Better support for scan into map, fix unfriendly data type for interface, close #3351
This commit is contained in:
		
							parent
							
								
									e6f4b711a7
								
							
						
					
					
						commit
						e73147fa8e
					
				
							
								
								
									
										72
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										72
									
								
								scan.go
									
									
									
									
									
								
							| @ -2,12 +2,52 @@ package gorm | |||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"database/sql" | 	"database/sql" | ||||||
|  | 	"database/sql/driver" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm/schema" | 	"gorm.io/gorm/schema" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { | ||||||
|  | 	if db.Statement.Schema != nil { | ||||||
|  | 		for idx, name := range columns { | ||||||
|  | 			if field := db.Statement.Schema.LookUpField(name); field != nil { | ||||||
|  | 				values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			values[idx] = new(interface{}) | ||||||
|  | 		} | ||||||
|  | 	} else if len(columnTypes) > 0 { | ||||||
|  | 		for idx, columnType := range columnTypes { | ||||||
|  | 			if columnType.ScanType() != nil { | ||||||
|  | 				values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface() | ||||||
|  | 			} else { | ||||||
|  | 				values[idx] = new(interface{}) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		for idx := range columns { | ||||||
|  | 			values[idx] = new(interface{}) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) { | ||||||
|  | 	for idx, column := range columns { | ||||||
|  | 		if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() { | ||||||
|  | 			mapValue[column] = reflectValue.Interface() | ||||||
|  | 			if valuer, ok := mapValue[column].(driver.Valuer); ok { | ||||||
|  | 				mapValue[column], _ = valuer.Value() | ||||||
|  | 			} else if b, ok := mapValue[column].(sql.RawBytes); ok { | ||||||
|  | 				mapValue[column] = string(b) | ||||||
|  | 			} | ||||||
|  | 		} else { | ||||||
|  | 			mapValue[column] = nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func Scan(rows *sql.Rows, db *DB, initialized bool) { | func Scan(rows *sql.Rows, db *DB, initialized bool) { | ||||||
| 	columns, _ := rows.Columns() | 	columns, _ := rows.Columns() | ||||||
| 	values := make([]interface{}, len(columns)) | 	values := make([]interface{}, len(columns)) | ||||||
| @ -15,9 +55,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { | |||||||
| 	switch dest := db.Statement.Dest.(type) { | 	switch dest := db.Statement.Dest.(type) { | ||||||
| 	case map[string]interface{}, *map[string]interface{}: | 	case map[string]interface{}, *map[string]interface{}: | ||||||
| 		if initialized || rows.Next() { | 		if initialized || rows.Next() { | ||||||
| 			for idx := range columns { | 			columnTypes, _ := rows.ColumnTypes() | ||||||
| 				values[idx] = new(interface{}) | 			prepareValues(values, db, columnTypes, columns) | ||||||
| 			} |  | ||||||
| 
 | 
 | ||||||
| 			db.RowsAffected++ | 			db.RowsAffected++ | ||||||
| 			db.AddError(rows.Scan(values...)) | 			db.AddError(rows.Scan(values...)) | ||||||
| @ -28,38 +67,19 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { | |||||||
| 					mapValue = *v | 					mapValue = *v | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 			scanIntoMap(mapValue, values, columns) | ||||||
| 			for idx, column := range columns { |  | ||||||
| 				if v, ok := values[idx].(*interface{}); ok { |  | ||||||
| 					if v == nil { |  | ||||||
| 						mapValue[column] = nil |  | ||||||
| 					} else { |  | ||||||
| 						mapValue[column] = *v |  | ||||||
| 					} |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} | 		} | ||||||
| 	case *[]map[string]interface{}: | 	case *[]map[string]interface{}: | ||||||
|  | 		columnTypes, _ := rows.ColumnTypes() | ||||||
| 		for initialized || rows.Next() { | 		for initialized || rows.Next() { | ||||||
| 			for idx := range columns { | 			prepareValues(values, db, columnTypes, columns) | ||||||
| 				values[idx] = new(interface{}) |  | ||||||
| 			} |  | ||||||
| 
 | 
 | ||||||
| 			initialized = false | 			initialized = false | ||||||
| 			db.RowsAffected++ | 			db.RowsAffected++ | ||||||
| 			db.AddError(rows.Scan(values...)) | 			db.AddError(rows.Scan(values...)) | ||||||
| 
 | 
 | ||||||
| 			mapValue := map[string]interface{}{} | 			mapValue := map[string]interface{}{} | ||||||
| 			for idx, column := range columns { | 			scanIntoMap(mapValue, values, columns) | ||||||
| 				if v, ok := values[idx].(*interface{}); ok { |  | ||||||
| 					if v == nil { |  | ||||||
| 						mapValue[column] = nil |  | ||||||
| 					} else { |  | ||||||
| 						mapValue[column] = *v |  | ||||||
| 					} |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			*dest = append(*dest, mapValue) | 			*dest = append(*dest, mapValue) | ||||||
| 		} | 		} | ||||||
| 	case *int, *int64, *uint, *uint64, *float32, *float64: | 	case *int, *int64, *uint, *uint64, *float32, *float64: | ||||||
|  | |||||||
| @ -6,6 +6,7 @@ import ( | |||||||
| 	"regexp" | 	"regexp" | ||||||
| 	"sort" | 	"sort" | ||||||
| 	"strconv" | 	"strconv" | ||||||
|  | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| @ -61,6 +62,54 @@ func TestFind(t *testing.T) { | |||||||
| 			for _, name := range []string{"Name", "Age", "Birthday"} { | 			for _, name := range []string{"Name", "Age", "Birthday"} { | ||||||
| 				t.Run(name, func(t *testing.T) { | 				t.Run(name, func(t *testing.T) { | ||||||
| 					dbName := DB.NamingStrategy.ColumnName("", name) | 					dbName := DB.NamingStrategy.ColumnName("", name) | ||||||
|  | 
 | ||||||
|  | 					switch name { | ||||||
|  | 					case "Name": | ||||||
|  | 						if _, ok := first[dbName].(string); !ok { | ||||||
|  | 							t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) | ||||||
|  | 						} | ||||||
|  | 					case "Age": | ||||||
|  | 						if _, ok := first[dbName].(uint); !ok { | ||||||
|  | 							t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) | ||||||
|  | 						} | ||||||
|  | 					case "Birthday": | ||||||
|  | 						if _, ok := first[dbName].(*time.Time); !ok { | ||||||
|  | 							t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
|  | 					reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) | ||||||
|  | 					AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) | ||||||
|  | 				}) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("FirstMapWithTable", func(t *testing.T) { | ||||||
|  | 		var first = map[string]interface{}{} | ||||||
|  | 		if err := DB.Table("users").Where("name = ?", "find").Find(first).Error; err != nil { | ||||||
|  | 			t.Errorf("errors happened when query first: %v", err) | ||||||
|  | 		} else { | ||||||
|  | 			for _, name := range []string{"Name", "Age", "Birthday"} { | ||||||
|  | 				t.Run(name, func(t *testing.T) { | ||||||
|  | 					dbName := DB.NamingStrategy.ColumnName("", name) | ||||||
|  | 					resultType := reflect.ValueOf(first[dbName]).Type().Name() | ||||||
|  | 
 | ||||||
|  | 					switch name { | ||||||
|  | 					case "Name": | ||||||
|  | 						if !strings.Contains(resultType, "string") { | ||||||
|  | 							t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) | ||||||
|  | 						} | ||||||
|  | 					case "Age": | ||||||
|  | 						if !strings.Contains(resultType, "int") { | ||||||
|  | 							t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) | ||||||
|  | 						} | ||||||
|  | 					case "Birthday": | ||||||
|  | 						if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) { | ||||||
|  | 							t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
| 					reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) | 					reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) | ||||||
| 					AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) | 					AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) | ||||||
| 				}) | 				}) | ||||||
| @ -86,13 +135,29 @@ func TestFind(t *testing.T) { | |||||||
| 	t.Run("FirstSliceOfMap", func(t *testing.T) { | 	t.Run("FirstSliceOfMap", func(t *testing.T) { | ||||||
| 		var allMap = []map[string]interface{}{} | 		var allMap = []map[string]interface{}{} | ||||||
| 		if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { | 		if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { | ||||||
| 			t.Errorf("errors happened when query first: %v", err) | 			t.Errorf("errors happened when query find: %v", err) | ||||||
| 		} else { | 		} else { | ||||||
| 			for idx, user := range users { | 			for idx, user := range users { | ||||||
| 				t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { | 				t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { | ||||||
| 					for _, name := range []string{"Name", "Age", "Birthday"} { | 					for _, name := range []string{"Name", "Age", "Birthday"} { | ||||||
| 						t.Run(name, func(t *testing.T) { | 						t.Run(name, func(t *testing.T) { | ||||||
| 							dbName := DB.NamingStrategy.ColumnName("", name) | 							dbName := DB.NamingStrategy.ColumnName("", name) | ||||||
|  | 
 | ||||||
|  | 							switch name { | ||||||
|  | 							case "Name": | ||||||
|  | 								if _, ok := allMap[idx][dbName].(string); !ok { | ||||||
|  | 									t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) | ||||||
|  | 								} | ||||||
|  | 							case "Age": | ||||||
|  | 								if _, ok := allMap[idx][dbName].(uint); !ok { | ||||||
|  | 									t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) | ||||||
|  | 								} | ||||||
|  | 							case "Birthday": | ||||||
|  | 								if _, ok := allMap[idx][dbName].(*time.Time); !ok { | ||||||
|  | 									t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) | ||||||
|  | 								} | ||||||
|  | 							} | ||||||
|  | 
 | ||||||
| 							reflectValue := reflect.Indirect(reflect.ValueOf(user)) | 							reflectValue := reflect.Indirect(reflect.ValueOf(user)) | ||||||
| 							AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) | 							AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) | ||||||
| 						}) | 						}) | ||||||
| @ -101,6 +166,43 @@ func TestFind(t *testing.T) { | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	}) | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("FindSliceOfMapWithTable", func(t *testing.T) { | ||||||
|  | 		var allMap = []map[string]interface{}{} | ||||||
|  | 		if err := DB.Table("users").Where("name = ?", "find").Find(&allMap).Error; err != nil { | ||||||
|  | 			t.Errorf("errors happened when query find: %v", err) | ||||||
|  | 		} else { | ||||||
|  | 			for idx, user := range users { | ||||||
|  | 				t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { | ||||||
|  | 					for _, name := range []string{"Name", "Age", "Birthday"} { | ||||||
|  | 						t.Run(name, func(t *testing.T) { | ||||||
|  | 							dbName := DB.NamingStrategy.ColumnName("", name) | ||||||
|  | 							resultType := reflect.ValueOf(allMap[idx][dbName]).Type().Name() | ||||||
|  | 
 | ||||||
|  | 							switch name { | ||||||
|  | 							case "Name": | ||||||
|  | 								if !strings.Contains(resultType, "string") { | ||||||
|  | 									t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) | ||||||
|  | 								} | ||||||
|  | 							case "Age": | ||||||
|  | 								if !strings.Contains(resultType, "int") { | ||||||
|  | 									t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) | ||||||
|  | 								} | ||||||
|  | 							case "Birthday": | ||||||
|  | 								if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) { | ||||||
|  | 									t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) | ||||||
|  | 								} | ||||||
|  | 							} | ||||||
|  | 
 | ||||||
|  | 							reflectValue := reflect.Indirect(reflect.ValueOf(user)) | ||||||
|  | 							AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) | ||||||
|  | 						}) | ||||||
|  | 					} | ||||||
|  | 				}) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestQueryWithAssociation(t *testing.T) { | func TestQueryWithAssociation(t *testing.T) { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu