Improve returning support
This commit is contained in:
		
							parent
							
								
									d3211908a0
								
							
						
					
					
						commit
						af3fbdc2fc
					
				| @ -13,7 +13,6 @@ var ( | |||||||
| 
 | 
 | ||||||
| type Config struct { | type Config struct { | ||||||
| 	LastInsertIDReversed bool | 	LastInsertIDReversed bool | ||||||
| 	WithReturning        bool |  | ||||||
| 	CreateClauses        []string | 	CreateClauses        []string | ||||||
| 	QueryClauses         []string | 	QueryClauses         []string | ||||||
| 	UpdateClauses        []string | 	UpdateClauses        []string | ||||||
| @ -25,6 +24,19 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { | |||||||
| 		return !db.SkipDefaultTransaction | 		return !db.SkipDefaultTransaction | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	if len(config.CreateClauses) == 0 { | ||||||
|  | 		config.CreateClauses = createClauses | ||||||
|  | 	} | ||||||
|  | 	if len(config.QueryClauses) == 0 { | ||||||
|  | 		config.QueryClauses = queryClauses | ||||||
|  | 	} | ||||||
|  | 	if len(config.DeleteClauses) == 0 { | ||||||
|  | 		config.DeleteClauses = deleteClauses | ||||||
|  | 	} | ||||||
|  | 	if len(config.UpdateClauses) == 0 { | ||||||
|  | 		config.UpdateClauses = updateClauses | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	createCallback := db.Callback().Create() | 	createCallback := db.Callback().Create() | ||||||
| 	createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) | 	createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) | ||||||
| 	createCallback.Register("gorm:before_create", BeforeCreate) | 	createCallback.Register("gorm:before_create", BeforeCreate) | ||||||
| @ -33,18 +45,12 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { | |||||||
| 	createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) | 	createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) | ||||||
| 	createCallback.Register("gorm:after_create", AfterCreate) | 	createCallback.Register("gorm:after_create", AfterCreate) | ||||||
| 	createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | 	createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | ||||||
| 	if len(config.CreateClauses) == 0 { |  | ||||||
| 		config.CreateClauses = createClauses |  | ||||||
| 	} |  | ||||||
| 	createCallback.Clauses = config.CreateClauses | 	createCallback.Clauses = config.CreateClauses | ||||||
| 
 | 
 | ||||||
| 	queryCallback := db.Callback().Query() | 	queryCallback := db.Callback().Query() | ||||||
| 	queryCallback.Register("gorm:query", Query) | 	queryCallback.Register("gorm:query", Query) | ||||||
| 	queryCallback.Register("gorm:preload", Preload) | 	queryCallback.Register("gorm:preload", Preload) | ||||||
| 	queryCallback.Register("gorm:after_query", AfterQuery) | 	queryCallback.Register("gorm:after_query", AfterQuery) | ||||||
| 	if len(config.QueryClauses) == 0 { |  | ||||||
| 		config.QueryClauses = queryClauses |  | ||||||
| 	} |  | ||||||
| 	queryCallback.Clauses = config.QueryClauses | 	queryCallback.Clauses = config.QueryClauses | ||||||
| 
 | 
 | ||||||
| 	deleteCallback := db.Callback().Delete() | 	deleteCallback := db.Callback().Delete() | ||||||
| @ -54,9 +60,6 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { | |||||||
| 	deleteCallback.Register("gorm:delete", Delete) | 	deleteCallback.Register("gorm:delete", Delete) | ||||||
| 	deleteCallback.Register("gorm:after_delete", AfterDelete) | 	deleteCallback.Register("gorm:after_delete", AfterDelete) | ||||||
| 	deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | 	deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | ||||||
| 	if len(config.DeleteClauses) == 0 { |  | ||||||
| 		config.DeleteClauses = deleteClauses |  | ||||||
| 	} |  | ||||||
| 	deleteCallback.Clauses = config.DeleteClauses | 	deleteCallback.Clauses = config.DeleteClauses | ||||||
| 
 | 
 | ||||||
| 	updateCallback := db.Callback().Update() | 	updateCallback := db.Callback().Update() | ||||||
| @ -64,13 +67,10 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { | |||||||
| 	updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) | 	updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) | ||||||
| 	updateCallback.Register("gorm:before_update", BeforeUpdate) | 	updateCallback.Register("gorm:before_update", BeforeUpdate) | ||||||
| 	updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) | 	updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) | ||||||
| 	updateCallback.Register("gorm:update", Update) | 	updateCallback.Register("gorm:update", Update(config)) | ||||||
| 	updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) | 	updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) | ||||||
| 	updateCallback.Register("gorm:after_update", AfterUpdate) | 	updateCallback.Register("gorm:after_update", AfterUpdate) | ||||||
| 	updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | 	updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | ||||||
| 	if len(config.UpdateClauses) == 0 { |  | ||||||
| 		config.UpdateClauses = updateClauses |  | ||||||
| 	} |  | ||||||
| 	updateCallback.Clauses = config.UpdateClauses | 	updateCallback.Clauses = config.UpdateClauses | ||||||
| 
 | 
 | ||||||
