From c807fe3202610e36fc898f4e85132914d8293c79 Mon Sep 17 00:00:00 2001 From: Gerhard Gruber Date: Thu, 27 Jul 2017 17:14:00 +0200 Subject: [PATCH] Started with expression extension Improved questionmark parameter placeholder replacementw Added support for subqueries in Where and Having clauses Queries can be transformed into subqueries by calling .Subquery() on a db object See main_test.go:TestQueryBuilderSubselectInWhere Fixed comment spacing Refactoring, adding Having Subquery support, allowing db.T for tablenames Removed quoting from tablename in db.T, use db.QT for that Refactoring, adding Having Subquery support, allowing db.T for tablenames Added changes Started with expression extension Refactoring, adding Having Subquery support, allowing db.T for tablenames Added method to easily update fields of the Model struct Added column comparison and Join support Added subquery support for InnerJoin querybuilder Fixed column comparison Added support for column prefixes Models can set their column prefix by implementing the method ColumnPrefix() string Fixed multi-parameter subselects and introduced aliasing Improved Related method Improved Related method to search for foreign key struct fields with the suffix "ID" (additional to "Id") Got QueryExpr support from upstream Added support for subqueries in Where and Having clauses Queries can be transformed into subqueries by calling .Subquery() on a db object See main_test.go:TestQueryBuilderSubselectInWhere Improved questionmark parameter placeholder replacementw Refactoring, adding Having Subquery support, allowing db.T for tablenames Removed quoting from tablename in db.T, use db.QT for that Removed quoting from tablename in db.T, use db.QT for that Added changes Added method to easily update fields of the Model struct Fixed column comparison Added support for column prefixes Models can set their column prefix by implementing the method ColumnPrefix() string Fixed multi-parameter subselects and introduced aliasing Improved Related method Improved Related method to search for foreign key struct fields with the suffix "ID" (additional to "Id") Added select extension for multiple columns Added support for LEFT RIGHT OUTER joins Fixed slice support for lexpr.In() Publizised LExpr Added DateFormatting for all dialects Added SUM function for columns Fixed FormatDate Added count for column Removed literal expressions LExpr Rewrote LExpr methods to work with expr structs. Added methods BAnd and BOr (bitwise & and | ) Added SetLogWriter method Added NotIn query expression Added Distinct query expression Added DistinctColumn query expression Same as Distinct but returns a string Added method OnExp to jexpr Improved query expression .Eq() method for nil pointers Fixed rebase errors --- dialect.go | 39 ++++++ dialect_common.go | 4 + dialect_mssql.go | 209 ++++++++++++++++++++++++++++++ dialect_mysql.go | 17 +++ dialect_postgres.go | 17 +++ dialect_sqlite3.go | 17 +++ dialects/mssql/mssql.go | 195 +--------------------------- expression_ext.go | 277 ++++++++++++++++++++++++++++++++++++++++ main.go | 13 +- main_test.go | 9 ++ model_struct.go | 7 +- scope.go | 25 +++- search.go | 8 +- 13 files changed, 636 insertions(+), 201 deletions(-) create mode 100644 dialect_mssql.go create mode 100644 expression_ext.go diff --git a/dialect.go b/dialect.go index 5f6439c1..850253a9 100644 --- a/dialect.go +++ b/dialect.go @@ -50,6 +50,19 @@ type Dialect interface { // CurrentDatabase return current database name CurrentDatabase() string + + // Formates the date and tries to be as similar as possible with all databases + // | | gorm | SQLITE | MYSQL | MSSQL | POSTGRES | + // | ----------- | ---- | ----------------- | ----------------- | ---------------- | ---------------- | + // | YEAR | %y | %Y (0000-9999) | %Y (0000-9999) | YYYY (0000-9999) | YYYY (0000-9999) | + // | MONTH | %m | %m (01-12) | %m (01-12) | MM (00-12) | MM (0000-9999) | + // | WEEK | %w | %W (00-53) | %u (00-53) | -- | WW (1-53) | + // | DAY | %d | %d (00-31) | %d (00-31) | dd (00-31) | DD (00-31) | + // | DAY OF WEEK | %D | %w (SUN 0-6 SAT) | %w (SUN 0-6 SAT) | -- | D (SUN 1-7 SAT) | + // | HOUR | %h | %H (00-24) | %H (00-23) | HH (00-31) | HH24 (00-23) | + // | MINUTE | %m | %M (00-59) | %i (00-59) | mm (00-59) | MI (00-59) | + // | SECOND | %s | %S (00-59) | %S (00-59) | ss (00-59) | SS (00-59) | + FormatDate(*expr, string) *expr } var dialectsMap = map[string]Dialect{} @@ -128,3 +141,29 @@ func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) } return dialect.CurrentDatabase(), tableName } + +func parseDateFormat(format string, mapping map[rune]string) string { + var parsedFormat string + isFormatter := false + + for _, rune := range format { + if !isFormatter { + if rune == '%' { + isFormatter = true + } else { + parsedFormat += string(rune) + } + continue + } + + isFormatter = false + + if sign, ok := mapping[rune]; ok { + parsedFormat += sign + } else { + parsedFormat += "%" + string(rune) + } + } + + return parsedFormat +} diff --git a/dialect_common.go b/dialect_common.go index b9f0c7da..cccb56dd 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -174,3 +174,7 @@ func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...str func IsByteArrayOrSlice(value reflect.Value) bool { return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) } + +func (commonDialect) FormatDate(e *expr, format string) *expr { + return e +} diff --git a/dialect_mssql.go b/dialect_mssql.go new file mode 100644 index 00000000..63e0f652 --- /dev/null +++ b/dialect_mssql.go @@ -0,0 +1,209 @@ +package gorm + +import ( + "fmt" + "reflect" + "strconv" + "strings" + "time" +) + +func setIdentityInsert(scope *Scope) { + if scope.Dialect().GetName() == "mssql" { + for _, field := range scope.PrimaryFields() { + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank { + scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) + scope.InstanceSet("mssql:identity_insert_on", true) + } + } + } +} + +func turnOffIdentityInsert(scope *Scope) { + if scope.Dialect().GetName() == "mssql" { + if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok { + scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName())) + } + } +} + +func init() { + DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert) + DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").Register("mssql:turn_off_identity_insert", turnOffIdentityInsert) + RegisterDialect("mssql", &mssql{}) +} + +type mssql struct { + db SQLCommon + DefaultForeignKeyNamer +} + +func (mssql) GetName() string { + return "mssql" +} + +func (s *mssql) SetDB(db SQLCommon) { + s.db = db +} + +func (mssql) BindVar(i int) string { + return "$$$" // ? +} + +func (mssql) Quote(key string) string { + return fmt.Sprintf(`[%s]`, key) +} + +func (s *mssql) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) + + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Bool: + sqlType = "bit" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if s.fieldCanAutoIncrement(field) { + field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + sqlType = "int IDENTITY(1,1)" + } else { + sqlType = "int" + } + case reflect.Int64, reflect.Uint64: + if s.fieldCanAutoIncrement(field) { + field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + sqlType = "bigint IDENTITY(1,1)" + } else { + sqlType = "bigint" + } + case reflect.Float32, reflect.Float64: + sqlType = "float" + case reflect.String: + if size > 0 && size < 8000 { + sqlType = fmt.Sprintf("nvarchar(%d)", size) + } else { + sqlType = "nvarchar(max)" + } + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + sqlType = "datetimeoffset" + } + default: + if IsByteArrayOrSlice(dataValue) { + if size > 0 && size < 8000 { + sqlType = fmt.Sprintf("varbinary(%d)", size) + } else { + sqlType = "varbinary(max)" + } + } + } + } + + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String())) + } + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + return fmt.Sprintf("%v %v", sqlType, additionalType) +} + +func (s mssql) fieldCanAutoIncrement(field *StructField) bool { + if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + return value != "FALSE" + } + return field.IsPrimaryKey +} + +func (s mssql) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) + return count > 0 +} + +func (s mssql) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) + return err +} + +func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { + return false +} + +func (s mssql) HasTable(tableName string) bool { + var count int + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count) + return count > 0 +} + +func (s mssql) HasColumn(tableName string, columnName string) bool { + var count int + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) + return count > 0 +} + +func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ)) + return err +} + +func (s mssql) CurrentDatabase() (name string) { + s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) + return +} + +func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { + if offset != nil { + if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { + sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset) + } + } + if limit != nil { + if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { + if sql == "" { + // add default zero offset + sql += " OFFSET 0 ROWS" + } + sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit) + } + } + return +} + +func (mssql) SelectFromDummyTable() string { + return "" +} + +func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { + return "" +} + +func (mssql) DefaultValueStr() string { + return "DEFAULT VALUES" +} + +func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) { + if strings.Contains(tableName, ".") { + splitStrings := strings.SplitN(tableName, ".", 2) + return splitStrings[0], splitStrings[1] + } + return dialect.CurrentDatabase(), tableName +} + +func (mssql) FormatDate(e *expr, format string) *expr { + mapping := map[rune]string{ + 'y': "YYYY", + 'm': "MM", + 'd': "dd", + 'h': "HH", + 'M': "mm", + 's': "ss", + } + parsedFormat := parseDateFormat(format, mapping) + + e.expr = "(format(" + e.expr + ", '" + parsedFormat + "'))" + return e +} + diff --git a/dialect_mysql.go b/dialect_mysql.go index b162bade..b8350dad 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -189,3 +189,20 @@ func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string { func (mysql) DefaultValueStr() string { return "VALUES()" } + +func (mysql) FormatDate(e *expr, format string) *expr { + mapping := map[rune]string{ + 'y': "%Y", + 'm': "%m", + 'w': "%u", + 'd': "%d", + 'D': "%w", + 'h': "%H", + 'M': "%i", + 's': "%S", + } + parsedFormat := parseDateFormat(format, mapping) + + e.expr = "(DATE_FORMAT(" + e.expr + ", '" + parsedFormat + "'))" + return e +} diff --git a/dialect_postgres.go b/dialect_postgres.go index c44c6a5b..01c5d03e 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -141,3 +141,20 @@ func isJSON(value reflect.Value) bool { _, ok := value.Interface().(json.RawMessage) return ok } + +func (postgres) FormatDate(e *expr, format string) *expr { + mapping := map[rune]string{ + 'y': "YYYY", + 'm': "MM", + 'w': "WW", + 'd': "DD", + 'D': "D", + 'h': "HH24", + 'M': "MI", + 's': "SS", + } + parsedFormat := parseDateFormat(format, mapping) + + e.expr = "(to_char(" + e.expr + ", '" + parsedFormat + "'))" + return e +} diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index f26f6be3..2bd36014 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -105,3 +105,20 @@ func (s sqlite3) CurrentDatabase() (name string) { } return } + +func (sqlite3) FormatDate(e *expr, format string) *expr { + mapping := map[rune]string{ + 'y': "%Y", + 'm': "%m", + 'w': "%W", + 'd': "%d", + 'D': "%w", + 'h': "%H", + 'M': "%M", + 's': "%S", + } + parsedFormat := parseDateFormat(format, mapping) + + e.expr = "(strftime('" + parsedFormat + "', " + e.expr + "))" + return e +} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index e0606465..26bb38eb 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -1,196 +1,3 @@ package mssql -import ( - "fmt" - "reflect" - "strconv" - "strings" - "time" - - _ "github.com/denisenkom/go-mssqldb" - "github.com/jinzhu/gorm" -) - -func setIdentityInsert(scope *gorm.Scope) { - if scope.Dialect().GetName() == "mssql" { - for _, field := range scope.PrimaryFields() { - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank { - scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) - scope.InstanceSet("mssql:identity_insert_on", true) - } - } - } -} - -func turnOffIdentityInsert(scope *gorm.Scope) { - if scope.Dialect().GetName() == "mssql" { - if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok { - scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName())) - } - } -} - -func init() { - gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert) - gorm.DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").Register("mssql:turn_off_identity_insert", turnOffIdentityInsert) - gorm.RegisterDialect("mssql", &mssql{}) -} - -type mssql struct { - db gorm.SQLCommon - gorm.DefaultForeignKeyNamer -} - -func (mssql) GetName() string { - return "mssql" -} - -func (s *mssql) SetDB(db gorm.SQLCommon) { - s.db = db -} - -func (mssql) BindVar(i int) string { - return "$$$" // ? -} - -func (mssql) Quote(key string) string { - return fmt.Sprintf(`[%s]`, key) -} - -func (s *mssql) DataTypeOf(field *gorm.StructField) string { - var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "bit" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "int IDENTITY(1,1)" - } else { - sqlType = "int" - } - case reflect.Int64, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "bigint IDENTITY(1,1)" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "float" - case reflect.String: - if size > 0 && size < 8000 { - sqlType = fmt.Sprintf("nvarchar(%d)", size) - } else { - sqlType = "nvarchar(max)" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "datetimeoffset" - } - default: - if gorm.IsByteArrayOrSlice(dataValue) { - if size > 0 && size < 8000 { - sqlType = fmt.Sprintf("varbinary(%d)", size) - } else { - sqlType = "varbinary(max)" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool { - if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - return value != "FALSE" - } - return field.IsPrimaryKey -} - -func (s mssql) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) - return count > 0 -} - -func (s mssql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) - return err -} - -func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { - return false -} - -func (s mssql) HasTable(tableName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count) - return count > 0 -} - -func (s mssql) HasColumn(tableName string, columnName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) - return count > 0 -} - -func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ)) - return err -} - -func (s mssql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) - return -} - -func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { - if offset != nil { - if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { - sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset) - } - } - if limit != nil { - if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { - if sql == "" { - // add default zero offset - sql += " OFFSET 0 ROWS" - } - sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit) - } - } - return -} - -func (mssql) SelectFromDummyTable() string { - return "" -} - -func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { - return "" -} - -func (mssql) DefaultValueStr() string { - return "DEFAULT VALUES" -} - -func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { - if strings.Contains(tableName, ".") { - splitStrings := strings.SplitN(tableName, ".", 2) - return splitStrings[0], splitStrings[1] - } - return dialect.CurrentDatabase(), tableName -} +import _ "github.com/denisenkom/go-mssqldb" diff --git a/expression_ext.go b/expression_ext.go new file mode 100644 index 00000000..39f21262 --- /dev/null +++ b/expression_ext.go @@ -0,0 +1,277 @@ +package gorm + +import ( + "reflect" + "strings" +) + +type jexpr struct { + expr string + args []interface{} +} + +func join(joinType string, db *DB, model interface{}, alias ...string) *jexpr { + var al string + if len(alias) > 0 { + al = alias[0] + } + + if val, ok := model.(*expr); ok { + return &jexpr{expr: " " + joinType + " JOIN (" + val.expr + ") " + al, args: val.args} + } + return &jexpr{expr: " " + joinType + " JOIN " + db.T(model) + " " + al} +} + +func (db *DB) InnerJoin(model interface{}, alias ...string) *jexpr { + return join("INNER", db, model, alias...) +} + +func (db *DB) LeftJoin(model interface{}, alias ...string) *jexpr { + return join("LEFT", db, model, alias...) +} + +func (db *DB) RightJoin(model interface{}, alias ...string) *jexpr { + return join("RIGHT", db, model, alias...) +} + +func (db *DB) OuterJoin(model interface{}, alias ...string) *jexpr { + return join("OUTER", db, model, alias...) +} + +func (je *jexpr) On(col1 *expr, col2 *expr) *expr { + return &expr{expr: je.expr + " ON " + col1.expr + " = " + col2.expr, args: je.args} +} + +func (je *jexpr) OnExp(e2 *expr) *expr { + e := &expr{expr: je.expr + " ON " + e2.expr, args: je.args} + e.args = append(e.args, e2.args...) + return e +} + +func (db *DB) L(model interface{}, name string) *expr { + scope := db.NewScope(model) + field, _ := scope.FieldByName(name) + return &expr{expr: scope.Quote(scope.TableName()) + "." + scope.Quote(field.DBName)} +} + +func (db *DB) LA(model interface{}, alias string, name string) *expr { + scope := db.NewScope(model) + field, _ := scope.FieldByName(name) + return &expr{expr: scope.Quote(alias) + "." + scope.Quote(field.DBName)} +} + +func (db *DB) C(model interface{}, names ...string) string { + columns := make([]string, 0) + + scope := db.NewScope(model) + for _, name := range names { + field, _ := scope.FieldByName(name) + columns = append(columns, field.DBName) + } + + return strings.Join(columns, ", ") +} + +func (db *DB) CA(model interface{}, alias string, names ...string) string { + columns := make([]string, 0) + + for _, name := range names { + columns = append(columns, db.LA(model, alias, name).expr) + } + + return strings.Join(columns, ", ") +} + +func (db *DB) CQ(model interface{}, names ...string) string { + columns := make([]string, 0) + + for _, name := range names { + columns = append(columns, db.L(model, name).expr) + } + + return strings.Join(columns, ", ") +} + +func (db *DB) T(model interface{}) string { + scope := db.NewScope(model) + return scope.TableName() +} + +func (db *DB) QT(model interface{}) string { + scope := db.NewScope(model) + return scope.QuotedTableName() +} + +func (e *expr) operator(operator string, value interface{}) *expr { + if value == nil { + e.expr = "(" + e.expr + " " + operator + " )" + return e + } + + if _, ok := value.(*expr); ok { + e.expr = "(" + e.expr + " " + operator + " (?))" + } else { + e.expr = "(" + e.expr + " " + operator + " ?)" + } + e.args = append(e.args, value) + + return e +} + +func (e *expr) Gt(value interface{}) *expr { + return e.operator(">", value) +} + +func (e *expr) Ge(value interface{}) *expr { + return e.operator(">=", value) +} + +func (e *expr) Lt(value interface{}) *expr { + return e.operator("<", value) +} + +func (e *expr) Le(value interface{}) *expr { + return e.operator("<=", value) +} + +func (e *expr) BAnd(value interface{}) *expr { + return e.operator("&", value) +} + +func (e *expr) BOr(value interface{}) *expr { + return e.operator("|", value) +} + +func (e *expr) Like(value interface{}) *expr { + return e.operator("LIKE", value) +} + +func (e *expr) Eq(value interface{}) *expr { + if value == nil { + return e.operator("IS NULL", value) + } else if val := reflect.ValueOf(value); val.Kind() == reflect.Ptr && val.IsNil() { + return e.operator("IS NULL", nil) + } + + return e.operator("=", value) +} + +func (e *expr) Neq(value interface{}) *expr { + if value == nil { + return e.operator("IS NOT NULL", value) + } + + return e.operator("!=", value) +} + +func (e *expr) Sum() string { + return "SUM(" + e.expr + ")" +} + +func (e *expr) Count() string { + return "COUNT(" + e.expr + ")" +} + +func (e *expr) Distinct() *expr { + e.expr = "DISTINCT " + e.expr + return e +} + +func (e *expr) DistinctColumn() string { + return "DISTINCT " + e.expr +} + +func (e *expr) in(operator string, values ...interface{}) *expr { + // NOTE: Maybe there is a better way to do this? :) + if len(values) == 1 { + if s := reflect.ValueOf(values[0]); s.Kind() == reflect.Slice { + vals := make([]interface{}, s.Len()) + qm := make([]string, s.Len()) + + for i := 0; i < s.Len(); i++ { + vals[i] = s.Index(i).Interface() + qm[i] = "?" + } + + e.expr = "(" + e.expr + operator + " IN (" + strings.Join(qm, ",") + "))" + e.args = append(e.args, vals...) + return e + } + } + + qm := make([]string, len(values)) + for i := 0; i < len(values); i++ { + qm[i] = "?" + } + + e.expr = "(" + e.expr + operator + " IN (" + strings.Join(qm, ",") + "))" + e.args = append(e.args, values...) + return e +} + +func (e *expr) In(values ...interface{}) *expr { + return e.in("", values...) +} + +func (e *expr) NotIn(values ...interface{}) *expr { + return e.in(" NOT", values...) +} + +func (e *expr) OrderAsc() string { + return e.expr + " ASC " +} + +func (e *expr) OrderDesc() string { + return e.expr + " DESC " +} + +func (e *expr) Or(e2 *expr) *expr { + e.expr = "(" + e.expr + " OR " + e2.expr + ")" + e.args = append(e.args, e2.args...) + + return e +} + +func (e *expr) And(e2 *expr) *expr { + e.expr = "(" + e.expr + " AND " + e2.expr + ")" + e.args = append(e.args, e2.args...) + + return e +} + +func (db *DB) UpdateFields(fields ...string) *DB { + sets := make(map[string]interface{}) + m := reflect.ValueOf(db.Value).Elem() + for _, field := range fields { + sets[db.C(db.Value, field)] = m.FieldByName(field).Interface() + } + + return db.Update(sets) +} + +func (db *DB) SelectFields(fields ...string) *DB { + selects := strings.Join(fields, ", ") + + return db.Select(selects) +} + +func (e *expr) Intersect(e2 *expr) *expr { + e.expr = "((" + e.expr + ") INTERSECT (" + e2.expr + "))" + e.args = append(e.args, e2.args...) + + return e +} + +func (e *expr) Alias(alias string) *expr { + e.expr = e.expr + " " + alias + " " + + return e +} + +func (db *DB) FormatDate(e *expr, format string) *expr { + return db.Dialect().FormatDate(e, format) +} + +func (db *DB) FormatDateColumn(e *expr, format string) string { + return db.FormatDate(e, format).expr +} diff --git a/main.go b/main.go index c26e05c8..94beab61 100644 --- a/main.go +++ b/main.go @@ -133,6 +133,11 @@ func (s *DB) SetLogger(log logger) { s.logger = log } +// SetLogWriter sets the LogWriter the default logger should write to +func (s *DB) SetLogWriter(log LogWriter) { + s.logger = Logger{log} +} + // LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs func (s *DB) LogMode(enable bool) *DB { if enable { @@ -169,11 +174,15 @@ func (s *DB) NewScope(value interface{}) *Scope { } // QueryExpr returns the query as expr object -func (s *DB) QueryExpr() *expr { +func (s *DB) QueryExpr(alias ...string) *expr { scope := s.NewScope(s.Value) scope.InstanceSet("skip_bindvar", true) scope.prepareQuerySQL() + if len(alias) > 0 { + return Expr("( "+scope.SQL+" ) "+alias[0]+" ", scope.SQLVars...) + } + return Expr(scope.SQL, scope.SQLVars...) } @@ -242,7 +251,7 @@ func (s *DB) Having(query interface{}, values ...interface{}) *DB { // Joins specify Joins conditions // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) -func (s *DB) Joins(query string, args ...interface{}) *DB { +func (s *DB) Joins(query interface{}, args ...interface{}) *DB { return s.clone().search.Joins(query, args...).db } diff --git a/main_test.go b/main_test.go index 66c46af0..7a3199e5 100644 --- a/main_test.go +++ b/main_test.go @@ -629,6 +629,7 @@ func TestQueryBuilderSubselectInWhere(t *testing.T) { if len(users) != 2 { t.Errorf("Two users should be found, instead found %d", len(users)) } + DB.Delete(&User{}) } func TestQueryBuilderRawQueryWithSubquery(t *testing.T) { @@ -689,6 +690,14 @@ func TestQueryBuilderSubselectInHaving(t *testing.T) { if len(users) != 1 { t.Errorf("Two user group should be found, instead found %d", len(users)) } + + DB.Select("*").Where("name LIKE ?", "query_expr_having_%").Where("age >= (?)", DB. + Select("AVG(age)").Where("name LIKE ?", "query_expr_having_%").Table("users").QueryExpr()).Find(&users) + + if len(users) != 2 { + t.Errorf("Two users should be found, instead found %d", len(users)) + } + DB.Delete(&User{}) } func DialectHasTzSupport() bool { diff --git a/model_struct.go b/model_struct.go index f571e2e8..1b6751b7 100644 --- a/model_struct.go +++ b/model_struct.go @@ -158,6 +158,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct { modelStruct.ModelType = reflectType + var columnPrefix string + if columnPrefixer, ok := reflect.New(modelStruct.ModelType).Interface().(columnPrefixer); ok { + columnPrefix = columnPrefixer.ColumnPrefix() + } + // Get all fields for i := 0; i < reflectType.NumField(); i++ { if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { @@ -587,7 +592,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if value, ok := field.TagSettings["COLUMN"]; ok { field.DBName = value } else { - field.DBName = ToDBName(fieldStruct.Name) + field.DBName = columnPrefix + ToDBName(fieldStruct.Name) } modelStruct.StructFields = append(modelStruct.StructFields, field) diff --git a/scope.go b/scope.go index 25077efc..8a72cab5 100644 --- a/scope.go +++ b/scope.go @@ -142,6 +142,7 @@ func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { if field.Name == name || field.DBName == name { return field, true } + if field.DBName == dbName { mostMatchedField = field } @@ -274,6 +275,7 @@ func (scope *Scope) AddToVars(value interface{}) string { if skipBindVar { return "?" } + return scope.Dialect().BindVar(len(scope.SQLVars)) } @@ -306,6 +308,10 @@ type tabler interface { TableName() string } +type columnPrefixer interface { + ColumnPrefix() string +} + type dbTabler interface { TableName(*DB) string } @@ -339,7 +345,7 @@ func (scope *Scope) QuotedTableName() (name string) { return scope.Quote(scope.TableName()) } -// CombinedConditionSql return combined condition sql +// CombinedConditionSQL return combined condition SQL func (scope *Scope) CombinedConditionSql() string { joinSQL := scope.joinsSQL() whereSQL := scope.whereSQL() @@ -473,6 +479,18 @@ func (scope *Scope) quoteIfPossible(str string) string { return str } +func (scope *Scope) replaceParameterPlaceholderLiteral(sql string, parameter interface{}, addToVars bool) string { + if val, ok := parameter.(string); ok && !addToVars { + return strings.Replace(sql, "?", val, 1) + } + + return strings.Replace(sql, "?", scope.AddToVars(parameter), 1) +} + +func (scope *Scope) replaceParameterPlaceholder(sql string, parameter interface{}) string { + return scope.replaceParameterPlaceholderLiteral(sql, parameter, true) +} + func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { var ( ignored interface{} @@ -578,6 +596,9 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) } } return strings.Join(sqls, " AND ") + case *expr: + clause["args"] = []interface{}{value} + str = "?" case interface{}: var sqls []string newScope := scope.New(value) @@ -1048,7 +1069,7 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { toScope := scope.db.NewScope(value) tx := scope.db.Set("gorm:association:source", scope.Value) - for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { + for _, foreignKey := range append(foreignKeys, toScope.typeName()+"ID", scope.typeName()+"ID", toScope.typeName()+"Id", scope.typeName()+"Id") { fromField, _ := scope.FieldByName(foreignKey) toField, _ := toScope.FieldByName(foreignKey) diff --git a/search.go b/search.go index 90138595..e49cc65c 100644 --- a/search.go +++ b/search.go @@ -106,8 +106,12 @@ func (s *search) Having(query interface{}, values ...interface{}) *search { return s } -func (s *search) Joins(query string, values ...interface{}) *search { - s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values}) +func (s *search) Joins(query interface{}, values ...interface{}) *search { + if val, ok := query.(*expr); ok { + s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": val.expr, "args": val.args}) + } else { + s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values}) + } return s }