From e330694e262a94523e4c330961a39c0ba221bc62 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 20 May 2025 22:52:55 +0800 Subject: [PATCH] handle error of generics Joins/Preload --- generics.go | 26 +++++++++++++++----------- tests/generics_test.go | 15 +++++++++++++++ 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/generics.go b/generics.go index 4953c758..230f07f5 100644 --- a/generics.go +++ b/generics.go @@ -31,8 +31,8 @@ type ChainInterface[T any] interface { Or(query interface{}, args ...interface{}) ChainInterface[T] Limit(offset int) ChainInterface[T] Offset(offset int) ChainInterface[T] - Joins(query clause.JoinTarget, args func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] - Preload(association string, args func(db PreloadBuilder) error) ChainInterface[T] + Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] + Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] Select(query string, args ...interface{}) ChainInterface[T] Omit(columns ...string) ChainInterface[T] MapColumns(m map[string]string) ChainInterface[T] @@ -287,15 +287,17 @@ func (q preloadBuilder) Scopes(scopes ...func(db *Statement)) PreloadBuilder { return q } -func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] { +func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] { return c.with(func(db *DB) *DB { if jt.Table == "" { jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name } - q := joinBuilder{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)} - if args != nil { - args(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}) + q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)} + if on != nil { + if err := on(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil { + db.AddError(err) + } } j := join{ @@ -383,12 +385,14 @@ func (c chainG[T]) Order(value interface{}) ChainInterface[T] { }) } -func (c chainG[T]) Preload(association string, args func(db PreloadBuilder) error) ChainInterface[T] { +func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] { return c.with(func(db *DB) *DB { - return db.Preload(association, func(db *DB) *DB { - q := preloadBuilder{db: db} - if args != nil { - args(q) + return db.Preload(association, func(tx *DB) *DB { + q := preloadBuilder{db: tx} + if query != nil { + if err := query(q); err != nil { + db.AddError(err) + } } return q.db }) diff --git a/tests/generics_test.go b/tests/generics_test.go index 2efaacdc..2e0dbc28 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -2,6 +2,7 @@ package tests_test import ( "context" + "errors" "fmt" "reflect" "sort" @@ -367,6 +368,13 @@ func TestGenericsJoinsAndPreload(t *testing.T) { t.Fatalf("Joins expected %s, got %+v", u.Name, result) } + _, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { + return errors.New("join error") + }).First(ctx) + if err == nil { + t.Fatalf("Joins should got error, but got nil") + } + // Preload result3, err := db.Preload("Company", nil).Where("name = ?", u.Name).First(ctx) if err != nil { @@ -392,6 +400,13 @@ func TestGenericsJoinsAndPreload(t *testing.T) { t.Fatalf("Preload other company should not loaded, user %v company expect %v but got %+v", u.Name, u.Company.Name, result.Company.Name) } } + + _, err = db.Preload("Company", func(db gorm.PreloadBuilder) error { + return errors.New("preload error") + }).Find(ctx) + if err == nil { + t.Fatalf("Preload should failed, but got nil") + } } func TestGenericsDistinct(t *testing.T) {