From b0e1bccf4ad5f803df27a8974491bcbc04a4b02c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 4 Mar 2020 11:32:36 +0800 Subject: [PATCH] Support scan into map, slice, struct --- callbacks/query.go | 21 +------- callbacks/scan.go | 98 ++++++++++++++++++++++++++++++++++++ finisher_api.go | 2 +- schema/schema_helper_test.go | 40 ++------------- tests/tests.go | 93 +++++++++++++++++++++++++++++++++- tests/utils.go | 41 +++++++++++---- utils/utils.go | 4 +- 7 files changed, 228 insertions(+), 71 deletions(-) create mode 100644 callbacks/scan.go diff --git a/callbacks/query.go b/callbacks/query.go index 21b58aaf..26c0e0ad 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -1,7 +1,6 @@ package callbacks import ( - "database/sql" "reflect" "github.com/jinzhu/gorm" @@ -22,25 +21,7 @@ func Query(db *gorm.DB) { } defer rows.Close() - columns, _ := rows.Columns() - values := make([]interface{}, len(columns)) - - for idx, column := range columns { - if field, ok := db.Statement.Schema.FieldsByDBName[column]; ok { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() - } else { - values[idx] = sql.RawBytes{} - } - } - - for rows.Next() { - db.RowsAffected++ - rows.Scan(values...) - } - - if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound { - db.AddError(gorm.ErrRecordNotFound) - } + Scan(rows, db) } func Preload(db *gorm.DB) { diff --git a/callbacks/scan.go b/callbacks/scan.go new file mode 100644 index 00000000..c9f948b1 --- /dev/null +++ b/callbacks/scan.go @@ -0,0 +1,98 @@ +package callbacks + +import ( + "database/sql" + "reflect" + + "github.com/jinzhu/gorm" +) + +func Scan(rows *sql.Rows, db *gorm.DB) { + columns, _ := rows.Columns() + values := make([]interface{}, len(columns)) + + switch dest := db.Statement.Dest.(type) { + case map[string]interface{}, *map[string]interface{}: + for idx, _ := range columns { + values[idx] = new(interface{}) + } + + if rows.Next() { + db.RowsAffected++ + rows.Scan(values...) + } + + mapValue, ok := dest.(map[string]interface{}) + if ok { + if v, ok := dest.(*map[string]interface{}); ok { + mapValue = *v + } + } + + for idx, column := range columns { + mapValue[column] = *(values[idx].(*interface{})) + } + case *[]map[string]interface{}: + for idx, _ := range columns { + values[idx] = new(interface{}) + } + + for rows.Next() { + db.RowsAffected++ + rows.Scan(values...) + + v := map[string]interface{}{} + for idx, column := range columns { + v[column] = *(values[idx].(*interface{})) + } + *dest = append(*dest, v) + } + default: + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + isPtr := db.Statement.ReflectValue.Type().Elem().Kind() == reflect.Ptr + db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) + + for rows.Next() { + elem := reflect.New(db.Statement.Schema.ModelType).Elem() + for idx, column := range columns { + if field := db.Statement.Schema.LookUpField(column); field != nil { + values[idx] = field.ReflectValueOf(elem).Addr().Interface() + } else if db.RowsAffected == 0 { + values[idx] = sql.RawBytes{} + } + } + + db.RowsAffected++ + if err := rows.Scan(values...); err != nil { + db.AddError(err) + } + + if isPtr { + db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr())) + } else { + db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem)) + } + } + case reflect.Struct: + for idx, column := range columns { + if field := db.Statement.Schema.LookUpField(column); field != nil { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + } else { + values[idx] = sql.RawBytes{} + } + } + + if rows.Next() { + db.RowsAffected++ + if err := rows.Scan(values...); err != nil { + db.AddError(err) + } + } + } + } + + if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound { + db.AddError(gorm.ErrRecordNotFound) + } +} diff --git a/finisher_api.go b/finisher_api.go index 83988546..c918c08a 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -26,7 +26,6 @@ func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { // TODO handle where tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, - Desc: true, }) tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out @@ -47,6 +46,7 @@ func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) { func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Desc: true, }) tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 196d19c4..146ba13a 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -1,7 +1,6 @@ package schema_test import ( - "database/sql/driver" "fmt" "reflect" "strings" @@ -13,7 +12,7 @@ import ( func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { t.Run("CheckSchema/"+s.Name, func(t *testing.T) { - tests.AssertEqual(t, s, v, "Name", "Table") + tests.AssertObjEqual(t, s, v, "Name", "Table") for idx, field := range primaryFields { var found bool @@ -53,7 +52,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* if parsedField, ok := s.FieldsByName[f.Name]; !ok { t.Errorf("schema %v failed to look up field with name %v", s, f.Name) } else { - tests.AssertEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") + tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) @@ -195,39 +194,8 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { for k, v := range values { t.Run("CheckField/"+k, func(t *testing.T) { - var ( - checker func(fv interface{}, v interface{}) - field = s.FieldsByDBName[k] - fv, _ = field.ValueOf(value) - ) - - checker = func(fv interface{}, v interface{}) { - if reflect.ValueOf(fv).Type() == reflect.ValueOf(v).Type() && fv != v { - t.Errorf("expects: %p, but got %p", v, fv) - } else if reflect.ValueOf(v).Type().ConvertibleTo(reflect.ValueOf(fv).Type()) { - if reflect.ValueOf(v).Convert(reflect.ValueOf(fv).Type()).Interface() != fv { - t.Errorf("expects: %p, but got %p", v, fv) - } - } else if reflect.ValueOf(fv).Type().ConvertibleTo(reflect.ValueOf(v).Type()) { - if reflect.ValueOf(fv).Convert(reflect.ValueOf(fv).Type()).Interface() != v { - t.Errorf("expects: %p, but got %p", v, fv) - } - } else if valuer, isValuer := fv.(driver.Valuer); isValuer { - valuerv, _ := valuer.Value() - checker(valuerv, v) - } else if valuer, isValuer := v.(driver.Valuer); isValuer { - valuerv, _ := valuer.Value() - checker(fv, valuerv) - } else if reflect.ValueOf(fv).Kind() == reflect.Ptr { - checker(reflect.ValueOf(fv).Elem().Interface(), v) - } else if reflect.ValueOf(v).Kind() == reflect.Ptr { - checker(fv, reflect.ValueOf(v).Elem().Interface()) - } else { - t.Errorf("expects: %+v, but got %+v", v, fv) - } - } - - checker(fv, v) + fv, _ := s.FieldsByDBName[k].ValueOf(value) + tests.AssertEqual(t, v, fv) }) } } diff --git a/tests/tests.go b/tests/tests.go index 5e47c09e..2f0dfd34 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -1,6 +1,9 @@ package tests import ( + "log" + "reflect" + "strconv" "testing" "time" @@ -14,6 +17,7 @@ func Now() *time.Time { func RunTestsSuit(t *testing.T, db *gorm.DB) { TestCreate(t, db) + TestFind(t, db) } func TestCreate(t *testing.T, db *gorm.DB) { @@ -38,7 +42,94 @@ func TestCreate(t *testing.T, db *gorm.DB) { if err := db.Where("id = ?", user.ID).First(&newUser).Error; err != nil { t.Errorf("errors happened when query: %v", err) } else { - AssertEqual(t, newUser, user, "Name", "Age", "Birthday") + AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday") + } + }) +} + +func TestFind(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&User{}) + db.AutoMigrate(&User{}) + + t.Run("Find", func(t *testing.T) { + var users = []User{{ + Name: "find", + Age: 1, + Birthday: Now(), + }, { + Name: "find", + Age: 2, + Birthday: Now(), + }, { + Name: "find", + Age: 3, + Birthday: Now(), + }} + + if err := db.Create(&users).Error; err != nil { + t.Errorf("errors happened when create users: %v", err) + } + + t.Run("First", func(t *testing.T) { + var first User + if err := db.Where("name = ?", "find").First(&first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + AssertObjEqual(t, first, users[0], "Name", "Age", "Birthday") + } + }) + + t.Run("Last", func(t *testing.T) { + var last User + if err := db.Where("name = ?", "find").Last(&last).Error; err != nil { + t.Errorf("errors happened when query last: %v", err) + } else { + AssertObjEqual(t, last, users[2], "Name", "Age", "Birthday") + } + }) + + var all []User + if err := db.Where("name = ?", "find").Find(&all).Error; err != nil || len(all) != 3 { + t.Errorf("errors happened when query find: %v, length: %v", err, len(all)) + } else { + for idx, user := range users { + t.Run("FindAll#"+strconv.Itoa(idx+1), func(t *testing.T) { + AssertObjEqual(t, all[idx], user, "Name", "Age", "Birthday") + }) + } + } + + t.Run("FirstMap", func(t *testing.T) { + var first = map[string]interface{}{} + if err := db.Model(&User{}).Where("name = ?", "find").First(first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := db.NamingStrategy.ColumnName("", name) + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) + AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) + }) + } + } + }) + + var allMap = []map[string]interface{}{} + if err := db.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + log.Printf("all map %+v %+v", len(allMap), allMap) + for idx, user := range users { + t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := db.NamingStrategy.ColumnName("", name) + reflectValue := reflect.Indirect(reflect.ValueOf(user)) + AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) + }) + } + }) + } } }) } diff --git a/tests/utils.go b/tests/utils.go index 292a357d..9d61c422 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -6,24 +6,43 @@ import ( "time" ) -func AssertEqual(t *testing.T, r, e interface{}, names ...string) { +func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { for _, name := range names { got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() - expects := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() + expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() + t.Run(name, func(t *testing.T) { + AssertEqual(t, got, expect) + }) + } +} - if !reflect.DeepEqual(got, expects) { - got = reflect.Indirect(reflect.ValueOf(got)).Interface() - expects = reflect.Indirect(reflect.ValueOf(got)).Interface() +func AssertEqual(t *testing.T, got, expect interface{}) { + if !reflect.DeepEqual(got, expect) { + isEqual := func() { if curTime, ok := got.(time.Time); ok { format := "2006-01-02T15:04:05Z07:00" - if curTime.Format(format) != expects.(time.Time).Format(format) { - t.Errorf("expects: %v, got %v", expects.(time.Time).Format(format), curTime.Format(format)) + if curTime.Format(format) != expect.(time.Time).Format(format) { + t.Errorf("expect: %v, got %v", expect.(time.Time).Format(format), curTime.Format(format)) } - } else { - t.Run(name, func(t *testing.T) { - t.Errorf("expects: %v, got %v", expects, got) - }) + } else if got != expect { + t.Errorf("expect: %#v, got %#v", expect, got) } } + + if got != nil { + got = reflect.Indirect(reflect.ValueOf(got)).Interface() + } + + if expect != nil { + expect = reflect.Indirect(reflect.ValueOf(expect)).Interface() + } + + if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { + got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() + isEqual() + } else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { + expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() + isEqual() + } } } diff --git a/utils/utils.go b/utils/utils.go index 315ba930..86ea557b 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -6,8 +6,8 @@ import ( "runtime" ) -var goSrcRegexp = regexp.MustCompile(`/gorm/.*.go`) -var goTestRegexp = regexp.MustCompile(`/gorm/.*test.go`) +var goSrcRegexp = regexp.MustCompile(`/gorm/.*\.go`) +var goTestRegexp = regexp.MustCompile(`/gorm/.*test\.go`) func FileWithLineNum() string { for i := 2; i < 15; i++ {