Improve returning support
This commit is contained in:
		
							parent
							
								
									d3211908a0
								
							
						
					
					
						commit
						af3fbdc2fc
					
				| @ -13,7 +13,6 @@ var ( | ||||
| 
 | ||||
| type Config struct { | ||||
| 	LastInsertIDReversed bool | ||||
| 	WithReturning        bool | ||||
| 	CreateClauses        []string | ||||
| 	QueryClauses         []string | ||||
| 	UpdateClauses        []string | ||||
| @ -25,6 +24,19 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { | ||||
| 		return !db.SkipDefaultTransaction | ||||
| 	} | ||||
| 
 | ||||
| 	if len(config.CreateClauses) == 0 { | ||||
| 		config.CreateClauses = createClauses | ||||
| 	} | ||||
| 	if len(config.QueryClauses) == 0 { | ||||
| 		config.QueryClauses = queryClauses | ||||
| 	} | ||||
| 	if len(config.DeleteClauses) == 0 { | ||||
| 		config.DeleteClauses = deleteClauses | ||||
| 	} | ||||
| 	if len(config.UpdateClauses) == 0 { | ||||
| 		config.UpdateClauses = updateClauses | ||||
| 	} | ||||
| 
 | ||||
| 	createCallback := db.Callback().Create() | ||||
| 	createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) | ||||
| 	createCallback.Register("gorm:before_create", BeforeCreate) | ||||
| @ -33,18 +45,12 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { | ||||
| 	createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) | ||||
| 	createCallback.Register("gorm:after_create", AfterCreate) | ||||
| 	createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | ||||
| 	if len(config.CreateClauses) == 0 { | ||||
| 		config.CreateClauses = createClauses | ||||
| 	} | ||||
| 	createCallback.Clauses = config.CreateClauses | ||||
| 
 | ||||
| 	queryCallback := db.Callback().Query() | ||||
| 	queryCallback.Register("gorm:query", Query) | ||||
| 	queryCallback.Register("gorm:preload", Preload) | ||||
| 	queryCallback.Register("gorm:after_query", AfterQuery) | ||||
| 	if len(config.QueryClauses) == 0 { | ||||
| 		config.QueryClauses = queryClauses | ||||
| 	} | ||||
| 	queryCallback.Clauses = config.QueryClauses | ||||
| 
 | ||||
| 	deleteCallback := db.Callback().Delete() | ||||
| @ -54,9 +60,6 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { | ||||
| 	deleteCallback.Register("gorm:delete", Delete) | ||||
| 	deleteCallback.Register("gorm:after_delete", AfterDelete) | ||||
| 	deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | ||||
| 	if len(config.DeleteClauses) == 0 { | ||||
| 		config.DeleteClauses = deleteClauses | ||||
| 	} | ||||
| 	deleteCallback.Clauses = config.DeleteClauses | ||||
| 
 | ||||
| 	updateCallback := db.Callback().Update() | ||||
| @ -64,13 +67,10 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { | ||||
| 	updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) | ||||
| 	updateCallback.Register("gorm:before_update", BeforeUpdate) | ||||
| 	updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) | ||||
| 	updateCallback.Register("gorm:update", Update) | ||||
| 	updateCallback.Register("gorm:update", Update(config)) | ||||
| 	updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) | ||||
| 	updateCallback.Register("gorm:after_update", AfterUpdate) | ||||
| 	updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | ||||
| 	if len(config.UpdateClauses) == 0 { | ||||
| 		config.UpdateClauses = updateClauses | ||||
| 	} | ||||
| 	updateCallback.Clauses = config.UpdateClauses | ||||
| 
 | ||||
| 	rowCallback := db.Callback().Row() | ||||
|  | ||||
| @ -31,21 +31,38 @@ func BeforeCreate(db *gorm.DB) { | ||||
| } | ||||
| 
 | ||||
