Add WithResult support for generics API
This commit is contained in:
		
							parent
							
								
									774d957089
								
							
						
					
					
						commit
						ddaee81548
					
				| @ -89,6 +89,10 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 					db.AddError(rows.Close()) | ||||
| 				}() | ||||
| 				gorm.Scan(rows, db, mode) | ||||
| 
 | ||||
| 				if db.Statement.Result != nil { | ||||
| 					db.Statement.Result.RowsAffected = db.RowsAffected | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			return | ||||
| @ -103,6 +107,12 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 		} | ||||
| 
 | ||||
| 		db.RowsAffected, _ = result.RowsAffected() | ||||
| 
 | ||||
| 		if db.Statement.Result != nil { | ||||
| 			db.Statement.Result.Result = result | ||||
| 			db.Statement.Result.RowsAffected = db.RowsAffected | ||||
| 		} | ||||
| 
 | ||||
| 		if db.RowsAffected == 0 { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| @ -157,8 +157,14 @@ func Delete(config *Config) func(db *gorm.DB) { | ||||
| 			ok, mode := hasReturning(db, supportReturning) | ||||
| 			if !ok { | ||||
| 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 				if db.AddError(err) == nil { | ||||
| 					db.RowsAffected, _ = result.RowsAffected() | ||||
| 
 | ||||
| 					if db.Statement.Result != nil { | ||||
| 						db.Statement.Result.Result = result | ||||
| 						db.Statement.Result.RowsAffected = db.RowsAffected | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				return | ||||
| @ -166,6 +172,10 @@ func Delete(config *Config) func(db *gorm.DB) { | ||||
| 
 | ||||
| 			if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { | ||||
| 				gorm.Scan(rows, db, mode) | ||||
| 
 | ||||
| 				if db.Statement.Result != nil { | ||||
| 					db.Statement.Result.RowsAffected = db.RowsAffected | ||||
| 				} | ||||
| 				db.AddError(rows.Close()) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| @ -25,6 +25,10 @@ func Query(db *gorm.DB) { | ||||
| 				db.AddError(rows.Close()) | ||||
| 			}() | ||||
| 			gorm.Scan(rows, db, 0) | ||||
| 
 | ||||
| 			if db.Statement.Result != nil { | ||||
| 				db.Statement.Result.RowsAffected = db.RowsAffected | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -13,5 +13,10 @@ func RawExec(db *gorm.DB) { | ||||
| 		} | ||||
| 
 | ||||
| 		db.RowsAffected, _ = result.RowsAffected() | ||||
| 
 | ||||
| 		if db.Statement.Result != nil { | ||||
| 			db.Statement.Result.Result = result | ||||
| 			db.Statement.Result.RowsAffected = db.RowsAffected | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -92,6 +92,10 @@ func Update(config *Config) func(db *gorm.DB) { | ||||
| 					gorm.Scan(rows, db, mode) | ||||
| 					db.Statement.Dest = dest | ||||
| 					db.AddError(rows.Close()) | ||||
| 
 | ||||
| 					if db.Statement.Result != nil { | ||||
| 						db.Statement.Result.RowsAffected = db.RowsAffected | ||||
| 					} | ||||
| 				} | ||||
| 			} else { | ||||
| 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| @ -99,6 +103,11 @@ func Update(config *Config) func(db *gorm.DB) { | ||||
| 				if db.AddError(err) == nil { | ||||
| 					db.RowsAffected, _ = result.RowsAffected() | ||||
| 				} | ||||
| 
 | ||||
| 				if db.Statement.Result != nil { | ||||
| 					db.Statement.Result.Result = result | ||||
| 					db.Statement.Result.RowsAffected = db.RowsAffected | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
							
								
								
									
										19
									
								
								generics.go
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								generics.go
									
									
									
									
									
								
							| @ -11,6 +11,23 @@ import ( | ||||
| 	"gorm.io/gorm/logger" | ||||
| ) | ||||
| 
 | ||||
| type result struct { | ||||
| 	Result       sql.Result | ||||
| 	RowsAffected int64 | ||||
| } | ||||
| 
 | ||||
| func (info *result) ModifyStatement(stmt *Statement) { | ||||
| 	stmt.Result = info | ||||
| } | ||||
| 
 | ||||
| // Build implements clause.Expression interface
 | ||||
| func (result) Build(clause.Builder) { | ||||
| } | ||||
| 
 | ||||
| func WithResult() *result { | ||||
| 	return &result{} | ||||
| } | ||||
| 
 | ||||
| type Interface[T any] interface { | ||||
| 	Raw(sql string, values ...interface{}) ExecInterface[T] | ||||
| 	Exec(ctx context.Context, sql string, values ...interface{}) error | ||||
| @ -85,7 +102,7 @@ type op func(*DB) *DB | ||||
| 
 | ||||
| func G[T any](db *DB, opts ...clause.Expression) Interface[T] { | ||||
| 	v := &g[T]{ | ||||
| 		db:  db.Session(&Session{NewDB: true}), | ||||
| 		db:  db, | ||||
| 		ops: make([]op, 0, 5), | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
| @ -47,6 +47,7 @@ type Statement struct { | ||||
| 	attrs                []interface{} | ||||
| 	assigns              []interface{} | ||||
| 	scopes               []func(*DB) *DB | ||||
| 	Result               *result | ||||
| } | ||||
| 
 | ||||
| type join struct { | ||||
| @ -532,6 +533,7 @@ func (stmt *Statement) clone() *Statement { | ||||
| 		Context:              stmt.Context, | ||||
| 		RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, | ||||
| 		SkipHooks:            stmt.SkipHooks, | ||||
| 		Result:               stmt.Result, | ||||
| 	} | ||||
| 
 | ||||
| 	if stmt.SQL.Len() > 0 { | ||||
|  | ||||
| @ -729,3 +729,18 @@ func TestGenericsUpsert(t *testing.T) { | ||||
| 		t.Errorf("should update name on conflict, but got name %+v", langs[0].Name) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsWithResult(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	users := []User{{Name: "TestGenericsWithResult", Age: 18}, {Name: "TestGenericsWithResult2", Age: 18}} | ||||
| 
 | ||||
| 	result := gorm.WithResult() | ||||
| 	err := gorm.G[User](DB, result).CreateInBatches(ctx, &users, 2) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("failed to create users WithResult") | ||||
| 	} | ||||
| 
 | ||||
| 	if result.RowsAffected != 2 { | ||||
| 		t.Errorf("failed to get affected rows, got %d, should be %d", result.RowsAffected, 2) | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu