handle error of generics Joins/Preload

This commit is contained in:
Jinzhu 2025-05-20 22:52:55 +08:00
parent 4694673526
commit e330694e26
2 changed files with 30 additions and 11 deletions

View File

@ -31,8 +31,8 @@ type ChainInterface[T any] interface {
Or(query interface{}, args ...interface{}) ChainInterface[T]
Limit(offset int) ChainInterface[T]
Offset(offset int) ChainInterface[T]
Joins(query clause.JoinTarget, args func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T]
Preload(association string, args func(db PreloadBuilder) error) ChainInterface[T]
Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T]
Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T]
Select(query string, args ...interface{}) ChainInterface[T]
Omit(columns ...string) ChainInterface[T]
MapColumns(m map[string]string) ChainInterface[T]
@ -287,15 +287,17 @@ func (q preloadBuilder) Scopes(scopes ...func(db *Statement)) PreloadBuilder {
return q
}
func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] {
func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] {
return c.with(func(db *DB) *DB {
if jt.Table == "" {
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
}
q := joinBuilder{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)}
if args != nil {
args(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable})
q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)}
if on != nil {
if err := on(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
db.AddError(err)
}
}
j := join{
@ -383,12 +385,14 @@ func (c chainG[T]) Order(value interface{}) ChainInterface[T] {
})
}
func (c chainG[T]) Preload(association string, args func(db PreloadBuilder) error) ChainInterface[T] {
func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Preload(association, func(db *DB) *DB {
q := preloadBuilder{db: db}
if args != nil {
args(q)
return db.Preload(association, func(tx *DB) *DB {
q := preloadBuilder{db: tx}
if query != nil {
if err := query(q); err != nil {
db.AddError(err)
}
}
return q.db
})

View File

@ -2,6 +2,7 @@ package tests_test
import (
"context"
"errors"
"fmt"
"reflect"
"sort"
@ -367,6 +368,13 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
_, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
return errors.New("join error")
}).First(ctx)
if err == nil {
t.Fatalf("Joins should got error, but got nil")
}
// Preload
result3, err := db.Preload("Company", nil).Where("name = ?", u.Name).First(ctx)
if err != nil {
@ -392,6 +400,13 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
t.Fatalf("Preload other company should not loaded, user %v company expect %v but got %+v", u.Name, u.Company.Name, result.Company.Name)
}
}
_, err = db.Preload("Company", func(db gorm.PreloadBuilder) error {
return errors.New("preload error")
}).Find(ctx)
if err == nil {
t.Fatalf("Preload should failed, but got nil")
}
}
func TestGenericsDistinct(t *testing.T) {