diff --git a/expression_ext.go b/expression_ext.go index cafd87e7..abddfeb7 100644 --- a/expression_ext.go +++ b/expression_ext.go @@ -383,6 +383,32 @@ func (e *expr) NotIn(values ...interface{}) *expr { return e.in(" NOT", values...) } +func (e *expr) orderByCase(orderValues []interface{}) *expr { + e.expr = "( CASE " + e.expr + + for i, orderValue := range orderValues { + e.expr += fmt.Sprintf(" WHEN %d THEN %d ELSE %d", orderValue, i+1, i+2) + } + + e.expr += " END )" + + return e +} + +func (e *expr) OrderByCaseASC(orderValues ...interface{}) *expr { + e = e.orderByCase(orderValues) + e.expr += " ASC" + + return e +} + +func (e *expr) OrderByCaseDESC(orderValues ...interface{}) *expr { + e = e.orderByCase(orderValues) + e.expr += " DESC" + + return e +} + func (e *expr) OrderAsc() string { return e.expr + " ASC " } @@ -451,77 +477,3 @@ func (db *DB) FormatDate(e *expr, format string) *expr { func (db *DB) FormatDateColumn(e *expr, format string) string { return db.FormatDate(e, format).expr } - -func (db *DB) GetSQL() string { - scope := db.NewScope(db.Value) - - scope.prepareQuerySQL() - - stmt := strings.ReplaceAll(scope.SQL, "$$$", "?") - for _, arg := range scope.SQLVars { - stmt = strings.Replace(stmt, "?", "'"+escape(fmt.Sprintf("%v", arg))+"'", 1) - } - - return stmt -} - -func (db *DB) GetSQLWhereClause() string { - scope := db.NewScope(db.Value) - - stmt := strings.Replace(strings.ReplaceAll(scope.whereSQL(), "$$$", "?"), "WHERE", "", 1) - - for _, arg := range scope.SQLVars { - stmt = strings.Replace(stmt, "?", "'"+escape(fmt.Sprintf("%v", arg))+"'", 1) - } - - return stmt -} - -func escape(source string) string { - var j int = 0 - if len(source) == 0 { - return "" - } - tempStr := source[:] - desc := make([]byte, len(tempStr)*2) - for i := 0; i < len(tempStr); i++ { - flag := false - var escape byte - switch tempStr[i] { - case '\r': - flag = true - escape = '\r' - - case '\n': - flag = true - escape = '\n' - - case '\\': - flag = true - escape = '\\' - - case '\'': - flag = true - escape = '\'' - - case '"': - flag = true - escape = '"' - - case '\032': - flag = true - escape = 'Z' - - default: - } - if flag { - desc[j] = '\\' - desc[j+1] = escape - j = j + 2 - } else { - desc[j] = tempStr[i] - j = j + 1 - } - } - return string(desc[0:j]) -}