fix: AfterQuery using safer right trim while clearing from clause's join added as part of https://github.com/go-gorm/gorm/pull/7027

This commit is contained in:
Abhijeet Bhowmik 2024-08-09 16:48:22 +00:00
parent 4a50b36f63
commit 51c9e58204
4 changed files with 115 additions and 1 deletions

View File

@ -288,7 +288,7 @@ func AfterQuery(db *gorm.DB) {
// clear the joins after query because preload need it
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
fromClause := db.Statement.Clauses["FROM"]
fromClause.Expression = clause.From{Tables: v.Tables, Joins: v.Joins[:len(v.Joins)-len(db.Statement.Joins)]} // keep the original From Joins
fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(v.Joins)-len(db.Statement.Joins))} // keep the original From Joins
db.Statement.Clauses["FROM"] = fromClause
}
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {

View File

@ -476,3 +476,19 @@ func TestJoinsPreload_Issue7013_NoEntries(t *testing.T) {
AssertEqual(t, len(entries), 0)
}
func TestJoinWithWrongColumnName_MultipleAfterQueryCalls(t *testing.T) {
type result struct {
gorm.Model
Name string
Pets []*Pet `gorm:"foreignKey:UserID"`
}
user := *GetUser("joins_with_select", Config{Pets: 2})
DB.Save(&user)
var results []result
var total int64
assert.NotPanics(t, func() {
err := DB.Table("users").Select("users.id, pets.id as pet_id, pets.name").Joins("left join pets on pets.user_id = users.id").Where("users.name = ? and pets.names = ?", "joins_with_select", "joins_with_select_pet_2").Preload("Pets").Find(&results).Limit(-1).Offset(-1).Count(&total).Error
assert.ErrorContains(t, err, "no such column: pets.names")
})
}

View File

@ -166,3 +166,26 @@ func SplitNestedRelationName(name string) []string {
func JoinNestedRelationNames(relationNames []string) string {
return strings.Join(relationNames, nestedRelationSplit)
}
// MaxInt returns maximum of two integers
func MaxInt(n1, n2 int) int {
if n1 > n2 {
return n1
}
return n2
}
// MinInt returns minimum of two integers
func MinInt(n1, n2 int) int {
if n1 < n2 {
return n1
}
return n2
}
// RTrimSlice Right trims the give slice by given length
func RTrimSlice[T any](v []T, trimLen int) []T {
rPtr := MaxInt(trimLen, 0) // should not be negative
rPtr = MinInt(len(v), rPtr) // should not be greater than length
return v[:rPtr]
}

View File

@ -138,3 +138,78 @@ func TestToString(t *testing.T) {
})
}
}
func TestMaxInt(t *testing.T) {
type testVal struct {
n1, n2 int
}
integerSet := []int{100, 10, 0, -10, -100} // test set in desc order
samples := []testVal{}
for _, i := range integerSet {
for _, j := range integerSet {
samples = append(samples, testVal{n1: i, n2: j})
}
}
for _, sample := range samples {
t.Run("", func(t *testing.T) {
result := MaxInt(sample.n1, sample.n2)
if !(result >= sample.n1 && result >= sample.n2) {
t.Fatalf("For n1=%d and n2=%d, result is %d;", sample.n1, sample.n2, result)
}
})
}
}
func TestMinInt(t *testing.T) {
type testVal struct {
n1, n2 int
}
integerSet := []int{100, 10, 0, -10, -100} // test set in desc order
samples := []testVal{}
for _, i := range integerSet {
for _, j := range integerSet {
samples = append(samples, testVal{n1: i, n2: j})
}
}
for _, sample := range samples {
t.Run("", func(t *testing.T) {
result := MinInt(sample.n1, sample.n2)
if !(result <= sample.n1 && result <= sample.n2) {
t.Fatalf("For n1=%d and n2=%d, result is %d;", sample.n1, sample.n2, result)
}
})
}
}
func TestRTrimSlice(t *testing.T) {
samples := []struct {
input []int
trimLen int
expected []int
}{
{[]int{1, 2, 3, 4, 5}, 3, []int{1, 2, 3}},
{[]int{1, 2, 3, 4, 5}, 0, []int{}},
{[]int{1, 2, 3, 4, 5}, 5, []int{1, 2, 3, 4, 5}},
{[]int{1, 2, 3, 4, 5}, 10, []int{1, 2, 3, 4, 5}}, // trimLen greater than slice length
{[]int{1, 2, 3, 4, 5}, -1, []int{}}, // negative trimLen
{[]int{}, 3, []int{}}, // empty slice
{[]int{1, 2, 3}, 1, []int{1}}, // trim to a single element
}
for _, sample := range samples {
t.Run("", func(t *testing.T) {
result := RTrimSlice(sample.input, sample.trimLen)
if !AssertEqual(result, sample.expected) {
t.Errorf("Triming %v by length %d gives %v but want %v", sample.input, sample.trimLen, result, sample.expected)
}
})
}
}