Add WithResult support for generics API

This commit is contained in:
Jinzhu 2025-05-22 20:19:23 +08:00
parent 774d957089
commit ddaee81548
8 changed files with 73 additions and 1 deletions

View File

@ -89,6 +89,10 @@ func Create(config *Config) func(db *gorm.DB) {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, mode)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
return
@ -103,6 +107,12 @@ func Create(config *Config) func(db *gorm.DB) {
}
db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
if db.RowsAffected == 0 {
return
}

View File

@ -157,8 +157,14 @@ func Delete(config *Config) func(db *gorm.DB) {
ok, mode := hasReturning(db, supportReturning)
if !ok {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
return
@ -166,6 +172,10 @@ func Delete(config *Config) func(db *gorm.DB) {
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
gorm.Scan(rows, db, mode)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
db.AddError(rows.Close())
}
}

View File

@ -25,6 +25,10 @@ func Query(db *gorm.DB) {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, 0)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
}
}

View File

@ -13,5 +13,10 @@ func RawExec(db *gorm.DB) {
}
db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
}

View File

@ -92,6 +92,10 @@ func Update(config *Config) func(db *gorm.DB) {
gorm.Scan(rows, db, mode)
db.Statement.Dest = dest
db.AddError(rows.Close())
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
} else {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
@ -99,6 +103,11 @@ func Update(config *Config) func(db *gorm.DB) {
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
}
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
}
}

View File

@ -11,6 +11,23 @@ import (
"gorm.io/gorm/logger"
)
type result struct {
Result sql.Result
RowsAffected int64
}
func (info *result) ModifyStatement(stmt *Statement) {
stmt.Result = info
}
// Build implements clause.Expression interface
func (result) Build(clause.Builder) {
}
func WithResult() *result {
return &result{}
}
type Interface[T any] interface {
Raw(sql string, values ...interface{}) ExecInterface[T]
Exec(ctx context.Context, sql string, values ...interface{}) error
@ -85,7 +102,7 @@ type op func(*DB) *DB
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
v := &g[T]{
db: db.Session(&Session{NewDB: true}),
db: db,
ops: make([]op, 0, 5),
}

View File

@ -47,6 +47,7 @@ type Statement struct {
attrs []interface{}
assigns []interface{}
scopes []func(*DB) *DB
Result *result
}
type join struct {
@ -532,6 +533,7 @@ func (stmt *Statement) clone() *Statement {
Context: stmt.Context,
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
SkipHooks: stmt.SkipHooks,
Result: stmt.Result,
}
if stmt.SQL.Len() > 0 {

View File

@ -729,3 +729,18 @@ func TestGenericsUpsert(t *testing.T) {
t.Errorf("should update name on conflict, but got name %+v", langs[0].Name)
}
}
func TestGenericsWithResult(t *testing.T) {
ctx := context.Background()
users := []User{{Name: "TestGenericsWithResult", Age: 18}, {Name: "TestGenericsWithResult2", Age: 18}}
result := gorm.WithResult()
err := gorm.G[User](DB, result).CreateInBatches(ctx, &users, 2)
if err != nil {
t.Errorf("failed to create users WithResult")
}
if result.RowsAffected != 2 {
t.Errorf("failed to get affected rows, got %d, should be %d", result.RowsAffected, 2)
}
}