From 25f561a742776af41b3165e2600e782ec9defe8b Mon Sep 17 00:00:00 2001 From: River Date: Thu, 19 Aug 2021 14:33:18 +0800 Subject: [PATCH] feat: QuoteTo accept clause.Expr (#4621) * feat: QuoteTo accept clause.Expr * test: update Expr build test --- clause/expression_test.go | 12 ++++++++++++ statement.go | 2 ++ 2 files changed, 14 insertions(+) diff --git a/clause/expression_test.go b/clause/expression_test.go index 0ccd0771..05074865 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -156,6 +156,18 @@ func TestExpression(t *testing.T) { }, ExpectedVars: []interface{}{"a", "b"}, Result: "`column-name` NOT IN (?,?)", + }, { + Expressions: []clause.Expression{ + clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100}, + }, + ExpectedVars: []interface{}{100}, + Result: "SUM(`id`) = ?", + }, { + Expressions: []clause.Expression{ + clause.Gte{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Table: "users", Name: "id"}}}, Value: 100}, + }, + ExpectedVars: []interface{}{100}, + Result: "SUM(`users`.`id`) >= ?", }} for idx, result := range results { diff --git a/statement.go b/statement.go index 8b682c84..93b78c12 100644 --- a/statement.go +++ b/statement.go @@ -129,6 +129,8 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { stmt.QuoteTo(writer, d) } writer.WriteByte(')') + case clause.Expr: + v.Build(stmt) case string: stmt.DB.Dialector.QuoteTo(writer, v) case []string: