Add returning support to delete
This commit is contained in:
		
							parent
							
								
									af3fbdc2fc
								
							
						
					
					
						commit
						835d7bde59
					
				| @ -57,7 +57,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { | ||||
| 	deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) | ||||
| 	deleteCallback.Register("gorm:before_delete", BeforeDelete) | ||||
| 	deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) | ||||
| 	deleteCallback.Register("gorm:delete", Delete) | ||||
| 	deleteCallback.Register("gorm:delete", Delete(config)) | ||||
| 	deleteCallback.Register("gorm:after_delete", AfterDelete) | ||||
| 	deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | ||||
| 	deleteCallback.Clauses = config.DeleteClauses | ||||
|  | ||||
| @ -7,6 +7,7 @@ import ( | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/schema" | ||||
| 	"gorm.io/gorm/utils" | ||||
| ) | ||||
| 
 | ||||
| func BeforeCreate(db *gorm.DB) { | ||||
| @ -31,18 +32,12 @@ func BeforeCreate(db *gorm.DB) { | ||||
| } | ||||
| 
 | ||||
| func Create(config *Config) func(db *gorm.DB) { | ||||
| 	withReturning := false | ||||
| 	for _, clause := range config.CreateClauses { | ||||
| 		if clause == "RETURNING" { | ||||
| 			withReturning = true | ||||
| 		} | ||||
| 	} | ||||
| 	supportReturning := utils.Contains(config.CreateClauses, "RETURNING") | ||||
| 
 | ||||
| 	return func(db *gorm.DB) { | ||||
| 		if db.Error != nil { | ||||
| 			return | ||||
| 		} | ||||
| 		onReturning := false | ||||
| 
 | ||||
| 		if db.Statement.Schema != nil { | ||||
| 			if !db.Statement.Unscoped { | ||||
| @ -51,8 +46,7 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if withReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { | ||||
| 				onReturning = true | ||||
| 			if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { | ||||
| 				if _, ok := db.Statement.Clauses["RETURNING"]; !ok { | ||||
| 					fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue)) | ||||
| 					for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { | ||||
| @ -72,18 +66,15 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 		} | ||||
| 
 | ||||
| 		if !db.DryRun && db.Error == nil { | ||||
| 			if onReturning { | ||||
| 				doNothing := false | ||||
| 
 | ||||
| 			if ok, mode := hasReturning(db, supportReturning); ok { | ||||
| 				if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { | ||||
| 					onConflict, _ := c.Expression.(clause.OnConflict) | ||||
| 					doNothing = onConflict.DoNothing | ||||
| 					if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { | ||||
| 						mode |= gorm.ScanOnConflictDoNothing | ||||
| 					} | ||||
| 				} | ||||
| 				if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { | ||||
| 					if doNothing { | ||||
| 						gorm.Scan(rows, db, gorm.ScanUpdate|gorm.ScanOnConflictDoNothing) | ||||
| 					} else { | ||||
| 						gorm.Scan(rows, db, gorm.ScanUpdate) | ||||
| 					} | ||||
| 					gorm.Scan(rows, db, mode) | ||||
| 					rows.Close() | ||||
| 				} | ||||
| 			} else { | ||||
|  | ||||
| @ -7,6 +7,7 @@ import ( | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/schema" | ||||
| 	"gorm.io/gorm/utils" | ||||
| ) | ||||
| 
 | ||||
| func BeforeDelete(db *gorm.DB) { | ||||
| @ -104,8 +105,14 @@ func DeleteBeforeAssociations(db *gorm.DB) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func Delete(db *gorm.DB) { | ||||
| 	if db.Error == nil { | ||||
| func Delete(config *Config) func(db *gorm.DB) { | ||||
| 	supportReturning := utils.Contains(config.DeleteClauses, "RETURNING") | ||||
| 
 | ||||
| 	return func(db *gorm.DB) { | ||||
| 		if db.Error != nil { | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		if db.Statement.Schema != nil && !db.Statement.Unscoped { | ||||
| 			for _, c := range db.Statement.Schema.DeleteClauses { | ||||
| 				db.Statement.AddClause(c) | ||||
| @ -144,12 +151,16 @@ func Delete(db *gorm.DB) { | ||||
| 		} | ||||
| 
 | ||||
| 		if !db.DryRun && db.Error == nil { | ||||
| 			result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 			if err == nil { | ||||
| 				db.RowsAffected, _ = result.RowsAffected() | ||||
| 			if ok, mode := hasReturning(db, supportReturning); ok { | ||||
| 				if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { | ||||
| 					gorm.Scan(rows, db, mode) | ||||
| 					rows.Close() | ||||
| 				} | ||||
| 			} else { | ||||
| 				db.AddError(err) | ||||
| 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 				if db.AddError(err) == nil { | ||||
| 					db.RowsAffected, _ = result.RowsAffected() | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @ -93,3 +93,16 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) { | ||||
| 	if supportReturning { | ||||
| 		if c, ok := tx.Statement.Clauses["RETURNING"]; ok { | ||||
| 			returning, _ := c.Expression.(clause.Returning) | ||||
| 			if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") { | ||||
| 				return true, 0 | ||||
| 			} | ||||
| 			return true, gorm.ScanUpdate | ||||
| 		} | ||||
| 	} | ||||
| 	return false, 0 | ||||
| } | ||||
|  | ||||
| @ -7,6 +7,7 @@ import ( | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/schema" | ||||
| 	"gorm.io/gorm/utils" | ||||
| ) | ||||
| 
 | ||||
| func SetupUpdateReflectValue(db *gorm.DB) { | ||||
| @ -51,12 +52,7 @@ func BeforeUpdate(db *gorm.DB) { | ||||
| } | ||||
| 
 | ||||
| func Update(config *Config) func(db *gorm.DB) { | ||||
| 	withReturning := false | ||||
| 	for _, clause := range config.UpdateClauses { | ||||
| 		if clause == "RETURNING" { | ||||
| 			withReturning = true | ||||
| 		} | ||||
| 	} | ||||
| 	supportReturning := utils.Contains(config.UpdateClauses, "RETURNING") | ||||
| 
 | ||||
| 	return func(db *gorm.DB) { | ||||
| 		if db.Error != nil { | ||||
| @ -86,18 +82,16 @@ func Update(config *Config) func(db *gorm.DB) { | ||||
| 		} | ||||
| 
 | ||||
| 		if !db.DryRun && db.Error == nil { | ||||
| 			if _, ok := db.Statement.Clauses["RETURNING"]; withReturning && ok { | ||||
| 			if ok, mode := hasReturning(db, supportReturning); ok { | ||||
| 				if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { | ||||
| 					gorm.Scan(rows, db, gorm.ScanUpdate) | ||||
| 					gorm.Scan(rows, db, mode) | ||||
| 					rows.Close() | ||||
| 				} | ||||
| 			} else { | ||||
| 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 				if err == nil { | ||||
| 				if db.AddError(err) == nil { | ||||
| 					db.RowsAffected, _ = result.RowsAffected() | ||||
| 				} else { | ||||
| 					db.AddError(err) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| @ -11,6 +11,7 @@ func (returning Returning) Name() string { | ||||
| 
 | ||||
| // Build build where clause
 | ||||
| func (returning Returning) Build(builder Builder) { | ||||
| 	if len(returning.Columns) > 0 { | ||||
| 		for idx, column := range returning.Columns { | ||||
| 			if idx > 0 { | ||||
| 				builder.WriteByte(',') | ||||
| @ -18,6 +19,9 @@ func (returning Returning) Build(builder Builder) { | ||||
| 
 | ||||
| 			builder.WriteQuoted(column) | ||||
| 		} | ||||
| 	} else { | ||||
| 		builder.WriteByte('*') | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // MergeClause merge order by clauses
 | ||||
|  | ||||
							
								
								
									
										2
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								scan.go
									
									
									
									
									
								
							| @ -241,7 +241,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { | ||||
| 		case reflect.Slice, reflect.Array: | ||||
| 			var elem reflect.Value | ||||
| 
 | ||||
| 			if !update { | ||||
| 			if !update && reflectValue.Len() != 0 { | ||||
| 				db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) | ||||
| 			} | ||||
| 
 | ||||
|  | ||||
| @ -9,8 +9,8 @@ require ( | ||||
| 	gorm.io/driver/mysql v1.1.2 | ||||
| 	gorm.io/driver/postgres v1.2.0 | ||||
| 	gorm.io/driver/sqlite v1.2.0 | ||||
| 	gorm.io/driver/sqlserver v1.1.1 | ||||
| 	gorm.io/gorm v1.21.16 | ||||
| 	gorm.io/driver/sqlserver v1.1.2 | ||||
| 	gorm.io/gorm v1.22.0 | ||||
| ) | ||||
| 
 | ||||
| replace gorm.io/gorm => ../ | ||||
|  | ||||
| @ -167,7 +167,7 @@ func TestUpdates(t *testing.T) { | ||||
| 	} | ||||
| 
 | ||||
| 	// update with gorm exprs
 | ||||
| 	if err := DB.Debug().Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { | ||||
| 	if err := DB.Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { | ||||
| 		t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) | ||||
| 	} | ||||
| 	var user4 User | ||||
|  | ||||
| @ -72,6 +72,15 @@ func ToStringKey(values ...interface{}) string { | ||||
| 	return strings.Join(results, "_") | ||||
| } | ||||
| 
 | ||||
| func Contains(elems []string, elem string) bool { | ||||
| 	for _, e := range elems { | ||||
| 		if elem == e { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| func AssertEqual(src, dst interface{}) bool { | ||||
| 	if !reflect.DeepEqual(src, dst) { | ||||
| 		if valuer, ok := src.(driver.Valuer); ok { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu