gorm/expression_ext.go
2024-07-17 18:38:38 +02:00

748 lines
16 KiB
Go

package gorm
import (
"fmt"
"reflect"
"strings"
)
type jexpr struct {
expr string
args []interface{}
}
func join(joinType string, db *DB, model interface{}, alias ...string) *jexpr {
var al string
if len(alias) > 0 {
al = alias[0]
}
if val, ok := model.(*expr); ok {
if al == "" {
return &jexpr{expr: " " + joinType + " JOIN " + val.expr, args: val.args}
} else {
return &jexpr{expr: " " + joinType + " JOIN (" + val.expr + ") " + al, args: val.args}
}
}
return &jexpr{expr: " " + joinType + " JOIN " + db.QT(model) + " " + al}
}
func (db *DB) InnerJoin(model interface{}, alias ...string) *jexpr {
return join("INNER", db, model, alias...)
}
func (db *DB) LeftJoin(model interface{}, alias ...string) *jexpr {
return join("LEFT", db, model, alias...)
}
func (db *DB) RightJoin(model interface{}, alias ...string) *jexpr {
return join("RIGHT", db, model, alias...)
}
func (db *DB) OuterJoin(model interface{}, alias ...string) *jexpr {
return join("OUTER", db, model, alias...)
}
func (je *jexpr) On(col1 *expr, col2 *expr) *expr {
return &expr{expr: je.expr + " ON " + col1.expr + " = " + col2.expr, args: je.args}
}
func (je *jexpr) OnExp(e2 *expr) *expr {
e := &expr{expr: je.expr + " ON " + e2.expr, args: je.args}
e.args = append(e.args, e2.args...)
return e
}
func (db *DB) L(model interface{}, name string) *expr {
scope := db.NewScope(model)
field, _ := scope.FieldByName(name)
return &expr{expr: scope.Quote(scope.TableName()) + "." + scope.Quote(field.DBName)}
}
func (db *DB) LA(model interface{}, alias string, name string) *expr {
scope := db.NewScope(model)
field, _ := scope.FieldByName(name)
return &expr{expr: scope.Quote(alias) + "." + scope.Quote(field.DBName)}
}
func (db *DB) LAO(model interface{}, alias string, name string) *expr {
if alias == "" {
return db.L(model, name)
}
return db.LA(model, alias, name)
}
func (db *DB) QuoteExpr(table string, column string) *expr {
scope := db.NewScope(nil)
return &expr{expr: scope.Quote(table) + "." + scope.Quote(column)}
}
func (db *DB) C(model interface{}, names ...string) string {
columns := make([]string, 0)
scope := db.NewScope(model)
for _, name := range names {
field, _ := scope.FieldByName(name)
columns = append(columns, field.DBName)
}
return strings.Join(columns, ", ")
}
func (db *DB) CA(model interface{}, alias string, names ...string) string {
columns := make([]string, 0)
for _, name := range names {
columns = append(columns, db.LA(model, alias, name).expr)
}
return strings.Join(columns, ", ")
}
func (db *DB) CAO(model interface{}, alias string, names ...string) string {
if alias == "" {
return db.CQ(model, names...)
}
return db.CA(model, alias, names...)
}
func (db *DB) CQ(model interface{}, names ...string) string {
columns := make([]string, 0)
for _, name := range names {
columns = append(columns, db.L(model, name).expr)
}
return strings.Join(columns, ", ")
}
func (db *DB) T(model interface{}) string {
scope := db.NewScope(model)
return scope.TableName()
}
func (db *DB) QT(model interface{}) string {
scope := db.NewScope(model)
return scope.QuotedTableName()
}
func (e *expr) operator(operator string, value interface{}) *expr {
if value == nil {
e.expr = "(" + e.expr + " " + operator + " )"
return e
}
if vExpr, ok := value.(*expr); ok {
e.expr = "(" + e.expr + " " + operator + " (" + vExpr.expr + "))"
e.args = append(e.args, vExpr.args...)
} else {
e.expr = "(" + e.expr + " " + operator + " ?)"
e.args = append(e.args, value)
}
return e
}
// Union will create a statement which unions the statement of e and stmt
func (e *expr) Union(stmt *expr) *expr {
e.expr = e.expr + " UNION " + stmt.expr
e.args = append(e.args, stmt.args...)
return e
}
// Union will create a statement which unions all given statements.
// stmts have to be *gorm.expr variables (but it is interface{}, so it
// can be used by external packages...)
func Union(stmts ...interface{}) *expr {
var result *expr
for idx, stmt := range stmts {
if idx == 0 {
result = stmt.(*expr)
} else {
result = result.Union(stmt.(*expr))
}
}
return result
}
// And will create a statement which "ands" all given statements.
// stmts have to be *gorm.expr variables (but it is interface{}, so it
// can be used by external packages...)
func And(stmts ...interface{}) *expr {
var result *expr
for idx, stmt := range stmts {
if idx == 0 {
result = stmt.(*expr)
} else {
result = result.And(stmt.(*expr))
}
}
return result
}
// Or will create a statement which "ors" all given statements.
// stmts have to be *gorm.expr variables (but it is interface{}, so it
// can be used by external packages...)
func Or(stmts ...interface{}) *expr {
var result *expr
for idx, stmt := range stmts {
if idx == 0 {
result = stmt.(*expr)
} else {
result = result.Or(stmt.(*expr))
}
}
return result
}
// Not negates the given statement by surrounding its expression with "NOT (expr)"
// stmt has to be a *gorm.expr (but it is interface{}, so it
// can be used by external packages...)
func Not(stmt interface{}) *expr {
e := stmt.(*expr)
e.expr = "NOT (" + e.expr + ")"
return e
}
// Concat will create an expression which concats all given statements together.
func Concat(stmts ...interface{}) *expr {
e := &expr{expr: "CONCAT("}
for i, stmt := range stmts {
if i != 0 {
e.expr += ", "
}
addStatementToExpression(e, stmt)
}
e.expr += ")"
return e
}
func (e *expr) Gt(value interface{}) *expr {
return e.operator(">", value)
}
func (e *expr) Ge(value interface{}) *expr {
return e.operator(">=", value)
}
func (e *expr) Lt(value interface{}) *expr {
return e.operator("<", value)
}
func (e *expr) Le(value interface{}) *expr {
return e.operator("<=", value)
}
func (e *expr) BAnd(value interface{}) *expr {
return e.operator("&", value)
}
func (e *expr) BOr(value interface{}) *expr {
return e.operator("|", value)
}
func (e *expr) Like(value interface{}) *expr {
return e.operator("LIKE", value)
}
func (e *expr) NotLike(value interface{}) *expr {
return e.operator("NOT LIKE", value)
}
func (e *expr) Eq(value interface{}) *expr {
if value == nil {
return e.operator("IS NULL", value)
} else if val := reflect.ValueOf(value); val.Kind() == reflect.Ptr && val.IsNil() {
return e.operator("IS NULL", nil)
}
return e.operator("=", value)
}
func (e *expr) Neq(value interface{}) *expr {
if value == nil {
return e.operator("IS NOT NULL", value)
} else if val := reflect.ValueOf(value); val.Kind() == reflect.Ptr && val.IsNil() {
return e.operator("IS NOT NULL", nil)
}
return e.operator("!=", value)
}
func (e *expr) Sum() string {
return "SUM(" + e.expr + ")"
}
func (e *expr) SumExpr() *expr {
e.expr = "SUM(" + e.expr + ")"
return e
}
func (e *expr) Avg() string {
return "AVG(" + e.expr + ")"
}
func (e *expr) AvgExpr() *expr {
e.expr = "AVG(" + e.expr + ")"
return e
}
func (e *expr) ReplaceExpr(search string, replace string) *expr {
e.expr = "REPLACE(" + e.expr + ",?,?)"
e.args = append(e.args, search, replace)
return e
}
func (db *DB) GroupConcatExpr(e *expr, separator string, orderExpr *expr) *expr {
e.args = append(e.args, orderExpr.args...)
dbType := db.Dialect().GetName()
switch dbType {
case "mysql":
e.expr = fmt.Sprintf("GROUP_CONCAT(%s %s SEPARATOR '%s')", e.expr, orderExpr.expr, separator)
case "mssql":
e.expr = fmt.Sprintf("STRING_AGG(%s, '%s') WITHIN GROUP (%s)", e.expr, separator, orderExpr.expr)
case "sqlite3":
e.expr = fmt.Sprintf("GROUP_CONCAT(%s,'%s')", e.expr, separator)
case "postgres":
e.expr = fmt.Sprintf("string_agg(%s, '%s' %s)", e.expr, separator, orderExpr.expr)
case "oracle":
e.expr = fmt.Sprintf("LISTAGG(%s, '%s') WITHIN GROUP (%s)", e.expr, separator, orderExpr.expr)
default:
panic(fmt.Sprintf("Unsuported database type %s for GroupConcat!", dbType))
}
return e
}
func (db *DB) GroupConcat(e *expr, separator string, orderExpr *expr) string {
return db.GroupConcatExpr(e, separator, orderExpr).expr
}
type TimeUnit int
var (
TimeUnitYear TimeUnit = 0
TimeUnitQuarter TimeUnit = 1
TimeUnitMonth TimeUnit = 2
TimeUnitDay TimeUnit = 3
TimeUnitWeek TimeUnit = 4
TimeUnitHour TimeUnit = 5
TimeUnitMinute TimeUnit = 6
TimeUnitSecond TimeUnit = 7
TimeUnitMicrosecond TimeUnit = 8
)
func (t TimeUnit) String(dialect string) string {
switch dialect {
case "mysql":
switch t {
case TimeUnitYear:
return "YEAR"
case TimeUnitQuarter:
return "QUARTER"
case TimeUnitMonth:
return "MONTH"
case TimeUnitDay:
return "WEEK"
case TimeUnitWeek:
return "DAY"
case TimeUnitHour:
return "HOUR"
case TimeUnitMinute:
return "MINUTE"
case TimeUnitSecond:
return "SECOND"
case TimeUnitMicrosecond:
return "MICROSECOND"
}
case "mssql":
switch t {
case TimeUnitYear:
return "year"
case TimeUnitQuarter:
return "quarter"
case TimeUnitMonth:
return "month"
case TimeUnitDay:
return "day"
case TimeUnitWeek:
return "week"
case TimeUnitHour:
return "hour"
case TimeUnitMinute:
return "minute"
case TimeUnitSecond:
return "second"
case TimeUnitMicrosecond:
return "microsecond"
}
}
return "unkown time unit"
}
func (db *DB) TimestampDiffExpr(unit TimeUnit, timestamp1 interface{}, timestamp2 interface{}) *expr {
e := &expr{expr: ""}
dialect := db.Dialect().GetName()
switch dialect {
case "mysql":
e.expr = "TIMESTAMPDIFF("
case "mssql":
e.expr = "DATEDIFF("
default:
panic(fmt.Sprintf("TIMESTAMPDIFF not supported for %s", dialect))
}
e.expr += unit.String(dialect) + ","
addStatementToExpression(e, timestamp1)
e.expr += ","
addStatementToExpression(e, timestamp2)
e.expr += ")"
return e
}
func addStatementToExpression(e *expr, stm interface{}) {
if exp, ok := stm.(*expr); ok {
e.expr += exp.expr
e.args = append(e.args, exp.args...)
} else {
e.expr += "?"
e.args = append(e.args, stm)
}
}
func (db *DB) TimestampDiff(unit TimeUnit, timestamp1 interface{}, timestamp2 interface{}) string {
return db.TimestampDiffExpr(unit, timestamp1, timestamp2).expr
}
func (db *DB) CoalesceExpr(stmts ...interface{}) *expr {
e := &expr{expr: "COALESCE("}
for i, stmt := range stmts {
if i != 0 {
e.expr += ", "
}
addStatementToExpression(e, stmt)
}
e.expr += ")"
return e
}
func (db *DB) Coalesce(stmts ...interface{}) string {
return db.CoalesceExpr(stmts...).expr
}
func Order(stmts ...interface{}) *expr {
e := &expr{expr: "ORDER BY "}
for i, stmt := range stmts {
if i != 0 {
e.expr += ", "
}
addStatementToExpression(e, stmt)
}
return e
}
func (e *expr) Max() string {
return "MAX(" + e.expr + ")"
}
func (e *expr) MaxExpr() *expr {
e.expr = "MAX(" + e.expr + ")"
return e
}
func (e *expr) Min() string {
return "MIN(" + e.expr + ")"
}
func (e *expr) MinExpr() *expr {
e.expr = "MIN(" + e.expr + ")"
return e
}
func (e *expr) LowerExpr() *expr {
e.expr = "LOWER(" + e.expr + ")"
return e
}
func (e *expr) UpperExpr() *expr {
e.expr = "UPPER(" + e.expr + ")"
return e
}
func (e *expr) Lower() string {
return "LOWER(" + e.expr + ")"
}
func (e *expr) Upper() string {
return "UPPER(" + e.expr + ")"
}
func (e *expr) Count() string {
return "COUNT(" + e.expr + ")"
}
func (e *expr) CountExpr() *expr {
e.expr = "COUNT(" + e.expr + ")"
return e
}
func (e *expr) Distinct() *expr {
e.expr = "DISTINCT " + e.expr
return e
}
func (e *expr) DistinctColumn() string {
return "DISTINCT " + e.expr
}
func (e *expr) in(operator string, values ...interface{}) *expr {
// NOTE: Maybe there is a better way to do this? :)
if len(values) == 1 {
s := reflect.ValueOf(values[0])
if s.Kind() == reflect.Slice {
if s.Len() == 0 {
if operator == "" {
e.expr = "1 = 0"
return e
} else {
e.expr = "1 = 1"
return e
}
}
vals := make([]interface{}, s.Len())
qm := make([]string, s.Len())
for i := 0; i < s.Len(); i++ {
vals[i] = s.Index(i).Interface()
qm[i] = "?"
}
e.expr = "(" + e.expr + operator + " IN (" + strings.Join(qm, ",") + "))"
e.args = append(e.args, vals...)
return e
}
if vexpr, ok := values[0].(*expr); ok {
e.expr = "(" + e.expr + operator + " IN (" + vexpr.expr + "))"
e.args = append(e.args, vexpr.args...)
return e
}
}
qm := make([]string, len(values))
for i := 0; i < len(values); i++ {
qm[i] = "?"
}
e.expr = "(" + e.expr + operator + " IN (" + strings.Join(qm, ",") + "))"
e.args = append(e.args, values...)
return e
}
func (e *expr) In(values ...interface{}) *expr {
if len(values) == 0 {
e.expr = "1 = 0"
return e
}
return e.in("", values...)
}
func (e *expr) NotIn(values ...interface{}) *expr {
if len(values) == 0 {
e.expr = "1 = 1"
return e
}
return e.in(" NOT", values...)
}
func (e *expr) OrderByCase(conditions ...interface{}) *expr {
e.expr = "(CASE " + e.expr
for i, condition := range conditions {
e.expr += " WHEN ? THEN ?"
e.args = append(e.args, condition, i+1)
}
e.expr += " ELSE ? END)"
e.args = append(e.args, len(conditions)+1)
return e
}
func (e *expr) OrderAsc() string {
return e.expr + " ASC "
}
func (e *expr) OrderDesc() string {
return e.expr + " DESC "
}
func (e *expr) OrderAscExpr() *expr {
e.expr = e.expr + " ASC "
return e
}
func (e *expr) OrderDescExpr() *expr {
e.expr = e.expr + " DESC "
return e
}
func (e *expr) Or(e2 *expr) *expr {
e.expr = "(" + e.expr + " OR " + e2.expr + ")"
e.args = append(e.args, e2.args...)
return e
}
func (e *expr) And(e2 *expr) *expr {
e.expr = "(" + e.expr + " AND " + e2.expr + ")"
e.args = append(e.args, e2.args...)
return e
}
func (db *DB) UpdateFields(fields ...string) *DB {
sets := make(map[string]interface{})
m := reflect.ValueOf(db.Value).Elem()
for _, field := range fields {
sets[db.C(db.Value, field)] = m.FieldByName(field).Interface()
}
return db.clone().Set("gorm:save_associations", false).Set("gorm:association_save_reference", false).Update(sets)
}
// UpdateFieldsWithoutHooks updates the specified fields of the current model without calling any
// Update hooks and without touching the UpdatedAt column (if any exists).
// The specified fields have to be the names of the struct variables.
func (db *DB) UpdateFieldsWithoutHooks(fields ...string) *DB {
return db.clone().Set("gorm:update_column", true).UpdateFields(fields...)
}
func (db *DB) SelectFields(fields ...string) *DB {
selects := strings.Join(fields, ", ")
return db.clone().Select(selects)
}
func (db *DB) SelectExprs(fields ...interface{}) *DB {
e := &expr{}
for i, field := range fields {
if i != 0 {
e.expr += ", "
}
addStatementToExpression(e, field)
}
return db.clone().Select(e.expr, e.args...)
}
func (e *expr) Intersect(e2 *expr) *expr {
e.expr = "((" + e.expr + ") INTERSECT (" + e2.expr + "))"
e.args = append(e.args, e2.args...)
return e
}
func (e *expr) Alias(alias string) *expr {
e.expr = e.expr + " " + alias + " "
return e
}
func (db *DB) FormatDate(e *expr, format string) *expr {
return db.Dialect().FormatDate(e, format)
}
func (db *DB) FormatDateColumn(e *expr, format string) string {
return db.FormatDate(e, format).expr
}
func (db *DB) GetSQL() string {
scope := db.NewScope(db.Value)
scope.prepareQuerySQL()
stmt := strings.ReplaceAll(scope.SQL, "$$$", "?")
for _, arg := range scope.SQLVars {
stmt = strings.Replace(stmt, "?", "'"+escape(fmt.Sprintf("%v", arg))+"'", 1)
}
return stmt
}
func (db *DB) GetSQLWhereClause() string {
scope := db.NewScope(db.Value)
stmt := strings.Replace(strings.ReplaceAll(scope.whereSQL(), "$$$", "?"), "WHERE", "", 1)
for _, arg := range scope.SQLVars {
stmt = strings.Replace(stmt, "?", "'"+escape(fmt.Sprintf("%v", arg))+"'", 1)
}
return stmt
}
func escape(source string) string {
var j int = 0
if len(source) == 0 {
return ""
}
tempStr := source[:]
desc := make([]byte, len(tempStr)*2)
for i := 0; i < len(tempStr); i++ {
flag := false
var escape byte
switch tempStr[i] {
case '\r':
flag = true
escape = '\r'
case '\n':
flag = true
escape = '\n'
case '\\':
flag = true
escape = '\\'
case '\'':
flag = true
escape = '\''
case '"':
flag = true
escape = '"'
case '\032':
flag = true
escape = 'Z'
default:
}
if flag {
desc[j] = '\\'
desc[j+1] = escape
j = j + 2
} else {
desc[j] = tempStr[i]
j = j + 1
}
}
return string(desc[0:j])
}