diff --git a/dialect.go b/dialect.go index 5f6439c1..850253a9 100644 --- a/dialect.go +++ b/dialect.go @@ -50,6 +50,19 @@ type Dialect interface { // CurrentDatabase return current database name CurrentDatabase() string + + // Formates the date and tries to be as similar as possible with all databases + // | | gorm | SQLITE | MYSQL | MSSQL | POSTGRES | + // | ----------- | ---- | ----------------- | ----------------- | ---------------- | ---------------- | + // | YEAR | %y | %Y (0000-9999) | %Y (0000-9999) | YYYY (0000-9999) | YYYY (0000-9999) | + // | MONTH | %m | %m (01-12) | %m (01-12) | MM (00-12) | MM (0000-9999) | + // | WEEK | %w | %W (00-53) | %u (00-53) | -- | WW (1-53) | + // | DAY | %d | %d (00-31) | %d (00-31) | dd (00-31) | DD (00-31) | + // | DAY OF WEEK | %D | %w (SUN 0-6 SAT) | %w (SUN 0-6 SAT) | -- | D (SUN 1-7 SAT) | + // | HOUR | %h | %H (00-24) | %H (00-23) | HH (00-31) | HH24 (00-23) | + // | MINUTE | %m | %M (00-59) | %i (00-59) | mm (00-59) | MI (00-59) | + // | SECOND | %s | %S (00-59) | %S (00-59) | ss (00-59) | SS (00-59) | + FormatDate(*expr, string) *expr } var dialectsMap = map[string]Dialect{} @@ -128,3 +141,29 @@ func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) } return dialect.CurrentDatabase(), tableName } + +func parseDateFormat(format string, mapping map[rune]string) string { + var parsedFormat string + isFormatter := false + + for _, rune := range format { + if !isFormatter { + if rune == '%' { + isFormatter = true + } else { + parsedFormat += string(rune) + } + continue + } + + isFormatter = false + + if sign, ok := mapping[rune]; ok { + parsedFormat += sign + } else { + parsedFormat += "%" + string(rune) + } + } + + return parsedFormat +} diff --git a/dialect_common.go b/dialect_common.go index b9f0c7da..cccb56dd 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -174,3 +174,7 @@ func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...str func IsByteArrayOrSlice(value reflect.Value) bool { return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) } + +func (commonDialect) FormatDate(e *expr, format string) *expr { + return e +} diff --git a/dialect_mssql.go b/dialect_mssql.go new file mode 100644 index 00000000..63e0f652 --- /dev/null +++ b/dialect_mssql.go @@ -0,0 +1,209 @@ +package gorm + +import ( + "fmt" + "reflect" + "strconv" + "strings" + "time" +) + +func setIdentityInsert(scope *Scope) { + if scope.Dialect().GetName() == "mssql" { + for _, field := range scope.PrimaryFields() { + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank { + scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) + scope.InstanceSet("mssql:identity_insert_on", true) + } + } + } +} + +func turnOffIdentityInsert(scope *Scope) { + if scope.Dialect().GetName() == "mssql" { + if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok { + scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName())) + } + } +} + +func init() { + DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert) + DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").Register("mssql:turn_off_identity_insert", turnOffIdentityInsert) + RegisterDialect("mssql", &mssql{}) +} + +type mssql struct { + db SQLCommon + DefaultForeignKeyNamer +} + +func (mssql) GetName() string { + return "mssql" +} + +func (s *mssql) SetDB(db SQLCommon) { + s.db = db +} + +func (mssql) BindVar(i int) string { + return "$$$" // ? +} + +func (mssql) Quote(key string) string { + return fmt.Sprintf(`[%s]`, key) +} + +func (s *mssql) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) + + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Bool: + sqlType = "bit" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if s.fieldCanAutoIncrement(field) { + field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + sqlType = "int IDENTITY(1,1)" + } else { + sqlType = "int" + } + case reflect.Int64, reflect.Uint64: + if s.fieldCanAutoIncrement(field) { + field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + sqlType = "bigint IDENTITY(1,1)" + } else { + sqlType = "bigint" + } + case reflect.Float32, reflect.Float64: + sqlType = "float" + case reflect.String: + if size > 0 && size < 8000 { + sqlType = fmt.Sprintf("nvarchar(%d)", size) + } else { + sqlType = "nvarchar(max)" + } + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + sqlType = "datetimeoffset" + } + default: + if IsByteArrayOrSlice(dataValue) { + if size > 0 && size < 8000 { + sqlType = fmt.Sprintf("varbinary(%d)", size) + } else { + sqlType = "varbinary(max)" + } + } + } + } + + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String())) + } + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + return fmt.Sprintf("%v %v", sqlType, additionalType) +} + +func (s mssql) fieldCanAutoIncrement(field *StructField) bool { + if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + return value != "FALSE" + } + return field.IsPrimaryKey +} + +func (s mssql) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) + return count > 0 +} + +func (s mssql) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) + return err +} + +func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { + return false +} + +func (s mssql) HasTable(tableName string) bool { + var count int + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count) + return count > 0 +} + +func (s mssql) HasColumn(tableName string, columnName string) bool { + var count int + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) + return count > 0 +} + +func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ)) + return err +} + +func (s mssql) CurrentDatabase() (name string) { + s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) + return +} + +func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { + if offset != nil { + if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { + sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset) + } + } + if limit != nil { + if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { + if sql == "" { + // add default zero offset + sql += " OFFSET 0 ROWS" + } + sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit) + } + } + return +} + +func (mssql) SelectFromDummyTable() string { + return "" +} + +func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { + return "" +} + +func (mssql) DefaultValueStr() string { + return "DEFAULT VALUES" +} + +func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) { + if strings.Contains(tableName, ".") { + splitStrings := strings.SplitN(tableName, ".", 2) + return splitStrings[0], splitStrings[1] + } + return dialect.CurrentDatabase(), tableName +} + +func (mssql) FormatDate(e *expr, format string) *expr { + mapping := map[rune]string{ + 'y': "YYYY", + 'm': "MM", + 'd': "dd", + 'h': "HH", + 'M': "mm", + 's': "ss", + } + parsedFormat := parseDateFormat(format, mapping) + + e.expr = "(format(" + e.expr + ", '" + parsedFormat + "'))" + return e +} + diff --git a/dialect_mysql.go b/dialect_mysql.go index b162bade..b8350dad 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -189,3 +189,20 @@ func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string { func (mysql) DefaultValueStr() string { return "VALUES()" } + +func (mysql) FormatDate(e *expr, format string) *expr { + mapping := map[rune]string{ + 'y': "%Y", + 'm': "%m", + 'w': "%u", + 'd': "%d", + 'D': "%w", + 'h': "%H", + 'M': "%i", + 's': "%S", + } + parsedFormat := parseDateFormat(format, mapping) + + e.expr = "(DATE_FORMAT(" + e.expr + ", '" + parsedFormat + "'))" + return e +} diff --git a/dialect_postgres.go b/dialect_postgres.go index c44c6a5b..01c5d03e 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -141,3 +141,20 @@ func isJSON(value reflect.Value) bool { _, ok := value.Interface().(json.RawMessage) return ok } + +func (postgres) FormatDate(e *expr, format string) *expr { + mapping := map[rune]string{ + 'y': "YYYY", + 'm': "MM", + 'w': "WW", + 'd': "DD", + 'D': "D", + 'h': "HH24", + 'M': "MI", + 's': "SS", + } + parsedFormat := parseDateFormat(format, mapping) + + e.expr = "(to_char(" + e.expr + ", '" + parsedFormat + "'))" + return e +} diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index f26f6be3..2bd36014 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -105,3 +105,20 @@ func (s sqlite3) CurrentDatabase() (name string) { } return } + +func (sqlite3) FormatDate(e *expr, format string) *expr { + mapping := map[rune]string{ + 'y': "%Y", + 'm': "%m", + 'w': "%W", + 'd': "%d", + 'D': "%w", + 'h': "%H", + 'M': "%M", + 's': "%S", + } + parsedFormat := parseDateFormat(format, mapping) + + e.expr = "(strftime('" + parsedFormat + "', " + e.expr + "))" + return e +} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index e0606465..26bb38eb 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -1,196 +1,3 @@ package mssql -import ( - "fmt" - "reflect" - "strconv" - "strings" - "time" - - _ "github.com/denisenkom/go-mssqldb" - "github.com/jinzhu/gorm" -) - -func setIdentityInsert(scope *gorm.Scope) { - if scope.Dialect().GetName() == "mssql" { - for _, field := range scope.PrimaryFields() { - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank { - scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) - scope.InstanceSet("mssql:identity_insert_on", true) - } - } - } -} - -func turnOffIdentityInsert(scope *gorm.Scope) { - if scope.Dialect().GetName() == "mssql" { - if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok { - scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName())) - } - } -} - -func init() { - gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert) - gorm.DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").Register("mssql:turn_off_identity_insert", turnOffIdentityInsert) - gorm.RegisterDialect("mssql", &mssql{}) -} - -type mssql struct { - db gorm.SQLCommon - gorm.DefaultForeignKeyNamer -} - -func (mssql) GetName() string { - return "mssql" -} - -func (s *mssql) SetDB(db gorm.SQLCommon) { - s.db = db -} - -func (mssql) BindVar(i int) string { - return "$$$" // ? -} - -func (mssql) Quote(key string) string { - return fmt.Sprintf(`[%s]`, key) -} - -func (s *mssql) DataTypeOf(field *gorm.StructField) string { - var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "bit" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "int IDENTITY(1,1)" - } else { - sqlType = "int" - } - case reflect.Int64, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "bigint IDENTITY(1,1)" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "float" - case reflect.String: - if size > 0 && size < 8000 { - sqlType = fmt.Sprintf("nvarchar(%d)", size) - } else { - sqlType = "nvarchar(max)" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "datetimeoffset" - } - default: - if gorm.IsByteArrayOrSlice(dataValue) { - if size > 0 && size < 8000 { - sqlType = fmt.Sprintf("varbinary(%d)", size) - } else { - sqlType = "varbinary(max)" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool { - if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - return value != "FALSE" - } - return field.IsPrimaryKey -} - -func (s mssql) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) - return count > 0 -} - -func (s mssql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) - return err -} - -func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { - return false -} - -func (s mssql) HasTable(tableName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count) - return count > 0 -} - -func (s mssql) HasColumn(tableName string, columnName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) - return count > 0 -} - -func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ)) - return err -} - -func (s mssql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) - return -} - -func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { - if offset != nil { - if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { - sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset) - } - } - if limit != nil { - if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { - if sql == "" { - // add default zero offset - sql += " OFFSET 0 ROWS" - } - sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit) - } - } - return -} - -func (mssql) SelectFromDummyTable() string { - return "" -} - -func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { - return "" -} - -func (mssql) DefaultValueStr() string { - return "DEFAULT VALUES" -} - -func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { - if strings.Contains(tableName, ".") { - splitStrings := strings.SplitN(tableName, ".", 2) - return splitStrings[0], splitStrings[1] - } - return dialect.CurrentDatabase(), tableName -} +import _ "github.com/denisenkom/go-mssqldb" diff --git a/expression_ext.go b/expression_ext.go new file mode 100644 index 00000000..39f21262 --- /dev/null +++ b/expression_ext.go @@ -0,0 +1,277 @@ +package gorm + +import ( + "reflect" + "strings" +) + +type jexpr struct { + expr string + args []interface{} +} + +func join(joinType string, db *DB, model interface{}, alias ...string) *jexpr { + var al string + if len(alias) > 0 { + al = alias[0] + } + + if val, ok := model.(*expr); ok { + return &jexpr{expr: " " + joinType + " JOIN (" + val.expr + ") " + al, args: val.args} + } + return &jexpr{expr: " " + joinType + " JOIN " + db.T(model) + " " + al} +} + +func (db *DB) InnerJoin(model interface{}, alias ...string) *jexpr { + return join("INNER", db, model, alias...) +} + +func (db *DB) LeftJoin(model interface{}, alias ...string) *jexpr { + return join("LEFT", db, model, alias...) +} + +func (db *DB) RightJoin(model interface{}, alias ...string) *jexpr { + return join("RIGHT", db, model, alias...) +} + +func (db *DB) OuterJoin(model interface{}, alias ...string) *jexpr { + return join("OUTER", db, model, alias...) +} + +func (je *jexpr) On(col1 *expr, col2 *expr) *expr { + return &expr{expr: je.expr + " ON " + col1.expr + " = " + col2.expr, args: je.args} +} + +func (je *jexpr) OnExp(e2 *expr) *expr { + e := &expr{expr: je.expr + " ON " + e2.expr, args: je.args} + e.args = append(e.args, e2.args...) + return e +} + +func (db *DB) L(model interface{}, name string) *expr { + scope := db.NewScope(model) + field, _ := scope.FieldByName(name) + return &expr{expr: scope.Quote(scope.TableName()) + "." + scope.Quote(field.DBName)} +} + +func (db *DB) LA(model interface{}, alias string, name string) *expr { + scope := db.NewScope(model) + field, _ := scope.FieldByName(name) + return &expr{expr: scope.Quote(alias) + "." + scope.Quote(field.DBName)} +} + +func (db *DB) C(model interface{}, names ...string) string { + columns := make([]string, 0) + + scope := db.NewScope(model) + for _, name := range names { + field, _ := scope.FieldByName(name) + columns = append(columns, field.DBName) + } + + return strings.Join(columns, ", ") +} + +func (db *DB) CA(model interface{}, alias string, names ...string) string { + columns := make([]string, 0) + + for _, name := range names { + columns = append(columns, db.LA(model, alias, name).expr) + } + + return strings.Join(columns, ", ") +} + +func (db *DB) CQ(model interface{}, names ...string) string { + columns := make([]string, 0) + + for _, name := range names { + columns = append(columns, db.L(model, name).expr) + } + + return strings.Join(columns, ", ") +} + +func (db *DB) T(model interface{}) string { + scope := db.NewScope(model) + return scope.TableName() +} + +func (db *DB) QT(model interface{}) string { + scope := db.NewScope(model) + return scope.QuotedTableName() +} + +func (e *expr) operator(operator string, value interface{}) *expr { + if value == nil { + e.expr = "(" + e.expr + " " + operator + " )" + return e + } + + if _, ok := value.(*expr); ok { + e.expr = "(" + e.expr + " " + operator + " (?))" + } else { + e.expr = "(" + e.expr + " " + operator + " ?)" + } + e.args = append(e.args, value) + + return e +} + +func (e *expr) Gt(value interface{}) *expr { + return e.operator(">", value) +} + +func (e *expr) Ge(value interface{}) *expr { + return e.operator(">=", value) +} + +func (e *expr) Lt(value interface{}) *expr { + return e.operator("<", value) +} + +func (e *expr) Le(value interface{}) *expr { + return e.operator("<=", value) +} + +func (e *expr) BAnd(value interface{}) *expr { + return e.operator("&", value) +} + +func (e *expr) BOr(value interface{}) *expr { + return e.operator("|", value) +} + +func (e *expr) Like(value interface{}) *expr { + return e.operator("LIKE", value) +} + +func (e *expr) Eq(value interface{}) *expr { + if value == nil { + return e.operator("IS NULL", value) + } else if val := reflect.ValueOf(value); val.Kind() == reflect.Ptr && val.IsNil() { + return e.operator("IS NULL", nil) + } + + return e.operator("=", value) +} + +func (e *expr) Neq(value interface{}) *expr { + if value == nil { + return e.operator("IS NOT NULL", value) + } + + return e.operator("!=", value) +} + +func (e *expr) Sum() string { + return "SUM(" + e.expr + ")" +} + +func (e *expr) Count() string { + return "COUNT(" + e.expr + ")" +} + +func (e *expr) Distinct() *expr { + e.expr = "DISTINCT " + e.expr + return e +} + +func (e *expr) DistinctColumn() string { + return "DISTINCT " + e.expr +} + +func (e *expr) in(operator string, values ...interface{}) *expr { + // NOTE: Maybe there is a better way to do this? :) + if len(values) == 1 { + if s := reflect.ValueOf(values[0]); s.Kind() == reflect.Slice { + vals := make([]interface{}, s.Len()) + qm := make([]string, s.Len()) + + for i := 0; i < s.Len(); i++ { + vals[i] = s.Index(i).Interface() + qm[i] = "?" + } + + e.expr = "(" + e.expr + operator + " IN (" + strings.Join(qm, ",") + "))" + e.args = append(e.args, vals...) + return e + } + } + + qm := make([]string, len(values)) + for i := 0; i < len(values); i++ { + qm[i] = "?" + } + + e.expr = "(" + e.expr + operator + " IN (" + strings.Join(qm, ",") + "))" + e.args = append(e.args, values...) + return e +} + +func (e *expr) In(values ...interface{}) *expr { + return e.in("", values...) +} + +func (e *expr) NotIn(values ...interface{}) *expr { + return e.in(" NOT", values...) +} + +func (e *expr) OrderAsc() string { + return e.expr + " ASC " +} + +func (e *expr) OrderDesc() string { + return e.expr + " DESC " +} + +func (e *expr) Or(e2 *expr) *expr { + e.expr = "(" + e.expr + " OR " + e2.expr + ")" + e.args = append(e.args, e2.args...) + + return e +} + +func (e *expr) And(e2 *expr) *expr { + e.expr = "(" + e.expr + " AND " + e2.expr + ")" + e.args = append(e.args, e2.args...) + + return e +} + +func (db *DB) UpdateFields(fields ...string) *DB { + sets := make(map[string]interface{}) + m := reflect.ValueOf(db.Value).Elem() + for _, field := range fields { + sets[db.C(db.Value, field)] = m.FieldByName(field).Interface() + } + + return db.Update(sets) +} + +func (db *DB) SelectFields(fields ...string) *DB { + selects := strings.Join(fields, ", ") + + return db.Select(selects) +} + +func (e *expr) Intersect(e2 *expr) *expr { + e.expr = "((" + e.expr + ") INTERSECT (" + e2.expr + "))" + e.args = append(e.args, e2.args...) + + return e +} + +func (e *expr) Alias(alias string) *expr { + e.expr = e.expr + " " + alias + " " + + return e +} + +func (db *DB) FormatDate(e *expr, format string) *expr { + return db.Dialect().FormatDate(e, format) +} + +func (db *DB) FormatDateColumn(e *expr, format string) string { + return db.FormatDate(e, format).expr +} diff --git a/main.go b/main.go index c26e05c8..94beab61 100644 --- a/main.go +++ b/main.go @@ -133,6 +133,11 @@ func (s *DB) SetLogger(log logger) { s.logger = log } +// SetLogWriter sets the LogWriter the default logger should write to +func (s *DB) SetLogWriter(log LogWriter) { + s.logger = Logger{log} +} + // LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs func (s *DB) LogMode(enable bool) *DB { if enable { @@ -169,11 +174,15 @@ func (s *DB) NewScope(value interface{}) *Scope { } // QueryExpr returns the query as expr object -func (s *DB) QueryExpr() *expr { +func (s *DB) QueryExpr(alias ...string) *expr { scope := s.NewScope(s.Value) scope.InstanceSet("skip_bindvar", true) scope.prepareQuerySQL() + if len(alias) > 0 { + return Expr("( "+scope.SQL+" ) "+alias[0]+" ", scope.SQLVars...) + } + return Expr(scope.SQL, scope.SQLVars...) } @@ -242,7 +251,7 @@ func (s *DB) Having(query interface{}, values ...interface{}) *DB { // Joins specify Joins conditions // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) -func (s *DB) Joins(query string, args ...interface{}) *DB { +func (s *DB) Joins(query interface{}, args ...interface{}) *DB { return s.clone().search.Joins(query, args...).db } diff --git a/main_test.go b/main_test.go index 66c46af0..7a3199e5 100644 --- a/main_test.go +++ b/main_test.go @@ -629,6 +629,7 @@ func TestQueryBuilderSubselectInWhere(t *testing.T) { if len(users) != 2 { t.Errorf("Two users should be found, instead found %d", len(users)) } + DB.Delete(&User{}) } func TestQueryBuilderRawQueryWithSubquery(t *testing.T) { @@ -689,6 +690,14 @@ func TestQueryBuilderSubselectInHaving(t *testing.T) { if len(users) != 1 { t.Errorf("Two user group should be found, instead found %d", len(users)) } + + DB.Select("*").Where("name LIKE ?", "query_expr_having_%").Where("age >= (?)", DB. + Select("AVG(age)").Where("name LIKE ?", "query_expr_having_%").Table("users").QueryExpr()).Find(&users) + + if len(users) != 2 { + t.Errorf("Two users should be found, instead found %d", len(users)) + } + DB.Delete(&User{}) } func DialectHasTzSupport() bool { diff --git a/model_struct.go b/model_struct.go index f571e2e8..1b6751b7 100644 --- a/model_struct.go +++ b/model_struct.go @@ -158,6 +158,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct { modelStruct.ModelType = reflectType + var columnPrefix string + if columnPrefixer, ok := reflect.New(modelStruct.ModelType).Interface().(columnPrefixer); ok { + columnPrefix = columnPrefixer.ColumnPrefix() + } + // Get all fields for i := 0; i < reflectType.NumField(); i++ { if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { @@ -587,7 +592,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if value, ok := field.TagSettings["COLUMN"]; ok { field.DBName = value } else { - field.DBName = ToDBName(fieldStruct.Name) + field.DBName = columnPrefix + ToDBName(fieldStruct.Name) } modelStruct.StructFields = append(modelStruct.StructFields, field) diff --git a/scope.go b/scope.go index 25077efc..8a72cab5 100644 --- a/scope.go +++ b/scope.go @@ -142,6 +142,7 @@ func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { if field.Name == name || field.DBName == name { return field, true } + if field.DBName == dbName { mostMatchedField = field } @@ -274,6 +275,7 @@ func (scope *Scope) AddToVars(value interface{}) string { if skipBindVar { return "?" } + return scope.Dialect().BindVar(len(scope.SQLVars)) } @@ -306,6 +308,10 @@ type tabler interface { TableName() string } +type columnPrefixer interface { + ColumnPrefix() string +} + type dbTabler interface { TableName(*DB) string } @@ -339,7 +345,7 @@ func (scope *Scope) QuotedTableName() (name string) { return scope.Quote(scope.TableName()) } -// CombinedConditionSql return combined condition sql +// CombinedConditionSQL return combined condition SQL func (scope *Scope) CombinedConditionSql() string { joinSQL := scope.joinsSQL() whereSQL := scope.whereSQL() @@ -473,6 +479,18 @@ func (scope *Scope) quoteIfPossible(str string) string { return str } +func (scope *Scope) replaceParameterPlaceholderLiteral(sql string, parameter interface{}, addToVars bool) string { + if val, ok := parameter.(string); ok && !addToVars { + return strings.Replace(sql, "?", val, 1) + } + + return strings.Replace(sql, "?", scope.AddToVars(parameter), 1) +} + +func (scope *Scope) replaceParameterPlaceholder(sql string, parameter interface{}) string { + return scope.replaceParameterPlaceholderLiteral(sql, parameter, true) +} + func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { var ( ignored interface{} @@ -578,6 +596,9 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) } } return strings.Join(sqls, " AND ") + case *expr: + clause["args"] = []interface{}{value} + str = "?" case interface{}: var sqls []string newScope := scope.New(value) @@ -1048,7 +1069,7 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { toScope := scope.db.NewScope(value) tx := scope.db.Set("gorm:association:source", scope.Value) - for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { + for _, foreignKey := range append(foreignKeys, toScope.typeName()+"ID", scope.typeName()+"ID", toScope.typeName()+"Id", scope.typeName()+"Id") { fromField, _ := scope.FieldByName(foreignKey) toField, _ := toScope.FieldByName(foreignKey) diff --git a/search.go b/search.go index 90138595..e49cc65c 100644 --- a/search.go +++ b/search.go @@ -106,8 +106,12 @@ func (s *search) Having(query interface{}, values ...interface{}) *search { return s } -func (s *search) Joins(query string, values ...interface{}) *search { - s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values}) +func (s *search) Joins(query interface{}, values ...interface{}) *search { + if val, ok := query.(*expr); ok { + s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": val.expr, "args": val.args}) + } else { + s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values}) + } return s }