From ddaee81548c8ffe79d9e67e56c56788324d6788b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 22 May 2025 20:19:23 +0800 Subject: [PATCH] Add WithResult support for generics API --- callbacks/create.go | 10 ++++++++++ callbacks/delete.go | 10 ++++++++++ callbacks/query.go | 4 ++++ callbacks/raw.go | 5 +++++ callbacks/update.go | 9 +++++++++ generics.go | 19 ++++++++++++++++++- statement.go | 2 ++ tests/generics_test.go | 15 +++++++++++++++ 8 files changed, 73 insertions(+), 1 deletion(-) diff --git a/callbacks/create.go b/callbacks/create.go index 8b7846b6..d8701f51 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -89,6 +89,10 @@ func Create(config *Config) func(db *gorm.DB) { db.AddError(rows.Close()) }() gorm.Scan(rows, db, mode) + + if db.Statement.Result != nil { + db.Statement.Result.RowsAffected = db.RowsAffected + } } return @@ -103,6 +107,12 @@ func Create(config *Config) func(db *gorm.DB) { } db.RowsAffected, _ = result.RowsAffected() + + if db.Statement.Result != nil { + db.Statement.Result.Result = result + db.Statement.Result.RowsAffected = db.RowsAffected + } + if db.RowsAffected == 0 { return } diff --git a/callbacks/delete.go b/callbacks/delete.go index 84f446a3..07ed6fee 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -157,8 +157,14 @@ func Delete(config *Config) func(db *gorm.DB) { ok, mode := hasReturning(db, supportReturning) if !ok { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if db.AddError(err) == nil { db.RowsAffected, _ = result.RowsAffected() + + if db.Statement.Result != nil { + db.Statement.Result.Result = result + db.Statement.Result.RowsAffected = db.RowsAffected + } } return @@ -166,6 +172,10 @@ func Delete(config *Config) func(db *gorm.DB) { if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { gorm.Scan(rows, db, mode) + + if db.Statement.Result != nil { + db.Statement.Result.RowsAffected = db.RowsAffected + } db.AddError(rows.Close()) } } diff --git a/callbacks/query.go b/callbacks/query.go index c8632cc5..548bf709 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -25,6 +25,10 @@ func Query(db *gorm.DB) { db.AddError(rows.Close()) }() gorm.Scan(rows, db, 0) + + if db.Statement.Result != nil { + db.Statement.Result.RowsAffected = db.RowsAffected + } } } } diff --git a/callbacks/raw.go b/callbacks/raw.go index 013e638c..3bb647c4 100644 --- a/callbacks/raw.go +++ b/callbacks/raw.go @@ -13,5 +13,10 @@ func RawExec(db *gorm.DB) { } db.RowsAffected, _ = result.RowsAffected() + + if db.Statement.Result != nil { + db.Statement.Result.Result = result + db.Statement.Result.RowsAffected = db.RowsAffected + } } } diff --git a/callbacks/update.go b/callbacks/update.go index 7cde7f61..8e2782e1 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -92,6 +92,10 @@ func Update(config *Config) func(db *gorm.DB) { gorm.Scan(rows, db, mode) db.Statement.Dest = dest db.AddError(rows.Close()) + + if db.Statement.Result != nil { + db.Statement.Result.RowsAffected = db.RowsAffected + } } } else { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) @@ -99,6 +103,11 @@ func Update(config *Config) func(db *gorm.DB) { if db.AddError(err) == nil { db.RowsAffected, _ = result.RowsAffected() } + + if db.Statement.Result != nil { + db.Statement.Result.Result = result + db.Statement.Result.RowsAffected = db.RowsAffected + } } } } diff --git a/generics.go b/generics.go index f2863dac..54ccfca0 100644 --- a/generics.go +++ b/generics.go @@ -11,6 +11,23 @@ import ( "gorm.io/gorm/logger" ) +type result struct { + Result sql.Result + RowsAffected int64 +} + +func (info *result) ModifyStatement(stmt *Statement) { + stmt.Result = info +} + +// Build implements clause.Expression interface +func (result) Build(clause.Builder) { +} + +func WithResult() *result { + return &result{} +} + type Interface[T any] interface { Raw(sql string, values ...interface{}) ExecInterface[T] Exec(ctx context.Context, sql string, values ...interface{}) error @@ -85,7 +102,7 @@ type op func(*DB) *DB func G[T any](db *DB, opts ...clause.Expression) Interface[T] { v := &g[T]{ - db: db.Session(&Session{NewDB: true}), + db: db, ops: make([]op, 0, 5), } diff --git a/statement.go b/statement.go index 19cdbbaf..c6183724 100644 --- a/statement.go +++ b/statement.go @@ -47,6 +47,7 @@ type Statement struct { attrs []interface{} assigns []interface{} scopes []func(*DB) *DB + Result *result } type join struct { @@ -532,6 +533,7 @@ func (stmt *Statement) clone() *Statement { Context: stmt.Context, RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, SkipHooks: stmt.SkipHooks, + Result: stmt.Result, } if stmt.SQL.Len() > 0 { diff --git a/tests/generics_test.go b/tests/generics_test.go index 876c7409..5ab76ae7 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -729,3 +729,18 @@ func TestGenericsUpsert(t *testing.T) { t.Errorf("should update name on conflict, but got name %+v", langs[0].Name) } } + +func TestGenericsWithResult(t *testing.T) { + ctx := context.Background() + users := []User{{Name: "TestGenericsWithResult", Age: 18}, {Name: "TestGenericsWithResult2", Age: 18}} + + result := gorm.WithResult() + err := gorm.G[User](DB, result).CreateInBatches(ctx, &users, 2) + if err != nil { + t.Errorf("failed to create users WithResult") + } + + if result.RowsAffected != 2 { + t.Errorf("failed to get affected rows, got %d, should be %d", result.RowsAffected, 2) + } +}