| 	rowCallback := db.Callback().Row() | 	rowCallback := db.Callback().Row() | ||||||
|  | |||||||
| @ -31,18 +31,35 @@ func BeforeCreate(db *gorm.DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Create(config *Config) func(db *gorm.DB) { | func Create(config *Config) func(db *gorm.DB) { | ||||||
| 	if config.WithReturning { | 	withReturning := false | ||||||
| 		return CreateWithReturning | 	for _, clause := range config.CreateClauses { | ||||||
|  | 		if clause == "RETURNING" { | ||||||
|  | 			withReturning = true | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return func(db *gorm.DB) { | 	return func(db *gorm.DB) { | ||||||
| 		if db.Error != nil { | 		if db.Error != nil { | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  | 		onReturning := false | ||||||
| 
 | 
 | ||||||
| 		if db.Statement.Schema != nil && !db.Statement.Unscoped { | 		if db.Statement.Schema != nil { | ||||||
| 			for _, c := range db.Statement.Schema.CreateClauses { | 			if !db.Statement.Unscoped { | ||||||
| 				db.Statement.AddClause(c) | 				for _, c := range db.Statement.Schema.CreateClauses { | ||||||
|  | 					db.Statement.AddClause(c) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if withReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { | ||||||
|  | 				onReturning = true | ||||||
|  | 				if _, ok := db.Statement.Clauses["RETURNING"]; !ok { | ||||||
|  | 					fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue)) | ||||||
|  | 					for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { | ||||||
|  | 						fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) | ||||||
|  | 					} | ||||||
|  | 					db.Statement.AddClause(clause.Returning{Columns: fromColumns}) | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| @ -55,180 +72,70 @@ func Create(config *Config) func(db *gorm.DB) { | |||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if !db.DryRun && db.Error == nil { | 		if !db.DryRun && db.Error == nil { | ||||||
| 			result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 			if onReturning { | ||||||
| 
 | 				doNothing := false | ||||||
| 			if err != nil { | 				if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { | ||||||
| 				db.AddError(err) | 					onConflict, _ := c.Expression.(clause.OnConflict) | ||||||
| 				return | 					doNothing = onConflict.DoNothing | ||||||
| 			} | 				} | ||||||
| 
 | 				if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { | ||||||
| 			db.RowsAffected, _ = result.RowsAffected() | 					if doNothing { | ||||||
| 
 | 						gorm.Scan(rows, db, gorm.ScanUpdate|gorm.ScanOnConflictDoNothing) | ||||||
| 			if db.RowsAffected != 0 && db.Statement.Schema != nil && | 					} else { | ||||||
| 				db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { | 						gorm.Scan(rows, db, gorm.ScanUpdate) | ||||||
| 				if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { |  | ||||||
| 					switch db.Statement.ReflectValue.Kind() { |  | ||||||
| 					case reflect.Slice, reflect.Array: |  | ||||||
| 						if config.LastInsertIDReversed { |  | ||||||
| 							for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { |  | ||||||
| 								rv := db.Statement.ReflectValue.Index(i) |  | ||||||
| 								if reflect.Indirect(rv).Kind() != reflect.Struct { |  | ||||||
| 									break |  | ||||||
| 								} |  | ||||||
| 
 |  | ||||||
| 								_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) |  | ||||||
| 								if isZero { |  | ||||||
| 									db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) |  | ||||||
| 									insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement |  | ||||||
| 								} |  | ||||||
| 							} |  | ||||||
| 						} else { |  | ||||||
| 							for i := 0; i < db.Statement.ReflectValue.Len(); i++ { |  | ||||||
| 								rv := db.Statement.ReflectValue.Index(i) |  | ||||||
| 								if reflect.Indirect(rv).Kind() != reflect.Struct { |  | ||||||
| 									break |  | ||||||
| 								} |  | ||||||
| 
 |  | ||||||
| 								if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { |  | ||||||
| 									db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) |  | ||||||
| 									insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement |  | ||||||
| 								} |  | ||||||
| 							} |  | ||||||
| 						} |  | ||||||
| 					case reflect.Struct: |  | ||||||
| 						if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { |  | ||||||
| 							db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) |  | ||||||
| 						} |  | ||||||
| 					} | 					} | ||||||
| 				} else { | 					rows.Close() | ||||||
|  | 				} | ||||||
|  | 			} else { | ||||||
|  | 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
|  | 
 | ||||||
|  | 				if err != nil { | ||||||
| 					db.AddError(err) | 					db.AddError(err) | ||||||
| 				} | 					return | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func CreateWithReturning(db *gorm.DB) { |  | ||||||
| 	if db.Error == nil { |  | ||||||
| 		if db.Statement.Schema != nil && !db.Statement.Unscoped { |  | ||||||
| 			for _, c := range db.Statement.Schema.CreateClauses { |  | ||||||
| 				db.Statement.AddClause(c) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if db.Statement.SQL.String() == "" { |  | ||||||
| 			db.Statement.AddClauseIfNotExists(clause.Insert{}) |  | ||||||
| 			db.Statement.AddClause(ConvertToCreateValues(db.Statement)) |  | ||||||
| 
 |  | ||||||
| 			db.Statement.Build(db.Statement.BuildClauses...) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { |  | ||||||
| 			db.Statement.WriteString(" RETURNING ") |  | ||||||
| 
 |  | ||||||
| 			var ( |  | ||||||
| 				fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) |  | ||||||
| 				values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) |  | ||||||
| 			) |  | ||||||
| 
 |  | ||||||
| 			for idx, field := range sch.FieldsWithDefaultDBValue { |  | ||||||
| 				if idx > 0 { |  | ||||||
| 					db.Statement.WriteByte(',') |  | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				fields[idx] = field | 				db.RowsAffected, _ = result.RowsAffected() | ||||||
| 				db.Statement.WriteQuoted(field.DBName) | 				if db.RowsAffected != 0 && db.Statement.Schema != nil && | ||||||
| 			} | 					db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { | ||||||
| 
 | 					if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { | ||||||
| 			if !db.DryRun && db.Error == nil { | 						switch db.Statement.ReflectValue.Kind() { | ||||||
| 				db.RowsAffected = 0 | 						case reflect.Slice, reflect.Array: | ||||||
| 				rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 							if config.LastInsertIDReversed { | ||||||
| 
 | 								for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { | ||||||
| 				if err == nil { | 									rv := db.Statement.ReflectValue.Index(i) | ||||||
| 					defer rows.Close() | 									if reflect.Indirect(rv).Kind() != reflect.Struct { | ||||||
| 
 | 										break | ||||||
| 					switch db.Statement.ReflectValue.Kind() { |  | ||||||
| 					case reflect.Slice, reflect.Array: |  | ||||||
| 						var ( |  | ||||||
| 							c                = db.Statement.Clauses["ON CONFLICT"] |  | ||||||
| 							onConflict, _    = c.Expression.(clause.OnConflict) |  | ||||||
| 							resetFieldValues = map[int]reflect.Value{} |  | ||||||
| 						) |  | ||||||
| 
 |  | ||||||
| 						for rows.Next() { |  | ||||||
| 						BEGIN: |  | ||||||
| 							reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected)) |  | ||||||
| 							if reflect.Indirect(reflectValue).Kind() != reflect.Struct { |  | ||||||
| 								break |  | ||||||
| 							} |  | ||||||
| 
 |  | ||||||
| 							for idx, field := range fields { |  | ||||||
| 								fieldValue := field.ReflectValueOf(reflectValue) |  | ||||||
| 
 |  | ||||||
| 								if onConflict.DoNothing && !fieldValue.IsZero() { |  | ||||||
| 									db.RowsAffected++ |  | ||||||
| 
 |  | ||||||
| 									if int(db.RowsAffected) >= db.Statement.ReflectValue.Len() { |  | ||||||
| 										return |  | ||||||
| 									} | 									} | ||||||
| 
 | 
 | ||||||
| 									goto BEGIN | 									_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) | ||||||
|  | 									if isZero { | ||||||
|  | 										db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) | ||||||
|  | 										insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement | ||||||
|  | 									} | ||||||
| 								} | 								} | ||||||
| 
 |  | ||||||
| 								if field.FieldType.Kind() == reflect.Ptr { |  | ||||||
| 									values[idx] = fieldValue.Addr().Interface() |  | ||||||
| 								} else { |  | ||||||
| 									reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) |  | ||||||
| 									reflectValue.Elem().Set(fieldValue.Addr()) |  | ||||||
| 									values[idx] = reflectValue.Interface() |  | ||||||
| 									resetFieldValues[idx] = fieldValue |  | ||||||
| 								} |  | ||||||
| 							} |  | ||||||
| 
 |  | ||||||
| 							db.RowsAffected++ |  | ||||||
| 							if err := rows.Scan(values...); err != nil { |  | ||||||
| 								db.AddError(err) |  | ||||||
| 							} |  | ||||||
| 
 |  | ||||||
| 							for idx, fv := range resetFieldValues { |  | ||||||
| 								if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() { |  | ||||||
| 									fv.Set(v.Elem()) |  | ||||||
| 								} |  | ||||||
| 							} |  | ||||||
| 						} |  | ||||||
| 					case reflect.Struct: |  | ||||||
| 						resetFieldValues := map[int]reflect.Value{} |  | ||||||
| 						for idx, field := range fields { |  | ||||||
| 							if field.FieldType.Kind() == reflect.Ptr { |  | ||||||
| 								values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() |  | ||||||
| 							} else { | 							} else { | ||||||
| 								reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) | 								for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||||
| 								fieldValue := field.ReflectValueOf(db.Statement.ReflectValue) | 									rv := db.Statement.ReflectValue.Index(i) | ||||||
| 								reflectValue.Elem().Set(fieldValue.Addr()) | 									if reflect.Indirect(rv).Kind() != reflect.Struct { | ||||||
| 								values[idx] = reflectValue.Interface() | 										break | ||||||
| 								resetFieldValues[idx] = fieldValue | 									} | ||||||
| 							} | 
 | ||||||
| 						} | 									if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { | ||||||
| 						if rows.Next() { | 										db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) | ||||||
| 							db.RowsAffected++ | 										insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement | ||||||
| 							db.AddError(rows.Scan(values...)) | 									} | ||||||
| 							for idx, fv := range resetFieldValues { |  | ||||||
| 								if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() { |  | ||||||
| 									fv.Set(v.Elem()) |  | ||||||
| 								} | 								} | ||||||
| 							} | 							} | ||||||
|  | 						case reflect.Struct: | ||||||
|  | 							if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { | ||||||
|  | 								db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) | ||||||
|  | 							} | ||||||
| 						} | 						} | ||||||
|  | 					} else { | ||||||
|  | 						db.AddError(err) | ||||||
| 					} | 					} | ||||||
| 				} else { |  | ||||||
| 					db.AddError(err) |  | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		} else if !db.DryRun && db.Error == nil { |  | ||||||
| 			if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { |  | ||||||
| 				db.RowsAffected, _ = result.RowsAffected() |  | ||||||
| 			} else { |  | ||||||
| 				db.AddError(err) |  | ||||||
| 			} |  | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -22,7 +22,7 @@ func Query(db *gorm.DB) { | |||||||
| 			} | 			} | ||||||
| 			defer rows.Close() | 			defer rows.Close() | ||||||
| 
 | 
 | ||||||
| 			gorm.Scan(rows, db, false) | 			gorm.Scan(rows, db, 0) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -50,40 +50,56 @@ func BeforeUpdate(db *gorm.DB) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Update(db *gorm.DB) { | func Update(config *Config) func(db *gorm.DB) { | ||||||
| 	if db.Error != nil { | 	withReturning := false | ||||||
| 		return | 	for _, clause := range config.UpdateClauses { | ||||||
| 	} | 		if clause == "RETURNING" { | ||||||
| 
 | 			withReturning = true | ||||||
| 	if db.Statement.Schema != nil && !db.Statement.Unscoped { |  | ||||||
| 		for _, c := range db.Statement.Schema.UpdateClauses { |  | ||||||
| 			db.Statement.AddClause(c) |  | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if db.Statement.SQL.String() == "" { | 	return func(db *gorm.DB) { | ||||||
| 		db.Statement.SQL.Grow(180) | 		if db.Error != nil { | ||||||
| 		db.Statement.AddClauseIfNotExists(clause.Update{}) |  | ||||||
| 		if set := ConvertToAssignments(db.Statement); len(set) != 0 { |  | ||||||
| 			db.Statement.AddClause(set) |  | ||||||
| 		} else { |  | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		db.Statement.Build(db.Statement.BuildClauses...) |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { | 		if db.Statement.Schema != nil && !db.Statement.Unscoped { | ||||||
| 		db.AddError(gorm.ErrMissingWhereClause) | 			for _, c := range db.Statement.Schema.UpdateClauses { | ||||||
| 		return | 				db.Statement.AddClause(c) | ||||||
| 	} | 			} | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 	if !db.DryRun && db.Error == nil { | 		if db.Statement.SQL.String() == "" { | ||||||
| 		result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 			db.Statement.SQL.Grow(180) | ||||||
|  | 			db.Statement.AddClauseIfNotExists(clause.Update{}) | ||||||
|  | 			if set := ConvertToAssignments(db.Statement); len(set) != 0 { | ||||||
|  | 				db.Statement.AddClause(set) | ||||||
|  | 			} else { | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			db.Statement.Build(db.Statement.BuildClauses...) | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 		if err == nil { | 		if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { | ||||||
| 			db.RowsAffected, _ = result.RowsAffected() | 			db.AddError(gorm.ErrMissingWhereClause) | ||||||
| 		} else { | 			return | ||||||
| 			db.AddError(err) | 		} | ||||||
|  | 
 | ||||||
|  | 		if !db.DryRun && db.Error == nil { | ||||||
|  | 			if _, ok := db.Statement.Clauses["RETURNING"]; withReturning && ok { | ||||||
|  | 				if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { | ||||||
|  | 					gorm.Scan(rows, db, gorm.ScanUpdate) | ||||||
|  | 					rows.Close() | ||||||
|  | 				} | ||||||
|  | 			} else { | ||||||
|  | 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
|  | 
 | ||||||
|  | 				if err == nil { | ||||||
|  | 					db.RowsAffected, _ = result.RowsAffected() | ||||||
|  | 				} else { | ||||||
|  | 					db.AddError(err) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -26,7 +26,7 @@ func (onConflict OnConflict) Build(builder Builder) { | |||||||
| 		} | 		} | ||||||
| 		builder.WriteString(`) `) | 		builder.WriteString(`) `) | ||||||
| 	} | 	} | ||||||
| 	 | 
 | ||||||
| 	if len(onConflict.TargetWhere.Exprs) > 0 { | 	if len(onConflict.TargetWhere.Exprs) > 0 { | ||||||
| 		builder.WriteString(" WHERE ") | 		builder.WriteString(" WHERE ") | ||||||
| 		onConflict.TargetWhere.Build(builder) | 		onConflict.TargetWhere.Build(builder) | ||||||
|  | |||||||
| @ -511,7 +511,7 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { | |||||||
| 		} | 		} | ||||||
| 		tx.Statement.ReflectValue = elem | 		tx.Statement.ReflectValue = elem | ||||||
| 	} | 	} | ||||||
| 	Scan(rows, tx, true) | 	Scan(rows, tx, ScanInitialized) | ||||||
| 	return tx.Error | 	return tx.Error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
							
								
								
									
										308
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										308
									
								
								scan.go
									
									
									
									
									
								
							| @ -49,13 +49,93 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Scan(rows *sql.Rows, db *DB, initialized bool) { | func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) { | ||||||
| 	columns, _ := rows.Columns() | 	for idx, column := range columns { | ||||||
| 	values := make([]interface{}, len(columns)) | 		if sch == nil { | ||||||
|  | 			values[idx] = reflectValue.Interface() | ||||||
|  | 		} else if field := sch.LookUpField(column); field != nil && field.Readable { | ||||||
|  | 			values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() | ||||||
|  | 		} else if names := strings.Split(column, "__"); len(names) > 1 { | ||||||
|  | 			if rel, ok := sch.Relationships.Relations[names[0]]; ok { | ||||||
|  | 				if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { | ||||||
|  | 					values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() | ||||||
|  | 					continue | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			values[idx] = &sql.RawBytes{} | ||||||
|  | 		} else if len(columns) == 1 { | ||||||
|  | 			sch = nil | ||||||
|  | 			values[idx] = reflectValue.Interface() | ||||||
|  | 		} else { | ||||||
|  | 			values[idx] = &sql.RawBytes{} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	db.RowsAffected++ | ||||||
|  | 	db.AddError(rows.Scan(values...)) | ||||||
|  | 
 | ||||||
|  | 	if sch != nil { | ||||||
|  | 		for idx, column := range columns { | ||||||
|  | 			if field := sch.LookUpField(column); field != nil && field.Readable { | ||||||
|  | 				field.Set(reflectValue, values[idx]) | ||||||
|  | 			} else if names := strings.Split(column, "__"); len(names) > 1 { | ||||||
|  | 				if rel, ok := sch.Relationships.Relations[names[0]]; ok { | ||||||
|  | 					if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { | ||||||
|  | 						relValue := rel.Field.ReflectValueOf(reflectValue) | ||||||
|  | 						value := reflect.ValueOf(values[idx]).Elem() | ||||||
|  | 
 | ||||||
|  | 						if relValue.Kind() == reflect.Ptr && relValue.IsNil() { | ||||||
|  | 							if value.IsNil() { | ||||||
|  | 								continue | ||||||
|  | 							} | ||||||
|  | 							relValue.Set(reflect.New(relValue.Type().Elem())) | ||||||
|  | 						} | ||||||
|  | 
 | ||||||
|  | 						field.Set(relValue, values[idx]) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type ScanMode uint8 | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	ScanInitialized         ScanMode = 1 << 0 | ||||||
|  | 	ScanUpdate                       = 1 << 1 | ||||||
|  | 	ScanOnConflictDoNothing          = 1 << 2 | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func Scan(rows *sql.Rows, db *DB, mode ScanMode) { | ||||||
|  | 	var ( | ||||||
|  | 		columns, _          = rows.Columns() | ||||||
|  | 		values              = make([]interface{}, len(columns)) | ||||||
|  | 		initialized         = mode&ScanInitialized != 0 | ||||||
|  | 		update              = mode&ScanUpdate != 0 | ||||||
|  | 		onConflictDonothing = mode&ScanOnConflictDoNothing != 0 | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
| 	db.RowsAffected = 0 | 	db.RowsAffected = 0 | ||||||
| 
 | 
 | ||||||
| 	switch dest := db.Statement.Dest.(type) { | 	switch dest := db.Statement.Dest.(type) { | ||||||
| 	case map[string]interface{}, *map[string]interface{}: | 	case map[string]interface{}, *map[string]interface{}: | ||||||
|  | 		if update && db.Statement.Schema != nil { | ||||||
|  | 			switch db.Statement.ReflectValue.Kind() { | ||||||
|  | 			case reflect.Struct: | ||||||
|  | 				fields := make([]*schema.Field, len(columns)) | ||||||
|  | 				for idx, column := range columns { | ||||||
|  | 					if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { | ||||||
|  | 						fields[idx] = field | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				if initialized || rows.Next() { | ||||||
|  | 					db.scanIntoStruct(db.Statement.Schema, rows, db.Statement.ReflectValue, values, columns, fields, nil) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		if initialized || rows.Next() { | 		if initialized || rows.Next() { | ||||||
| 			columnTypes, _ := rows.ColumnTypes() | 			columnTypes, _ := rows.ColumnTypes() | ||||||
| 			prepareValues(values, db, columnTypes, columns) | 			prepareValues(values, db, columnTypes, columns) | ||||||
| @ -71,7 +151,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { | |||||||
| 			} | 			} | ||||||
| 			scanIntoMap(mapValue, values, columns) | 			scanIntoMap(mapValue, values, columns) | ||||||
| 		} | 		} | ||||||
| 	case *[]map[string]interface{}: | 	case *[]map[string]interface{}, []map[string]interface{}: | ||||||
| 		columnTypes, _ := rows.ColumnTypes() | 		columnTypes, _ := rows.ColumnTypes() | ||||||
| 		for initialized || rows.Next() { | 		for initialized || rows.Next() { | ||||||
| 			prepareValues(values, db, columnTypes, columns) | 			prepareValues(values, db, columnTypes, columns) | ||||||
| @ -82,7 +162,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { | |||||||
| 
 | 
 | ||||||
| 			mapValue := map[string]interface{}{} | 			mapValue := map[string]interface{}{} | ||||||
| 			scanIntoMap(mapValue, values, columns) | 			scanIntoMap(mapValue, values, columns) | ||||||
| 			*dest = append(*dest, mapValue) | 			if values, ok := dest.([]map[string]interface{}); ok { | ||||||
|  | 				values = append(values, mapValue) | ||||||
|  | 			} else if values, ok := dest.(*[]map[string]interface{}); ok { | ||||||
|  | 				*values = append(*values, mapValue) | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 	case *int, *int8, *int16, *int32, *int64, | 	case *int, *int8, *int16, *int32, *int64, | ||||||
| 		*uint, *uint8, *uint16, *uint32, *uint64, *uintptr, | 		*uint, *uint8, *uint16, *uint32, *uint64, *uintptr, | ||||||
| @ -96,155 +180,109 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { | |||||||
| 			db.AddError(rows.Scan(dest)) | 			db.AddError(rows.Scan(dest)) | ||||||
| 		} | 		} | ||||||
| 	default: | 	default: | ||||||
| 		Schema := db.Statement.Schema | 		var ( | ||||||
| 		reflectValue := db.Statement.ReflectValue | 			fields       = make([]*schema.Field, len(columns)) | ||||||
|  | 			joinFields   [][2]*schema.Field | ||||||
|  | 			sch          = db.Statement.Schema | ||||||
|  | 			reflectValue = db.Statement.ReflectValue | ||||||
|  | 		) | ||||||
|  | 
 | ||||||
| 		if reflectValue.Kind() == reflect.Interface { | 		if reflectValue.Kind() == reflect.Interface { | ||||||
| 			reflectValue = reflectValue.Elem() | 			reflectValue = reflectValue.Elem() | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		reflectValueType := reflectValue.Type() | ||||||
|  | 		switch reflectValueType.Kind() { | ||||||
|  | 		case reflect.Array, reflect.Slice: | ||||||
|  | 			reflectValueType = reflectValueType.Elem() | ||||||
|  | 		} | ||||||
|  | 		isPtr := reflectValueType.Kind() == reflect.Ptr | ||||||
|  | 		if isPtr { | ||||||
|  | 			reflectValueType = reflectValueType.Elem() | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if sch != nil { | ||||||
|  | 			if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct { | ||||||
|  | 				sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			for idx, column := range columns { | ||||||
|  | 				if field := sch.LookUpField(column); field != nil && field.Readable { | ||||||
|  | 					fields[idx] = field | ||||||
|  | 				} else if names := strings.Split(column, "__"); len(names) > 1 { | ||||||
|  | 					if rel, ok := sch.Relationships.Relations[names[0]]; ok { | ||||||
|  | 						if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { | ||||||
|  | 							fields[idx] = field | ||||||
|  | 
 | ||||||
|  | 							if len(joinFields) == 0 { | ||||||
|  | 								joinFields = make([][2]*schema.Field, len(columns)) | ||||||
|  | 							} | ||||||
|  | 							joinFields[idx] = [2]*schema.Field{rel.Field, field} | ||||||
|  | 							continue | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 					values[idx] = &sql.RawBytes{} | ||||||
|  | 				} else { | ||||||
|  | 					values[idx] = &sql.RawBytes{} | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if len(columns) == 1 { | ||||||
|  | 				// isPluck
 | ||||||
|  | 				if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
 | ||||||
|  | 					reflectValueType.Kind() != reflect.Struct || // is not struct
 | ||||||
|  | 					sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
 | ||||||
|  | 					sch = nil | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		switch reflectValue.Kind() { | 		switch reflectValue.Kind() { | ||||||
| 		case reflect.Slice, reflect.Array: | 		case reflect.Slice, reflect.Array: | ||||||
| 			var ( | 			var elem reflect.Value | ||||||
| 				reflectValueType = reflectValue.Type().Elem() |  | ||||||
| 				isPtr            = reflectValueType.Kind() == reflect.Ptr |  | ||||||
| 				fields           = make([]*schema.Field, len(columns)) |  | ||||||
| 				joinFields       [][2]*schema.Field |  | ||||||
| 			) |  | ||||||
| 
 | 
 | ||||||
| 			if isPtr { | 			if !update { | ||||||
| 				reflectValueType = reflectValueType.Elem() | 				db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) |  | ||||||
| 
 |  | ||||||
| 			if Schema != nil { |  | ||||||
| 				if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { |  | ||||||
| 					Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				for idx, column := range columns { |  | ||||||
| 					if field := Schema.LookUpField(column); field != nil && field.Readable { |  | ||||||
| 						fields[idx] = field |  | ||||||
| 					} else if names := strings.Split(column, "__"); len(names) > 1 { |  | ||||||
| 						if rel, ok := Schema.Relationships.Relations[names[0]]; ok { |  | ||||||
| 							if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { |  | ||||||
| 								fields[idx] = field |  | ||||||
| 
 |  | ||||||
| 								if len(joinFields) == 0 { |  | ||||||
| 									joinFields = make([][2]*schema.Field, len(columns)) |  | ||||||
| 								} |  | ||||||
| 								joinFields[idx] = [2]*schema.Field{rel.Field, field} |  | ||||||
| 								continue |  | ||||||
| 							} |  | ||||||
| 						} |  | ||||||
| 						values[idx] = &sql.RawBytes{} |  | ||||||
| 					} else { |  | ||||||
| 						values[idx] = &sql.RawBytes{} |  | ||||||
| 					} |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			// pluck values into slice of data
 |  | ||||||
| 			isPluck := false |  | ||||||
| 			if len(fields) == 1 { |  | ||||||
| 				if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner
 |  | ||||||
| 					reflectValueType.Kind() != reflect.Struct || // is not struct
 |  | ||||||
| 					Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
 |  | ||||||
| 					isPluck = true |  | ||||||
| 				} |  | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			for initialized || rows.Next() { | 			for initialized || rows.Next() { | ||||||
|  | 			BEGIN: | ||||||
| 				initialized = false | 				initialized = false | ||||||
| 				db.RowsAffected++ |  | ||||||
| 
 | 
 | ||||||
| 				elem := reflect.New(reflectValueType) | 				if update { | ||||||
| 				if isPluck { | 					if int(db.RowsAffected) >= reflectValue.Len() { | ||||||
| 					db.AddError(rows.Scan(elem.Interface())) | 						return | ||||||
| 				} else { |  | ||||||
| 					for idx, field := range fields { |  | ||||||
| 						if field != nil { |  | ||||||
| 							values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() |  | ||||||
| 						} |  | ||||||
| 					} | 					} | ||||||
| 
 | 					elem = reflectValue.Index(int(db.RowsAffected)) | ||||||
| 					db.AddError(rows.Scan(values...)) | 					if onConflictDonothing { | ||||||
| 
 | 						for _, field := range fields { | ||||||
| 					for idx, field := range fields { | 							if _, ok := field.ValueOf(elem); !ok { | ||||||
| 						if len(joinFields) != 0 && joinFields[idx][0] != nil { | 								db.RowsAffected++ | ||||||
| 							value := reflect.ValueOf(values[idx]).Elem() | 								goto BEGIN | ||||||
| 							relValue := joinFields[idx][0].ReflectValueOf(elem) |  | ||||||
| 
 |  | ||||||
| 							if relValue.Kind() == reflect.Ptr && relValue.IsNil() { |  | ||||||
| 								if value.IsNil() { |  | ||||||
| 									continue |  | ||||||
| 								} |  | ||||||
| 								relValue.Set(reflect.New(relValue.Type().Elem())) |  | ||||||
| 							} |  | ||||||
| 
 |  | ||||||
| 							field.Set(relValue, values[idx]) |  | ||||||
| 						} else if field != nil { |  | ||||||
| 							field.Set(elem, values[idx]) |  | ||||||
| 						} |  | ||||||
| 					} |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				if isPtr { |  | ||||||
| 					reflectValue = reflect.Append(reflectValue, elem) |  | ||||||
| 				} else { |  | ||||||
| 					reflectValue = reflect.Append(reflectValue, elem.Elem()) |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			db.Statement.ReflectValue.Set(reflectValue) |  | ||||||
| 		case reflect.Struct, reflect.Ptr: |  | ||||||
| 			if reflectValue.Type() != Schema.ModelType { |  | ||||||
| 				Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			if initialized || rows.Next() { |  | ||||||
| 				for idx, column := range columns { |  | ||||||
| 					if field := Schema.LookUpField(column); field != nil && field.Readable { |  | ||||||
| 						values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() |  | ||||||
| 					} else if names := strings.Split(column, "__"); len(names) > 1 { |  | ||||||
| 						if rel, ok := Schema.Relationships.Relations[names[0]]; ok { |  | ||||||
| 							if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { |  | ||||||
| 								values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() |  | ||||||
| 								continue |  | ||||||
| 							} | 							} | ||||||
| 						} | 						} | ||||||
| 						values[idx] = &sql.RawBytes{} | 					} | ||||||
| 					} else if len(columns) == 1 { | 				} else { | ||||||
| 						values[idx] = dest | 					elem = reflect.New(reflectValueType) | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) | ||||||
|  | 
 | ||||||
|  | 				if !update { | ||||||
|  | 					if isPtr { | ||||||
|  | 						reflectValue = reflect.Append(reflectValue, elem) | ||||||
| 					} else { | 					} else { | ||||||
| 						values[idx] = &sql.RawBytes{} | 						reflectValue = reflect.Append(reflectValue, elem.Elem()) | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
|  | 			} | ||||||
| 
 | 
 | ||||||
| 				db.RowsAffected++ | 			if !update { | ||||||
| 				db.AddError(rows.Scan(values...)) | 				db.Statement.ReflectValue.Set(reflectValue) | ||||||
| 
 | 			} | ||||||
| 				for idx, column := range columns { | 		case reflect.Struct, reflect.Ptr: | ||||||
| 					if field := Schema.LookUpField(column); field != nil && field.Readable { | 			if initialized || rows.Next() { | ||||||
| 						field.Set(reflectValue, values[idx]) | 				db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) | ||||||
| 					} else if names := strings.Split(column, "__"); len(names) > 1 { |  | ||||||
| 						if rel, ok := Schema.Relationships.Relations[names[0]]; ok { |  | ||||||
| 							if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { |  | ||||||
| 								relValue := rel.Field.ReflectValueOf(reflectValue) |  | ||||||
| 								value := reflect.ValueOf(values[idx]).Elem() |  | ||||||
| 
 |  | ||||||
| 								if relValue.Kind() == reflect.Ptr && relValue.IsNil() { |  | ||||||
| 									if value.IsNil() { |  | ||||||
| 										continue |  | ||||||
| 									} |  | ||||||
| 									relValue.Set(reflect.New(relValue.Type().Elem())) |  | ||||||
| 								} |  | ||||||
| 
 |  | ||||||
| 								field.Set(relValue, values[idx]) |  | ||||||
| 							} |  | ||||||
| 						} |  | ||||||
| 					} |  | ||||||
| 				} |  | ||||||
| 			} | 			} | ||||||
| 		default: | 		default: | ||||||
| 			db.AddError(rows.Scan(dest)) | 			db.AddError(rows.Scan(dest)) | ||||||
|  | |||||||
| @ -7,9 +7,9 @@ require ( | |||||||
| 	github.com/jinzhu/now v1.1.2 | 	github.com/jinzhu/now v1.1.2 | ||||||
| 	github.com/lib/pq v1.10.3 | 	github.com/lib/pq v1.10.3 | ||||||
| 	gorm.io/driver/mysql v1.1.2 | 	gorm.io/driver/mysql v1.1.2 | ||||||
| 	gorm.io/driver/postgres v1.1.2 | 	gorm.io/driver/postgres v1.2.0 | ||||||
| 	gorm.io/driver/sqlite v1.1.6 | 	gorm.io/driver/sqlite v1.2.0 | ||||||
| 	gorm.io/driver/sqlserver v1.1.0 | 	gorm.io/driver/sqlserver v1.1.1 | ||||||
| 	gorm.io/gorm v1.21.16 | 	gorm.io/gorm v1.21.16 | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -1,9 +1,9 @@ | |||||||
| package tests_test | package tests_test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"gorm.io/gorm" |  | ||||||
| 	"gorm.io/gorm/callbacks" |  | ||||||
| 	"testing" | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"gorm.io/gorm" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestReturningWithNullToZeroValues(t *testing.T) { | func TestReturningWithNullToZeroValues(t *testing.T) { | ||||||
| @ -20,11 +20,6 @@ func TestReturningWithNullToZeroValues(t *testing.T) { | |||||||
| 			Name string `gorm:"default:null"` | 			Name string `gorm:"default:null"` | ||||||
| 		} | 		} | ||||||
| 		u1 := user{} | 		u1 := user{} | ||||||
| 		c := DB.Callback().Create().Get("gorm:create") |  | ||||||
| 		t.Cleanup(func() { |  | ||||||
| 			DB.Callback().Create().Replace("gorm:create", c) |  | ||||||
| 		}) |  | ||||||
| 		DB.Callback().Create().Replace("gorm:create", callbacks.Create(&callbacks.Config{WithReturning: true})) |  | ||||||
| 
 | 
 | ||||||
| 		if results := DB.Create(&u1); results.Error != nil { | 		if results := DB.Create(&u1); results.Error != nil { | ||||||
| 			t.Fatalf("errors happened on create: %v", results.Error) | 			t.Fatalf("errors happened on create: %v", results.Error) | ||||||
|  | |||||||
| @ -9,6 +9,7 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
|  | 	"gorm.io/gorm/clause" | ||||||
| 	"gorm.io/gorm/utils" | 	"gorm.io/gorm/utils" | ||||||
| 	. "gorm.io/gorm/utils/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| @ -166,13 +167,16 @@ func TestUpdates(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// update with gorm exprs
 | 	// update with gorm exprs
 | ||||||
| 	if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { | 	if err := DB.Debug().Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { | ||||||
| 		t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) | 		t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) | ||||||
| 	} | 	} | ||||||
| 	var user4 User | 	var user4 User | ||||||
| 	DB.First(&user4, user3.ID) | 	DB.First(&user4, user3.ID) | ||||||
| 
 | 
 | ||||||
| 	user3.Age += 100 | 	// sqlite, postgres support returning
 | ||||||
|  | 	if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { | ||||||
|  | 		user3.Age += 100 | ||||||
|  | 	} | ||||||
| 	AssertObjEqual(t, user4, user3, "UpdatedAt", "Age") | 	AssertObjEqual(t, user4, user3, "UpdatedAt", "Age") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu