From 72d0fa61960c5c2472b561e3945654b3f020a233 Mon Sep 17 00:00:00 2001 From: Douglas Danger Manley Date: Sun, 7 Jun 2020 16:41:54 -0400 Subject: [PATCH] Fix Statement Where clone array corruption in v2 Method-chaining in gorm is predicated on a `Clause`'s `MergeClause` method ensuring that the two clauses are disconnected in terms of pointers (at least in the Wherec case). However, the original Where implementation used `append`, which only returns a new instance if the backing array needs to be resized. In some cases, this is true. Practically, go doubles the size of the slice once it gets full, so the following slice `append` calls would result in a new slice: * 0 -> 1 * 1 -> 2 * 2 -> 4 * 4 -> 8 * and so on. So, when the number of "where" conditions was 0, 1, 2, or 4, method-chaining would work as expected. However, when it was 3, 5, 6, or 7, modifying the copy would modify the original. This also updates the "order by", "group by" and "set" clauses. --- clause/group_by.go | 9 +++++++-- clause/order_by.go | 4 +++- clause/set.go | 4 +++- clause/where.go | 4 +++- statement_test.go | 37 +++++++++++++++++++++++++++++++++++++ 5 files changed, 53 insertions(+), 5 deletions(-) create mode 100644 statement_test.go diff --git a/clause/group_by.go b/clause/group_by.go index c1383c36..88231916 100644 --- a/clause/group_by.go +++ b/clause/group_by.go @@ -30,8 +30,13 @@ func (groupBy GroupBy) Build(builder Builder) { // MergeClause merge group by clause func (groupBy GroupBy) MergeClause(clause *Clause) { if v, ok := clause.Expression.(GroupBy); ok { - groupBy.Columns = append(v.Columns, groupBy.Columns...) - groupBy.Having = append(v.Having, groupBy.Having...) + copiedColumns := make([]Column, len(v.Columns)) + copy(copiedColumns, v.Columns) + groupBy.Columns = append(copiedColumns, groupBy.Columns...) + + copiedHaving := make([]Expression, len(v.Having)) + copy(copiedHaving, v.Having) + groupBy.Having = append(copiedHaving, groupBy.Having...) } clause.Expression = groupBy } diff --git a/clause/order_by.go b/clause/order_by.go index 307bf930..a8a9539a 100644 --- a/clause/order_by.go +++ b/clause/order_by.go @@ -40,7 +40,9 @@ func (orderBy OrderBy) MergeClause(clause *Clause) { } } - orderBy.Columns = append(v.Columns, orderBy.Columns...) + copiedColumns := make([]OrderByColumn, len(v.Columns)) + copy(copiedColumns, v.Columns) + orderBy.Columns = append(copiedColumns, orderBy.Columns...) } clause.Expression = orderBy diff --git a/clause/set.go b/clause/set.go index 7704ca36..2d3965d3 100644 --- a/clause/set.go +++ b/clause/set.go @@ -32,7 +32,9 @@ func (set Set) Build(builder Builder) { // MergeClause merge assignments clauses func (set Set) MergeClause(clause *Clause) { - clause.Expression = set + copiedAssignments := make([]Assignment, len(set)) + copy(copiedAssignments, set) + clause.Expression = Set(copiedAssignments) } func Assignments(values map[string]interface{}) Set { diff --git a/clause/where.go b/clause/where.go index 015addf8..806565d1 100644 --- a/clause/where.go +++ b/clause/where.go @@ -40,7 +40,9 @@ func (where Where) Build(builder Builder) { // MergeClause merge where clauses func (where Where) MergeClause(clause *Clause) { if w, ok := clause.Expression.(Where); ok { - where.Exprs = append(w.Exprs, where.Exprs...) + copiedExpressions := make([]Expression, len(w.Exprs)) + copy(copiedExpressions, w.Exprs) + where.Exprs = append(copiedExpressions, where.Exprs...) } clause.Expression = where diff --git a/statement_test.go b/statement_test.go new file mode 100644 index 00000000..7d730875 --- /dev/null +++ b/statement_test.go @@ -0,0 +1,37 @@ +package gorm + +import ( + "fmt" + "reflect" + "testing" + + "gorm.io/gorm/clause" +) + +func TestWhereCloneCorruption(t *testing.T) { + for whereCount := 1; whereCount <= 8; whereCount++ { + t.Run(fmt.Sprintf("w=%d", whereCount), func(t *testing.T) { + s := new(Statement) + for w := 0; w < whereCount; w++ { + s = s.clone() + s.AddClause(clause.Where{ + Exprs: s.BuildCondtion(fmt.Sprintf("where%d", w)), + }) + } + + s1 := s.clone() + s1.AddClause(clause.Where{ + Exprs: s.BuildCondtion("FINAL1"), + }) + s2 := s.clone() + s2.AddClause(clause.Where{ + Exprs: s.BuildCondtion("FINAL2"), + }) + + if reflect.DeepEqual(s1.Clauses["WHERE"], s2.Clauses["WHERE"]) { + t.Errorf("Where conditions should be different") + } + }) + } +} +