| func Create(config *Config) func(db *gorm.DB) { | ||||
| 	if config.WithReturning { | ||||
| 		return CreateWithReturning | ||||
| 	withReturning := false | ||||
| 	for _, clause := range config.CreateClauses { | ||||
| 		if clause == "RETURNING" { | ||||
| 			withReturning = true | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return func(db *gorm.DB) { | ||||
| 		if db.Error != nil { | ||||
| 			return | ||||
| 		} | ||||
| 		onReturning := false | ||||
| 
 | ||||
| 		if db.Statement.Schema != nil && !db.Statement.Unscoped { | ||||
| 		if db.Statement.Schema != nil { | ||||
| 			if !db.Statement.Unscoped { | ||||
| 				for _, c := range db.Statement.Schema.CreateClauses { | ||||
| 					db.Statement.AddClause(c) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if withReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { | ||||
| 				onReturning = true | ||||
| 				if _, ok := db.Statement.Clauses["RETURNING"]; !ok { | ||||
| 					fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue)) | ||||
| 					for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { | ||||
| 						fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) | ||||
| 					} | ||||
| 					db.Statement.AddClause(clause.Returning{Columns: fromColumns}) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if db.Statement.SQL.String() == "" { | ||||
| 			db.Statement.SQL.Grow(180) | ||||
| 			db.Statement.AddClauseIfNotExists(clause.Insert{}) | ||||
| @ -55,6 +72,21 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 		} | ||||
| 
 | ||||
| 		if !db.DryRun && db.Error == nil { | ||||
| 			if onReturning { | ||||
| 				doNothing := false | ||||
| 				if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { | ||||
| 					onConflict, _ := c.Expression.(clause.OnConflict) | ||||
| 					doNothing = onConflict.DoNothing | ||||
| 				} | ||||
| 				if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { | ||||
| 					if doNothing { | ||||
| 						gorm.Scan(rows, db, gorm.ScanUpdate|gorm.ScanOnConflictDoNothing) | ||||
| 					} else { | ||||
| 						gorm.Scan(rows, db, gorm.ScanUpdate) | ||||
| 					} | ||||
| 					rows.Close() | ||||
| 				} | ||||
| 			} else { | ||||
| 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 				if err != nil { | ||||
| @ -63,7 +95,6 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 				} | ||||
| 
 | ||||
| 				db.RowsAffected, _ = result.RowsAffected() | ||||
| 
 | ||||
| 				if db.RowsAffected != 0 && db.Statement.Schema != nil && | ||||
| 					db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { | ||||
| 					if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { | ||||
| @ -106,130 +137,6 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| } | ||||
| 
 | ||||
| func CreateWithReturning(db *gorm.DB) { | ||||
| 	if db.Error == nil { | ||||
| 		if db.Statement.Schema != nil && !db.Statement.Unscoped { | ||||
| 			for _, c := range db.Statement.Schema.CreateClauses { | ||||
| 				db.Statement.AddClause(c) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if db.Statement.SQL.String() == "" { | ||||
| 			db.Statement.AddClauseIfNotExists(clause.Insert{}) | ||||
| 			db.Statement.AddClause(ConvertToCreateValues(db.Statement)) | ||||
| 
 | ||||
| 			db.Statement.Build(db.Statement.BuildClauses...) | ||||
| 		} | ||||
| 
 | ||||
| 		if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { | ||||
| 			db.Statement.WriteString(" RETURNING ") | ||||
| 
 | ||||
| 			var ( | ||||
| 				fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) | ||||
| 				values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) | ||||
| 			) | ||||
| 
 | ||||
| 			for idx, field := range sch.FieldsWithDefaultDBValue { | ||||
| 				if idx > 0 { | ||||
| 					db.Statement.WriteByte(',') | ||||
| 				} | ||||
| 
 | ||||
| 				fields[idx] = field | ||||
| 				db.Statement.WriteQuoted(field.DBName) | ||||
| 			} | ||||
| 
 | ||||
| 			if !db.DryRun && db.Error == nil { | ||||
| 				db.RowsAffected = 0 | ||||
| 				rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 				if err == nil { | ||||
| 					defer rows.Close() | ||||
| 
 | ||||
| 					switch db.Statement.ReflectValue.Kind() { | ||||
| 					case reflect.Slice, reflect.Array: | ||||
| 						var ( | ||||
| 							c                = db.Statement.Clauses["ON CONFLICT"] | ||||
| 							onConflict, _    = c.Expression.(clause.OnConflict) | ||||
| 							resetFieldValues = map[int]reflect.Value{} | ||||
| 						) | ||||
| 
 | ||||
| 						for rows.Next() { | ||||
| 						BEGIN: | ||||
| 							reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected)) | ||||
| 							if reflect.Indirect(reflectValue).Kind() != reflect.Struct { | ||||
| 								break | ||||
| 							} | ||||
| 
 | ||||
| 							for idx, field := range fields { | ||||
| 								fieldValue := field.ReflectValueOf(reflectValue) | ||||
| 
 | ||||
| 								if onConflict.DoNothing && !fieldValue.IsZero() { | ||||
| 									db.RowsAffected++ | ||||
| 
 | ||||
| 									if int(db.RowsAffected) >= db.Statement.ReflectValue.Len() { | ||||
| 										return | ||||
| 									} | ||||
| 
 | ||||
| 									goto BEGIN | ||||
| 								} | ||||
| 
 | ||||
| 								if field.FieldType.Kind() == reflect.Ptr { | ||||
| 									values[idx] = fieldValue.Addr().Interface() | ||||
| 								} else { | ||||
| 									reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) | ||||
| 									reflectValue.Elem().Set(fieldValue.Addr()) | ||||
| 									values[idx] = reflectValue.Interface() | ||||
| 									resetFieldValues[idx] = fieldValue | ||||
| 								} | ||||
| 							} | ||||
| 
 | ||||
| 							db.RowsAffected++ | ||||
| 							if err := rows.Scan(values...); err != nil { | ||||
| 								db.AddError(err) | ||||
| 							} | ||||
| 
 | ||||
| 							for idx, fv := range resetFieldValues { | ||||
| 								if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() { | ||||
| 									fv.Set(v.Elem()) | ||||
| 								} | ||||
| 							} | ||||
| 						} | ||||
| 					case reflect.Struct: | ||||
| 						resetFieldValues := map[int]reflect.Value{} | ||||
| 						for idx, field := range fields { | ||||
| 							if field.FieldType.Kind() == reflect.Ptr { | ||||
| 								values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() | ||||
| 							} else { | ||||
| 								reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) | ||||
| 								fieldValue := field.ReflectValueOf(db.Statement.ReflectValue) | ||||
| 								reflectValue.Elem().Set(fieldValue.Addr()) | ||||
| 								values[idx] = reflectValue.Interface() | ||||
| 								resetFieldValues[idx] = fieldValue | ||||
| 							} | ||||
| 						} | ||||
| 						if rows.Next() { | ||||
| 							db.RowsAffected++ | ||||
| 							db.AddError(rows.Scan(values...)) | ||||
| 							for idx, fv := range resetFieldValues { | ||||
| 								if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() { | ||||
| 									fv.Set(v.Elem()) | ||||
| 								} | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 				} else { | ||||
| 					db.AddError(err) | ||||
| 				} | ||||
| 			} | ||||
| 		} else if !db.DryRun && db.Error == nil { | ||||
| 			if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { | ||||
| 				db.RowsAffected, _ = result.RowsAffected() | ||||
| 			} else { | ||||
| 				db.AddError(err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -22,7 +22,7 @@ func Query(db *gorm.DB) { | ||||
| 			} | ||||
| 			defer rows.Close() | ||||
| 
 | ||||
| 			gorm.Scan(rows, db, false) | ||||
| 			gorm.Scan(rows, db, 0) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -50,7 +50,15 @@ func BeforeUpdate(db *gorm.DB) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func Update(db *gorm.DB) { | ||||
| func Update(config *Config) func(db *gorm.DB) { | ||||
| 	withReturning := false | ||||
| 	for _, clause := range config.UpdateClauses { | ||||
| 		if clause == "RETURNING" { | ||||
| 			withReturning = true | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return func(db *gorm.DB) { | ||||
| 		if db.Error != nil { | ||||
| 			return | ||||
| 		} | ||||
| @ -78,6 +86,12 @@ func Update(db *gorm.DB) { | ||||
| 		} | ||||
| 
 | ||||
| 		if !db.DryRun && db.Error == nil { | ||||
| 			if _, ok := db.Statement.Clauses["RETURNING"]; withReturning && ok { | ||||
| 				if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { | ||||
| 					gorm.Scan(rows, db, gorm.ScanUpdate) | ||||
| 					rows.Close() | ||||
| 				} | ||||
| 			} else { | ||||
| 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 				if err == nil { | ||||
| @ -86,6 +100,8 @@ func Update(db *gorm.DB) { | ||||
| 					db.AddError(err) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func AfterUpdate(db *gorm.DB) { | ||||
|  | ||||
| @ -511,7 +511,7 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { | ||||
| 		} | ||||
| 		tx.Statement.ReflectValue = elem | ||||
| 	} | ||||
| 	Scan(rows, tx, true) | ||||
| 	Scan(rows, tx, ScanInitialized) | ||||
| 	return tx.Error | ||||
| } | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										356
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										356
									
								
								scan.go
									
									
									
									
									
								
							| @ -49,165 +49,14 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func Scan(rows *sql.Rows, db *DB, initialized bool) { | ||||
| 	columns, _ := rows.Columns() | ||||
| 	values := make([]interface{}, len(columns)) | ||||
| 	db.RowsAffected = 0 | ||||
| 
 | ||||
| 	switch dest := db.Statement.Dest.(type) { | ||||
| 	case map[string]interface{}, *map[string]interface{}: | ||||
| 		if initialized || rows.Next() { | ||||
| 			columnTypes, _ := rows.ColumnTypes() | ||||
| 			prepareValues(values, db, columnTypes, columns) | ||||
| 
 | ||||
| 			db.RowsAffected++ | ||||
| 			db.AddError(rows.Scan(values...)) | ||||
| 
 | ||||
| 			mapValue, ok := dest.(map[string]interface{}) | ||||
| 			if !ok { | ||||
| 				if v, ok := dest.(*map[string]interface{}); ok { | ||||
| 					mapValue = *v | ||||
| 				} | ||||
| 			} | ||||
| 			scanIntoMap(mapValue, values, columns) | ||||
| 		} | ||||
| 	case *[]map[string]interface{}: | ||||
| 		columnTypes, _ := rows.ColumnTypes() | ||||
| 		for initialized || rows.Next() { | ||||
| 			prepareValues(values, db, columnTypes, columns) | ||||
| 
 | ||||
| 			initialized = false | ||||
| 			db.RowsAffected++ | ||||
| 			db.AddError(rows.Scan(values...)) | ||||
| 
 | ||||
| 			mapValue := map[string]interface{}{} | ||||
| 			scanIntoMap(mapValue, values, columns) | ||||
| 			*dest = append(*dest, mapValue) | ||||
| 		} | ||||
| 	case *int, *int8, *int16, *int32, *int64, | ||||
| 		*uint, *uint8, *uint16, *uint32, *uint64, *uintptr, | ||||
| 		*float32, *float64, | ||||
| 		*bool, *string, *time.Time, | ||||
| 		*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64, | ||||
| 		*sql.NullBool, *sql.NullString, *sql.NullTime: | ||||
| 		for initialized || rows.Next() { | ||||
| 			initialized = false | ||||
| 			db.RowsAffected++ | ||||
| 			db.AddError(rows.Scan(dest)) | ||||
| 		} | ||||
| 	default: | ||||
| 		Schema := db.Statement.Schema | ||||
| 		reflectValue := db.Statement.ReflectValue | ||||
| 		if reflectValue.Kind() == reflect.Interface { | ||||
| 			reflectValue = reflectValue.Elem() | ||||
| 		} | ||||
| 
 | ||||
| 		switch reflectValue.Kind() { | ||||
| 		case reflect.Slice, reflect.Array: | ||||
| 			var ( | ||||
| 				reflectValueType = reflectValue.Type().Elem() | ||||
| 				isPtr            = reflectValueType.Kind() == reflect.Ptr | ||||
| 				fields           = make([]*schema.Field, len(columns)) | ||||
| 				joinFields       [][2]*schema.Field | ||||
| 			) | ||||
| 
 | ||||
| 			if isPtr { | ||||
| 				reflectValueType = reflectValueType.Elem() | ||||
| 			} | ||||
| 
 | ||||
| 			db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) | ||||
| 
 | ||||
| 			if Schema != nil { | ||||
| 				if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { | ||||
| 					Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) | ||||
| 				} | ||||
| 
 | ||||
| func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) { | ||||
| 	for idx, column := range columns { | ||||
| 					if field := Schema.LookUpField(column); field != nil && field.Readable { | ||||
| 						fields[idx] = field | ||||
| 					} else if names := strings.Split(column, "__"); len(names) > 1 { | ||||
| 						if rel, ok := Schema.Relationships.Relations[names[0]]; ok { | ||||
| 							if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { | ||||
| 								fields[idx] = field | ||||
| 
 | ||||
| 								if len(joinFields) == 0 { | ||||
| 									joinFields = make([][2]*schema.Field, len(columns)) | ||||
| 								} | ||||
| 								joinFields[idx] = [2]*schema.Field{rel.Field, field} | ||||
| 								continue | ||||
| 							} | ||||
| 						} | ||||
| 						values[idx] = &sql.RawBytes{} | ||||
| 					} else { | ||||
| 						values[idx] = &sql.RawBytes{} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			// pluck values into slice of data
 | ||||
| 			isPluck := false | ||||
| 			if len(fields) == 1 { | ||||
| 				if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner
 | ||||
| 					reflectValueType.Kind() != reflect.Struct || // is not struct
 | ||||
| 					Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
 | ||||
| 					isPluck = true | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			for initialized || rows.Next() { | ||||
| 				initialized = false | ||||
| 				db.RowsAffected++ | ||||
| 
 | ||||
| 				elem := reflect.New(reflectValueType) | ||||
| 				if isPluck { | ||||
| 					db.AddError(rows.Scan(elem.Interface())) | ||||
| 				} else { | ||||
| 					for idx, field := range fields { | ||||
| 						if field != nil { | ||||
| 							values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() | ||||
| 						} | ||||
| 					} | ||||
| 
 | ||||
| 					db.AddError(rows.Scan(values...)) | ||||
| 
 | ||||
| 					for idx, field := range fields { | ||||
| 						if len(joinFields) != 0 && joinFields[idx][0] != nil { | ||||
| 							value := reflect.ValueOf(values[idx]).Elem() | ||||
| 							relValue := joinFields[idx][0].ReflectValueOf(elem) | ||||
| 
 | ||||
| 							if relValue.Kind() == reflect.Ptr && relValue.IsNil() { | ||||
| 								if value.IsNil() { | ||||
| 									continue | ||||
| 								} | ||||
| 								relValue.Set(reflect.New(relValue.Type().Elem())) | ||||
| 							} | ||||
| 
 | ||||
| 							field.Set(relValue, values[idx]) | ||||
| 						} else if field != nil { | ||||
| 							field.Set(elem, values[idx]) | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				if isPtr { | ||||
| 					reflectValue = reflect.Append(reflectValue, elem) | ||||
| 				} else { | ||||
| 					reflectValue = reflect.Append(reflectValue, elem.Elem()) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			db.Statement.ReflectValue.Set(reflectValue) | ||||
| 		case reflect.Struct, reflect.Ptr: | ||||
| 			if reflectValue.Type() != Schema.ModelType { | ||||
| 				Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) | ||||
| 			} | ||||
| 
 | ||||
| 			if initialized || rows.Next() { | ||||
| 				for idx, column := range columns { | ||||
| 					if field := Schema.LookUpField(column); field != nil && field.Readable { | ||||
| 		if sch == nil { | ||||
| 			values[idx] = reflectValue.Interface() | ||||
| 		} else if field := sch.LookUpField(column); field != nil && field.Readable { | ||||
| 			values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() | ||||
| 		} else if names := strings.Split(column, "__"); len(names) > 1 { | ||||
| 						if rel, ok := Schema.Relationships.Relations[names[0]]; ok { | ||||
| 			if rel, ok := sch.Relationships.Relations[names[0]]; ok { | ||||
| 				if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { | ||||
| 					values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() | ||||
| 					continue | ||||
| @ -215,7 +64,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { | ||||
| 			} | ||||
| 			values[idx] = &sql.RawBytes{} | ||||
| 		} else if len(columns) == 1 { | ||||
| 						values[idx] = dest | ||||
| 			sch = nil | ||||
| 			values[idx] = reflectValue.Interface() | ||||
| 		} else { | ||||
| 			values[idx] = &sql.RawBytes{} | ||||
| 		} | ||||
| @ -224,11 +74,12 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { | ||||
| 	db.RowsAffected++ | ||||
| 	db.AddError(rows.Scan(values...)) | ||||
| 
 | ||||
| 	if sch != nil { | ||||
| 		for idx, column := range columns { | ||||
| 					if field := Schema.LookUpField(column); field != nil && field.Readable { | ||||
| 			if field := sch.LookUpField(column); field != nil && field.Readable { | ||||
| 				field.Set(reflectValue, values[idx]) | ||||
| 			} else if names := strings.Split(column, "__"); len(names) > 1 { | ||||
| 						if rel, ok := Schema.Relationships.Relations[names[0]]; ok { | ||||
| 				if rel, ok := sch.Relationships.Relations[names[0]]; ok { | ||||
| 					if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { | ||||
| 						relValue := rel.Field.ReflectValueOf(reflectValue) | ||||
| 						value := reflect.ValueOf(values[idx]).Elem() | ||||
| @ -246,6 +97,193 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type ScanMode uint8 | ||||
| 
 | ||||
| const ( | ||||
| 	ScanInitialized         ScanMode = 1 << 0 | ||||
| 	ScanUpdate                       = 1 << 1 | ||||
| 	ScanOnConflictDoNothing          = 1 << 2 | ||||
| ) | ||||
| 
 | ||||
| func Scan(rows *sql.Rows, db *DB, mode ScanMode) { | ||||
| 	var ( | ||||
| 		columns, _          = rows.Columns() | ||||
| 		values              = make([]interface{}, len(columns)) | ||||
| 		initialized         = mode&ScanInitialized != 0 | ||||
| 		update              = mode&ScanUpdate != 0 | ||||
| 		onConflictDonothing = mode&ScanOnConflictDoNothing != 0 | ||||
| 	) | ||||
| 
 | ||||
| 	db.RowsAffected = 0 | ||||
| 
 | ||||
| 	switch dest := db.Statement.Dest.(type) { | ||||
| 	case map[string]interface{}, *map[string]interface{}: | ||||
| 		if update && db.Statement.Schema != nil { | ||||
| 			switch db.Statement.ReflectValue.Kind() { | ||||
| 			case reflect.Struct: | ||||
| 				fields := make([]*schema.Field, len(columns)) | ||||
| 				for idx, column := range columns { | ||||
| 					if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { | ||||
| 						fields[idx] = field | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				if initialized || rows.Next() { | ||||
| 					db.scanIntoStruct(db.Statement.Schema, rows, db.Statement.ReflectValue, values, columns, fields, nil) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if initialized || rows.Next() { | ||||
| 			columnTypes, _ := rows.ColumnTypes() | ||||
| 			prepareValues(values, db, columnTypes, columns) | ||||
| 
 | ||||
| 			db.RowsAffected++ | ||||
| 			db.AddError(rows.Scan(values...)) | ||||
| 
 | ||||
| 			mapValue, ok := dest.(map[string]interface{}) | ||||
| 			if !ok { | ||||
| 				if v, ok := dest.(*map[string]interface{}); ok { | ||||
| 					mapValue = *v | ||||
| 				} | ||||
| 			} | ||||
| 			scanIntoMap(mapValue, values, columns) | ||||
| 		} | ||||
| 	case *[]map[string]interface{}, []map[string]interface{}: | ||||
| 		columnTypes, _ := rows.ColumnTypes() | ||||
| 		for initialized || rows.Next() { | ||||
| 			prepareValues(values, db, columnTypes, columns) | ||||
| 
 | ||||
| 			initialized = false | ||||
| 			db.RowsAffected++ | ||||
| 			db.AddError(rows.Scan(values...)) | ||||
| 
 | ||||
| 			mapValue := map[string]interface{}{} | ||||
| 			scanIntoMap(mapValue, values, columns) | ||||
| 			if values, ok := dest.([]map[string]interface{}); ok { | ||||
| 				values = append(values, mapValue) | ||||
| 			} else if values, ok := dest.(*[]map[string]interface{}); ok { | ||||
| 				*values = append(*values, mapValue) | ||||
| 			} | ||||
| 		} | ||||
| 	case *int, *int8, *int16, *int32, *int64, | ||||
| 		*uint, *uint8, *uint16, *uint32, *uint64, *uintptr, | ||||
| 		*float32, *float64, | ||||
| 		*bool, *string, *time.Time, | ||||
| 		*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64, | ||||
| 		*sql.NullBool, *sql.NullString, *sql.NullTime: | ||||
| 		for initialized || rows.Next() { | ||||
| 			initialized = false | ||||
| 			db.RowsAffected++ | ||||
| 			db.AddError(rows.Scan(dest)) | ||||
| 		} | ||||
| 	default: | ||||
| 		var ( | ||||
| 			fields       = make([]*schema.Field, len(columns)) | ||||
| 			joinFields   [][2]*schema.Field | ||||
| 			sch          = db.Statement.Schema | ||||
| 			reflectValue = db.Statement.ReflectValue | ||||
| 		) | ||||
| 
 | ||||
| 		if reflectValue.Kind() == reflect.Interface { | ||||
| 			reflectValue = reflectValue.Elem() | ||||
| 		} | ||||
| 
 | ||||
| 		reflectValueType := reflectValue.Type() | ||||
| 		switch reflectValueType.Kind() { | ||||
| 		case reflect.Array, reflect.Slice: | ||||
| 			reflectValueType = reflectValueType.Elem() | ||||
| 		} | ||||
| 		isPtr := reflectValueType.Kind() == reflect.Ptr | ||||
| 		if isPtr { | ||||
| 			reflectValueType = reflectValueType.Elem() | ||||
| 		} | ||||
| 
 | ||||
| 		if sch != nil { | ||||
| 			if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct { | ||||
| 				sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) | ||||
| 			} | ||||
| 
 | ||||
| 			for idx, column := range columns { | ||||
| 				if field := sch.LookUpField(column); field != nil && field.Readable { | ||||
| 					fields[idx] = field | ||||
| 				} else if names := strings.Split(column, "__"); len(names) > 1 { | ||||
| 					if rel, ok := sch.Relationships.Relations[names[0]]; ok { | ||||
| 						if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { | ||||
| 							fields[idx] = field | ||||
| 
 | ||||
| 							if len(joinFields) == 0 { | ||||
| 								joinFields = make([][2]*schema.Field, len(columns)) | ||||
| 							} | ||||
| 							joinFields[idx] = [2]*schema.Field{rel.Field, field} | ||||
| 							continue | ||||
| 						} | ||||
| 					} | ||||
| 					values[idx] = &sql.RawBytes{} | ||||
| 				} else { | ||||
| 					values[idx] = &sql.RawBytes{} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if len(columns) == 1 { | ||||
| 				// isPluck
 | ||||
| 				if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
 | ||||
| 					reflectValueType.Kind() != reflect.Struct || // is not struct
 | ||||
| 					sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
 | ||||
| 					sch = nil | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		switch reflectValue.Kind() { | ||||
| 		case reflect.Slice, reflect.Array: | ||||
| 			var elem reflect.Value | ||||
| 
 | ||||
| 			if !update { | ||||
| 				db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) | ||||
| 			} | ||||
| 
 | ||||
| 			for initialized || rows.Next() { | ||||
| 			BEGIN: | ||||
| 				initialized = false | ||||
| 
 | ||||
| 				if update { | ||||
| 					if int(db.RowsAffected) >= reflectValue.Len() { | ||||
| 						return | ||||
| 					} | ||||
| 					elem = reflectValue.Index(int(db.RowsAffected)) | ||||
| 					if onConflictDonothing { | ||||
| 						for _, field := range fields { | ||||
| 							if _, ok := field.ValueOf(elem); !ok { | ||||
| 								db.RowsAffected++ | ||||
| 								goto BEGIN | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 				} else { | ||||
| 					elem = reflect.New(reflectValueType) | ||||
| 				} | ||||
| 
 | ||||
| 				db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) | ||||
| 
 | ||||
| 				if !update { | ||||
| 					if isPtr { | ||||
| 						reflectValue = reflect.Append(reflectValue, elem) | ||||
| 					} else { | ||||
| 						reflectValue = reflect.Append(reflectValue, elem.Elem()) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if !update { | ||||
| 				db.Statement.ReflectValue.Set(reflectValue) | ||||
| 			} | ||||
| 		case reflect.Struct, reflect.Ptr: | ||||
| 			if initialized || rows.Next() { | ||||
| 				db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) | ||||
| 			} | ||||
| 		default: | ||||
| 			db.AddError(rows.Scan(dest)) | ||||
| 		} | ||||
|  | ||||
| @ -7,9 +7,9 @@ require ( | ||||
| 	github.com/jinzhu/now v1.1.2 | ||||
| 	github.com/lib/pq v1.10.3 | ||||
| 	gorm.io/driver/mysql v1.1.2 | ||||
| 	gorm.io/driver/postgres v1.1.2 | ||||
| 	gorm.io/driver/sqlite v1.1.6 | ||||
| 	gorm.io/driver/sqlserver v1.1.0 | ||||
| 	gorm.io/driver/postgres v1.2.0 | ||||
| 	gorm.io/driver/sqlite v1.2.0 | ||||
| 	gorm.io/driver/sqlserver v1.1.1 | ||||
| 	gorm.io/gorm v1.21.16 | ||||
| ) | ||||
| 
 | ||||
|  | ||||
| @ -1,9 +1,9 @@ | ||||
| package tests_test | ||||
| 
 | ||||
| import ( | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/callbacks" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
| 
 | ||||
| func TestReturningWithNullToZeroValues(t *testing.T) { | ||||
| @ -20,11 +20,6 @@ func TestReturningWithNullToZeroValues(t *testing.T) { | ||||
| 			Name string `gorm:"default:null"` | ||||
| 		} | ||||
| 		u1 := user{} | ||||
| 		c := DB.Callback().Create().Get("gorm:create") | ||||
| 		t.Cleanup(func() { | ||||
| 			DB.Callback().Create().Replace("gorm:create", c) | ||||
| 		}) | ||||
| 		DB.Callback().Create().Replace("gorm:create", callbacks.Create(&callbacks.Config{WithReturning: true})) | ||||
| 
 | ||||
| 		if results := DB.Create(&u1); results.Error != nil { | ||||
| 			t.Fatalf("errors happened on create: %v", results.Error) | ||||
|  | ||||
| @ -9,6 +9,7 @@ import ( | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/utils" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| @ -166,13 +167,16 @@ func TestUpdates(t *testing.T) { | ||||
| 	} | ||||
| 
 | ||||
| 	// update with gorm exprs
 | ||||
| 	if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { | ||||
| 	if err := DB.Debug().Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { | ||||
| 		t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) | ||||
| 	} | ||||
| 	var user4 User | ||||
| 	DB.First(&user4, user3.ID) | ||||
| 
 | ||||
| 	// sqlite, postgres support returning
 | ||||
| 	if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { | ||||
| 		user3.Age += 100 | ||||
| 	} | ||||
| 	AssertObjEqual(t, user4, user3, "UpdatedAt", "Age") | ||||
| } | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu