Support scan into map, slice, struct
This commit is contained in:
		
							parent
							
								
									1403ee70c3
								
							
						
					
					
						commit
						b0e1bccf4a
					
				| @ -1,7 +1,6 @@ | ||||
| package callbacks | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"reflect" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| @ -22,25 +21,7 @@ func Query(db *gorm.DB) { | ||||
| 	} | ||||
| 	defer rows.Close() | ||||
| 
 | ||||
| 	columns, _ := rows.Columns() | ||||
| 	values := make([]interface{}, len(columns)) | ||||
| 
 | ||||
| 	for idx, column := range columns { | ||||
| 		if field, ok := db.Statement.Schema.FieldsByDBName[column]; ok { | ||||
| 			values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() | ||||
| 		} else { | ||||
| 			values[idx] = sql.RawBytes{} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for rows.Next() { | ||||
| 		db.RowsAffected++ | ||||
| 		rows.Scan(values...) | ||||
| 	} | ||||
| 
 | ||||
| 	if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound { | ||||
| 		db.AddError(gorm.ErrRecordNotFound) | ||||
| 	} | ||||
| 	Scan(rows, db) | ||||
| } | ||||
| 
 | ||||
| func Preload(db *gorm.DB) { | ||||
|  | ||||
							
								
								
									
										98
									
								
								callbacks/scan.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								callbacks/scan.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,98 @@ | ||||
| package callbacks | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"reflect" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| func Scan(rows *sql.Rows, db *gorm.DB) { | ||||
| 	columns, _ := rows.Columns() | ||||
| 	values := make([]interface{}, len(columns)) | ||||
| 
 | ||||
| 	switch dest := db.Statement.Dest.(type) { | ||||
| 	case map[string]interface{}, *map[string]interface{}: | ||||
| 		for idx, _ := range columns { | ||||
| 			values[idx] = new(interface{}) | ||||
| 		} | ||||
| 
 | ||||
| 		if rows.Next() { | ||||
| 			db.RowsAffected++ | ||||
| 			rows.Scan(values...) | ||||
| 		} | ||||
| 
 | ||||
| 		mapValue, ok := dest.(map[string]interface{}) | ||||
| 		if ok { | ||||
| 			if v, ok := dest.(*map[string]interface{}); ok { | ||||
| 				mapValue = *v | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		for idx, column := range columns { | ||||
| 			mapValue[column] = *(values[idx].(*interface{})) | ||||
| 		} | ||||
| 	case *[]map[string]interface{}: | ||||
| 		for idx, _ := range columns { | ||||
| 			values[idx] = new(interface{}) | ||||
| 		} | ||||
| 
 | ||||
| 		for rows.Next() { | ||||
| 			db.RowsAffected++ | ||||
| 			rows.Scan(values...) | ||||
| 
 | ||||
| 			v := map[string]interface{}{} | ||||
| 			for idx, column := range columns { | ||||
| 				v[column] = *(values[idx].(*interface{})) | ||||
| 			} | ||||
| 			*dest = append(*dest, v) | ||||
| 		} | ||||
| 	default: | ||||
| 		switch db.Statement.ReflectValue.Kind() { | ||||
| 		case reflect.Slice, reflect.Array: | ||||
| 			isPtr := db.Statement.ReflectValue.Type().Elem().Kind() == reflect.Ptr | ||||
| 			db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) | ||||
| 
 | ||||
| 			for rows.Next() { | ||||
| 				elem := reflect.New(db.Statement.Schema.ModelType).Elem() | ||||
| 				for idx, column := range columns { | ||||
| 					if field := db.Statement.Schema.LookUpField(column); field != nil { | ||||
| 						values[idx] = field.ReflectValueOf(elem).Addr().Interface() | ||||
| 					} else if db.RowsAffected == 0 { | ||||
| 						values[idx] = sql.RawBytes{} | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				db.RowsAffected++ | ||||
| 				if err := rows.Scan(values...); err != nil { | ||||
| 					db.AddError(err) | ||||
| 				} | ||||
| 
 | ||||
| 				if isPtr { | ||||
| 					db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr())) | ||||
| 				} else { | ||||
| 					db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem)) | ||||
| 				} | ||||
| 			} | ||||
| 		case reflect.Struct: | ||||
| 			for idx, column := range columns { | ||||
| 				if field := db.Statement.Schema.LookUpField(column); field != nil { | ||||
| 					values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() | ||||
| 				} else { | ||||
| 					values[idx] = sql.RawBytes{} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if rows.Next() { | ||||
| 				db.RowsAffected++ | ||||
| 				if err := rows.Scan(values...); err != nil { | ||||
| 					db.AddError(err) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound { | ||||
| 		db.AddError(gorm.ErrRecordNotFound) | ||||
| 	} | ||||
| } | ||||
| @ -26,7 +26,6 @@ func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { | ||||
| 	// TODO handle where
 | ||||
| 	tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||
| 		Desc:   true, | ||||
| 	}) | ||||
| 	tx.Statement.RaiseErrorOnNotFound = true | ||||
| 	tx.Statement.Dest = out | ||||
| @ -47,6 +46,7 @@ func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) { | ||||
| func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||
| 		Desc:   true, | ||||
| 	}) | ||||
| 	tx.Statement.RaiseErrorOnNotFound = true | ||||
| 	tx.Statement.Dest = out | ||||
|  | ||||
| @ -1,7 +1,6 @@ | ||||
| package schema_test | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| @ -13,7 +12,7 @@ import ( | ||||
| 
 | ||||
| func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { | ||||
| 	t.Run("CheckSchema/"+s.Name, func(t *testing.T) { | ||||
| 		tests.AssertEqual(t, s, v, "Name", "Table") | ||||
| 		tests.AssertObjEqual(t, s, v, "Name", "Table") | ||||
| 
 | ||||
| 		for idx, field := range primaryFields { | ||||
| 			var found bool | ||||
| @ -53,7 +52,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* | ||||
| 		if parsedField, ok := s.FieldsByName[f.Name]; !ok { | ||||
| 			t.Errorf("schema %v failed to look up field with name %v", s, f.Name) | ||||
| 		} else { | ||||
| 			tests.AssertEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") | ||||
| 			tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") | ||||
| 
 | ||||
| 			if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { | ||||
| 				t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) | ||||
| @ -195,39 +194,8 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { | ||||
| func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { | ||||
| 	for k, v := range values { | ||||
| 		t.Run("CheckField/"+k, func(t *testing.T) { | ||||
| 			var ( | ||||
| 				checker func(fv interface{}, v interface{}) | ||||
| 				field   = s.FieldsByDBName[k] | ||||
| 				fv, _   = field.ValueOf(value) | ||||
| 			) | ||||
| 
 | ||||
| 			checker = func(fv interface{}, v interface{}) { | ||||
| 				if reflect.ValueOf(fv).Type() == reflect.ValueOf(v).Type() && fv != v { | ||||
| 					t.Errorf("expects: %p, but got %p", v, fv) | ||||
| 				} else if reflect.ValueOf(v).Type().ConvertibleTo(reflect.ValueOf(fv).Type()) { | ||||
| 					if reflect.ValueOf(v).Convert(reflect.ValueOf(fv).Type()).Interface() != fv { | ||||
| 						t.Errorf("expects: %p, but got %p", v, fv) | ||||
| 					} | ||||
| 				} else if reflect.ValueOf(fv).Type().ConvertibleTo(reflect.ValueOf(v).Type()) { | ||||
| 					if reflect.ValueOf(fv).Convert(reflect.ValueOf(fv).Type()).Interface() != v { | ||||
| 						t.Errorf("expects: %p, but got %p", v, fv) | ||||
| 					} | ||||
| 				} else if valuer, isValuer := fv.(driver.Valuer); isValuer { | ||||
| 					valuerv, _ := valuer.Value() | ||||
| 					checker(valuerv, v) | ||||
| 				} else if valuer, isValuer := v.(driver.Valuer); isValuer { | ||||
| 					valuerv, _ := valuer.Value() | ||||
| 					checker(fv, valuerv) | ||||
| 				} else if reflect.ValueOf(fv).Kind() == reflect.Ptr { | ||||
| 					checker(reflect.ValueOf(fv).Elem().Interface(), v) | ||||
| 				} else if reflect.ValueOf(v).Kind() == reflect.Ptr { | ||||
| 					checker(fv, reflect.ValueOf(v).Elem().Interface()) | ||||
| 				} else { | ||||
| 					t.Errorf("expects: %+v, but got %+v", v, fv) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			checker(fv, v) | ||||
| 			fv, _ := s.FieldsByDBName[k].ValueOf(value) | ||||
| 			tests.AssertEqual(t, v, fv) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -1,6 +1,9 @@ | ||||
| package tests | ||||
| 
 | ||||
| import ( | ||||
| 	"log" | ||||
| 	"reflect" | ||||
| 	"strconv" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| @ -14,6 +17,7 @@ func Now() *time.Time { | ||||
| 
 | ||||
| func RunTestsSuit(t *testing.T, db *gorm.DB) { | ||||
| 	TestCreate(t, db) | ||||
| 	TestFind(t, db) | ||||
| } | ||||
| 
 | ||||
| func TestCreate(t *testing.T, db *gorm.DB) { | ||||
| @ -38,7 +42,94 @@ func TestCreate(t *testing.T, db *gorm.DB) { | ||||
| 		if err := db.Where("id = ?", user.ID).First(&newUser).Error; err != nil { | ||||
| 			t.Errorf("errors happened when query: %v", err) | ||||
| 		} else { | ||||
| 			AssertEqual(t, newUser, user, "Name", "Age", "Birthday") | ||||
| 			AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday") | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func TestFind(t *testing.T, db *gorm.DB) { | ||||
| 	db.Migrator().DropTable(&User{}) | ||||
| 	db.AutoMigrate(&User{}) | ||||
| 
 | ||||
| 	t.Run("Find", func(t *testing.T) { | ||||
| 		var users = []User{{ | ||||
| 			Name:     "find", | ||||
| 			Age:      1, | ||||
| 			Birthday: Now(), | ||||
| 		}, { | ||||
| 			Name:     "find", | ||||
| 			Age:      2, | ||||
| 			Birthday: Now(), | ||||
| 		}, { | ||||
| 			Name:     "find", | ||||
| 			Age:      3, | ||||
| 			Birthday: Now(), | ||||
| 		}} | ||||
| 
 | ||||
| 		if err := db.Create(&users).Error; err != nil { | ||||
| 			t.Errorf("errors happened when create users: %v", err) | ||||
| 		} | ||||
| 
 | ||||
| 		t.Run("First", func(t *testing.T) { | ||||
| 			var first User | ||||
| 			if err := db.Where("name = ?", "find").First(&first).Error; err != nil { | ||||
| 				t.Errorf("errors happened when query first: %v", err) | ||||
| 			} else { | ||||
| 				AssertObjEqual(t, first, users[0], "Name", "Age", "Birthday") | ||||
| 			} | ||||
| 		}) | ||||
| 
 | ||||
| 		t.Run("Last", func(t *testing.T) { | ||||
| 			var last User | ||||
| 			if err := db.Where("name = ?", "find").Last(&last).Error; err != nil { | ||||
| 				t.Errorf("errors happened when query last: %v", err) | ||||
| 			} else { | ||||
| 				AssertObjEqual(t, last, users[2], "Name", "Age", "Birthday") | ||||
| 			} | ||||
| 		}) | ||||
| 
 | ||||
| 		var all []User | ||||
| 		if err := db.Where("name = ?", "find").Find(&all).Error; err != nil || len(all) != 3 { | ||||
| 			t.Errorf("errors happened when query find: %v, length: %v", err, len(all)) | ||||
| 		} else { | ||||
| 			for idx, user := range users { | ||||
| 				t.Run("FindAll#"+strconv.Itoa(idx+1), func(t *testing.T) { | ||||
| 					AssertObjEqual(t, all[idx], user, "Name", "Age", "Birthday") | ||||
| 				}) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		t.Run("FirstMap", func(t *testing.T) { | ||||
| 			var first = map[string]interface{}{} | ||||
| 			if err := db.Model(&User{}).Where("name = ?", "find").First(first).Error; err != nil { | ||||
| 				t.Errorf("errors happened when query first: %v", err) | ||||
| 			} else { | ||||
| 				for _, name := range []string{"Name", "Age", "Birthday"} { | ||||
| 					t.Run(name, func(t *testing.T) { | ||||
| 						dbName := db.NamingStrategy.ColumnName("", name) | ||||
| 						reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) | ||||
| 						AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) | ||||
| 					}) | ||||
| 				} | ||||
| 			} | ||||
| 		}) | ||||
| 
 | ||||
| 		var allMap = []map[string]interface{}{} | ||||
| 		if err := db.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { | ||||
| 			t.Errorf("errors happened when query first: %v", err) | ||||
| 		} else { | ||||
| 			log.Printf("all map %+v %+v", len(allMap), allMap) | ||||
| 			for idx, user := range users { | ||||
| 				t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { | ||||
| 					for _, name := range []string{"Name", "Age", "Birthday"} { | ||||
| 						t.Run(name, func(t *testing.T) { | ||||
| 							dbName := db.NamingStrategy.ColumnName("", name) | ||||
| 							reflectValue := reflect.Indirect(reflect.ValueOf(user)) | ||||
| 							AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) | ||||
| 						}) | ||||
| 					} | ||||
| 				}) | ||||
| 			} | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| @ -6,24 +6,43 @@ import ( | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| func AssertEqual(t *testing.T, r, e interface{}, names ...string) { | ||||
| func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { | ||||
| 	for _, name := range names { | ||||
| 		got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() | ||||
| 		expects := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() | ||||
| 		expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			AssertEqual(t, got, expect) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| 		if !reflect.DeepEqual(got, expects) { | ||||
| 			got = reflect.Indirect(reflect.ValueOf(got)).Interface() | ||||
| 			expects = reflect.Indirect(reflect.ValueOf(got)).Interface() | ||||
| func AssertEqual(t *testing.T, got, expect interface{}) { | ||||
| 	if !reflect.DeepEqual(got, expect) { | ||||
| 		isEqual := func() { | ||||
| 			if curTime, ok := got.(time.Time); ok { | ||||
| 				format := "2006-01-02T15:04:05Z07:00" | ||||
| 				if curTime.Format(format) != expects.(time.Time).Format(format) { | ||||
| 					t.Errorf("expects: %v, got %v", expects.(time.Time).Format(format), curTime.Format(format)) | ||||
| 				if curTime.Format(format) != expect.(time.Time).Format(format) { | ||||
| 					t.Errorf("expect: %v, got %v", expect.(time.Time).Format(format), curTime.Format(format)) | ||||
| 				} | ||||
| 			} else { | ||||
| 				t.Run(name, func(t *testing.T) { | ||||
| 					t.Errorf("expects: %v, got %v", expects, got) | ||||
| 				}) | ||||
| 			} else if got != expect { | ||||
| 				t.Errorf("expect: %#v, got %#v", expect, got) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if got != nil { | ||||
| 			got = reflect.Indirect(reflect.ValueOf(got)).Interface() | ||||
| 		} | ||||
| 
 | ||||
| 		if expect != nil { | ||||
| 			expect = reflect.Indirect(reflect.ValueOf(expect)).Interface() | ||||
| 		} | ||||
| 
 | ||||
| 		if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { | ||||
| 			got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() | ||||
| 			isEqual() | ||||
| 		} else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { | ||||
| 			expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() | ||||
| 			isEqual() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -6,8 +6,8 @@ import ( | ||||
| 	"runtime" | ||||
| ) | ||||
| 
 | ||||
| var goSrcRegexp = regexp.MustCompile(`/gorm/.*.go`) | ||||
| var goTestRegexp = regexp.MustCompile(`/gorm/.*test.go`) | ||||
| var goSrcRegexp = regexp.MustCompile(`/gorm/.*\.go`) | ||||
| var goTestRegexp = regexp.MustCompile(`/gorm/.*test\.go`) | ||||
| 
 | ||||
| func FileWithLineNum() string { | ||||
| 	for i := 2; i < 15; i++ { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu