diff --git a/callbacks/query.go b/callbacks/query.go index 9b2b17ea..fab67dce 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -288,7 +288,7 @@ func AfterQuery(db *gorm.DB) { // clear the joins after query because preload need it if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { fromClause := db.Statement.Clauses["FROM"] - fromClause.Expression = clause.From{Tables: v.Tables, Joins: v.Joins[:len(v.Joins)-len(db.Statement.Joins)]} // keep the original From Joins + fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(v.Joins)-len(db.Statement.Joins))} // keep the original From Joins db.Statement.Clauses["FROM"] = fromClause } if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { diff --git a/tests/joins_test.go b/tests/joins_test.go index 497f8146..b9487945 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -476,3 +476,19 @@ func TestJoinsPreload_Issue7013_NoEntries(t *testing.T) { AssertEqual(t, len(entries), 0) } + +func TestJoinWithWrongColumnName_MultipleAfterQueryCalls(t *testing.T) { + type result struct { + gorm.Model + Name string + Pets []*Pet `gorm:"foreignKey:UserID"` + } + user := *GetUser("joins_with_select", Config{Pets: 2}) + DB.Save(&user) + var results []result + var total int64 + assert.NotPanics(t, func() { + err := DB.Table("users").Select("users.id, pets.id as pet_id, pets.name").Joins("left join pets on pets.user_id = users.id").Where("users.name = ? and pets.names = ?", "joins_with_select", "joins_with_select_pet_2").Preload("Pets").Find(&results).Limit(-1).Offset(-1).Count(&total).Error + assert.ErrorContains(t, err, "no such column: pets.names") + }) +} diff --git a/utils/utils.go b/utils/utils.go index b8d30b35..0b70da03 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -166,3 +166,26 @@ func SplitNestedRelationName(name string) []string { func JoinNestedRelationNames(relationNames []string) string { return strings.Join(relationNames, nestedRelationSplit) } + +// MaxInt returns maximum of two integers +func MaxInt(n1, n2 int) int { + if n1 > n2 { + return n1 + } + return n2 +} + +// MinInt returns minimum of two integers +func MinInt(n1, n2 int) int { + if n1 < n2 { + return n1 + } + return n2 +} + +// RTrimSlice Right trims the give slice by given length +func RTrimSlice[T any](v []T, trimLen int) []T { + rPtr := MaxInt(trimLen, 0) // should not be negative + rPtr = MinInt(len(v), rPtr) // should not be greater than length + return v[:rPtr] +} diff --git a/utils/utils_test.go b/utils/utils_test.go index 8ff42af8..3f581897 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -138,3 +138,78 @@ func TestToString(t *testing.T) { }) } } + +func TestMaxInt(t *testing.T) { + + type testVal struct { + n1, n2 int + } + + integerSet := []int{100, 10, 0, -10, -100} // test set in desc order + samples := []testVal{} + + for _, i := range integerSet { + for _, j := range integerSet { + samples = append(samples, testVal{n1: i, n2: j}) + } + } + + for _, sample := range samples { + t.Run("", func(t *testing.T) { + result := MaxInt(sample.n1, sample.n2) + if !(result >= sample.n1 && result >= sample.n2) { + t.Fatalf("For n1=%d and n2=%d, result is %d;", sample.n1, sample.n2, result) + } + }) + } +} + +func TestMinInt(t *testing.T) { + + type testVal struct { + n1, n2 int + } + + integerSet := []int{100, 10, 0, -10, -100} // test set in desc order + samples := []testVal{} + + for _, i := range integerSet { + for _, j := range integerSet { + samples = append(samples, testVal{n1: i, n2: j}) + } + } + + for _, sample := range samples { + t.Run("", func(t *testing.T) { + result := MinInt(sample.n1, sample.n2) + if !(result <= sample.n1 && result <= sample.n2) { + t.Fatalf("For n1=%d and n2=%d, result is %d;", sample.n1, sample.n2, result) + } + }) + } +} + +func TestRTrimSlice(t *testing.T) { + samples := []struct { + input []int + trimLen int + expected []int + }{ + {[]int{1, 2, 3, 4, 5}, 3, []int{1, 2, 3}}, + {[]int{1, 2, 3, 4, 5}, 0, []int{}}, + {[]int{1, 2, 3, 4, 5}, 5, []int{1, 2, 3, 4, 5}}, + {[]int{1, 2, 3, 4, 5}, 10, []int{1, 2, 3, 4, 5}}, // trimLen greater than slice length + {[]int{1, 2, 3, 4, 5}, -1, []int{}}, // negative trimLen + {[]int{}, 3, []int{}}, // empty slice + {[]int{1, 2, 3}, 1, []int{1}}, // trim to a single element + } + + for _, sample := range samples { + t.Run("", func(t *testing.T) { + result := RTrimSlice(sample.input, sample.trimLen) + if !AssertEqual(result, sample.expected) { + t.Errorf("Triming %v by length %d gives %v but want %v", sample.input, sample.trimLen, result, sample.expected) + } + }) + } +}