Add WithResult support for generics API
This commit is contained in:
parent
774d957089
commit
ddaee81548
@ -89,6 +89,10 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
db.AddError(rows.Close())
|
db.AddError(rows.Close())
|
||||||
}()
|
}()
|
||||||
gorm.Scan(rows, db, mode)
|
gorm.Scan(rows, db, mode)
|
||||||
|
|
||||||
|
if db.Statement.Result != nil {
|
||||||
|
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
@ -103,6 +107,12 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
|
||||||
|
if db.Statement.Result != nil {
|
||||||
|
db.Statement.Result.Result = result
|
||||||
|
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||||
|
}
|
||||||
|
|
||||||
if db.RowsAffected == 0 {
|
if db.RowsAffected == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -157,8 +157,14 @@ func Delete(config *Config) func(db *gorm.DB) {
|
|||||||
ok, mode := hasReturning(db, supportReturning)
|
ok, mode := hasReturning(db, supportReturning)
|
||||||
if !ok {
|
if !ok {
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
if db.AddError(err) == nil {
|
if db.AddError(err) == nil {
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
|
||||||
|
if db.Statement.Result != nil {
|
||||||
|
db.Statement.Result.Result = result
|
||||||
|
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
@ -166,6 +172,10 @@ func Delete(config *Config) func(db *gorm.DB) {
|
|||||||
|
|
||||||
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||||
gorm.Scan(rows, db, mode)
|
gorm.Scan(rows, db, mode)
|
||||||
|
|
||||||
|
if db.Statement.Result != nil {
|
||||||
|
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||||
|
}
|
||||||
db.AddError(rows.Close())
|
db.AddError(rows.Close())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -25,6 +25,10 @@ func Query(db *gorm.DB) {
|
|||||||
db.AddError(rows.Close())
|
db.AddError(rows.Close())
|
||||||
}()
|
}()
|
||||||
gorm.Scan(rows, db, 0)
|
gorm.Scan(rows, db, 0)
|
||||||
|
|
||||||
|
if db.Statement.Result != nil {
|
||||||
|
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,5 +13,10 @@ func RawExec(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
|
||||||
|
if db.Statement.Result != nil {
|
||||||
|
db.Statement.Result.Result = result
|
||||||
|
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -92,6 +92,10 @@ func Update(config *Config) func(db *gorm.DB) {
|
|||||||
gorm.Scan(rows, db, mode)
|
gorm.Scan(rows, db, mode)
|
||||||
db.Statement.Dest = dest
|
db.Statement.Dest = dest
|
||||||
db.AddError(rows.Close())
|
db.AddError(rows.Close())
|
||||||
|
|
||||||
|
if db.Statement.Result != nil {
|
||||||
|
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
@ -99,6 +103,11 @@ func Update(config *Config) func(db *gorm.DB) {
|
|||||||
if db.AddError(err) == nil {
|
if db.AddError(err) == nil {
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if db.Statement.Result != nil {
|
||||||
|
db.Statement.Result.Result = result
|
||||||
|
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
19
generics.go
19
generics.go
@ -11,6 +11,23 @@ import (
|
|||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type result struct {
|
||||||
|
Result sql.Result
|
||||||
|
RowsAffected int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (info *result) ModifyStatement(stmt *Statement) {
|
||||||
|
stmt.Result = info
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build implements clause.Expression interface
|
||||||
|
func (result) Build(clause.Builder) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithResult() *result {
|
||||||
|
return &result{}
|
||||||
|
}
|
||||||
|
|
||||||
type Interface[T any] interface {
|
type Interface[T any] interface {
|
||||||
Raw(sql string, values ...interface{}) ExecInterface[T]
|
Raw(sql string, values ...interface{}) ExecInterface[T]
|
||||||
Exec(ctx context.Context, sql string, values ...interface{}) error
|
Exec(ctx context.Context, sql string, values ...interface{}) error
|
||||||
@ -85,7 +102,7 @@ type op func(*DB) *DB
|
|||||||
|
|
||||||
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
|
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
|
||||||
v := &g[T]{
|
v := &g[T]{
|
||||||
db: db.Session(&Session{NewDB: true}),
|
db: db,
|
||||||
ops: make([]op, 0, 5),
|
ops: make([]op, 0, 5),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,6 +47,7 @@ type Statement struct {
|
|||||||
attrs []interface{}
|
attrs []interface{}
|
||||||
assigns []interface{}
|
assigns []interface{}
|
||||||
scopes []func(*DB) *DB
|
scopes []func(*DB) *DB
|
||||||
|
Result *result
|
||||||
}
|
}
|
||||||
|
|
||||||
type join struct {
|
type join struct {
|
||||||
@ -532,6 +533,7 @@ func (stmt *Statement) clone() *Statement {
|
|||||||
Context: stmt.Context,
|
Context: stmt.Context,
|
||||||
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
|
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
|
||||||
SkipHooks: stmt.SkipHooks,
|
SkipHooks: stmt.SkipHooks,
|
||||||
|
Result: stmt.Result,
|
||||||
}
|
}
|
||||||
|
|
||||||
if stmt.SQL.Len() > 0 {
|
if stmt.SQL.Len() > 0 {
|
||||||
|
@ -729,3 +729,18 @@ func TestGenericsUpsert(t *testing.T) {
|
|||||||
t.Errorf("should update name on conflict, but got name %+v", langs[0].Name)
|
t.Errorf("should update name on conflict, but got name %+v", langs[0].Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenericsWithResult(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
users := []User{{Name: "TestGenericsWithResult", Age: 18}, {Name: "TestGenericsWithResult2", Age: 18}}
|
||||||
|
|
||||||
|
result := gorm.WithResult()
|
||||||
|
err := gorm.G[User](DB, result).CreateInBatches(ctx, &users, 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to create users WithResult")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected != 2 {
|
||||||
|
t.Errorf("failed to get affected rows, got %d, should be %d", result.RowsAffected, 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user