From 27d6d274419d9845fd45905b4c51197b7462ca7e Mon Sep 17 00:00:00 2001 From: Tsubasa Munekata Date: Thu, 22 Feb 2024 21:28:27 +0900 Subject: [PATCH] Fix --- clause/where.go | 68 ++++++++++++++++++++++++++++++++++++-------- clause/where_test.go | 7 +++-- tests/query_test.go | 5 ++++ 3 files changed, 65 insertions(+), 15 deletions(-) diff --git a/clause/where.go b/clause/where.go index 3db8a3d3..9ac78578 100644 --- a/clause/where.go +++ b/clause/where.go @@ -153,6 +153,11 @@ func Not(exprs ...Expression) Expression { if len(exprs) == 0 { return nil } + if len(exprs) == 1 { + if andCondition, ok := exprs[0].(AndConditions); ok { + exprs = andCondition.Exprs + } + } return NotConditions{Exprs: exprs} } @@ -161,19 +166,58 @@ type NotConditions struct { } func (not NotConditions) Build(builder Builder) { - if len(not.Exprs) > 1 { - builder.WriteByte('(') + anyNegationBuilder := false + for _, c := range not.Exprs { + if _, ok := c.(NegationExpressionBuilder); ok { + anyNegationBuilder = true + break + } } - for idx, c := range not.Exprs { - if idx > 0 { - builder.WriteString(AndWithSpace) + if anyNegationBuilder { + if len(not.Exprs) > 1 { + builder.WriteByte('(') } - if negationBuilder, ok := c.(NegationExpressionBuilder); ok { - negationBuilder.NegationBuild(builder) - } else { - builder.WriteString("NOT ") + for idx, c := range not.Exprs { + if idx > 0 { + builder.WriteString(AndWithSpace) + } + + if negationBuilder, ok := c.(NegationExpressionBuilder); ok { + negationBuilder.NegationBuild(builder) + } else { + builder.WriteString("NOT ") + e, wrapInParentheses := c.(Expr) + if wrapInParentheses { + sql := strings.ToUpper(e.SQL) + if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses { + builder.WriteByte('(') + } + } + + c.Build(builder) + + if wrapInParentheses { + builder.WriteByte(')') + } + } + } + + if len(not.Exprs) > 1 { + builder.WriteByte(')') + } + } else { + builder.WriteString("NOT ") + if len(not.Exprs) > 1 { + builder.WriteByte('(') + } + + for idx, c := range not.Exprs { + if idx > 0 { + builder.WriteString(AndWithSpace) + } + e, wrapInParentheses := c.(Expr) if wrapInParentheses { sql := strings.ToUpper(e.SQL) @@ -188,9 +232,9 @@ func (not NotConditions) Build(builder Builder) { builder.WriteByte(')') } } - } - if len(not.Exprs) > 1 { - builder.WriteByte(')') + if len(not.Exprs) > 1 { + builder.WriteByte(')') + } } } diff --git a/clause/where_test.go b/clause/where_test.go index 28475a5c..7d5aca1f 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -107,10 +107,11 @@ func TestWhere(t *testing.T) { }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ - Exprs: []clause.Expression{clause.Not(clause.And(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}))}, + Exprs: []clause.Expression{clause.Not(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}}, + clause.Expr{SQL: "`age` <= ?", Vars: []interface{}{60}})}, }}, - "SELECT * FROM `users` WHERE NOT (`users`.`id` = ? AND `score` <= ?)", - []interface{}{"1", 100}, + "SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)", + []interface{}{100, 60}, }, } diff --git a/tests/query_test.go b/tests/query_test.go index cadf7164..e780e3bf 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -554,6 +554,11 @@ func TestNot(t *testing.T) { if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } + + result = dryDB.Not(DB.Where("manager IS NULL").Where("age >= ?", 20)).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT \\(manager IS NULL AND age >= .+\\) AND .users.\\..deleted_at. IS NULL").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } } func TestNotWithAllFields(t *testing.T) {