From 33f8b60b289d02b24a246f496e6ba280253fe558 Mon Sep 17 00:00:00 2001 From: Nick Murray Date: Sat, 25 Jul 2020 22:16:19 +0800 Subject: [PATCH] proposal for struct scanning --- scan.go | 68 ++++++++++++++++++++++++++++++--------------- tests/query_test.go | 6 ++-- 2 files changed, 48 insertions(+), 26 deletions(-) diff --git a/scan.go b/scan.go index 0b199029..555123bb 100644 --- a/scan.go +++ b/scan.go @@ -162,18 +162,32 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } if initialized || rows.Next() { + fieldsIndex := 0 + var fieldnames []string + for _, f := range Schema.Fields { + fieldnames = append(fieldnames, f.DBName) + } for idx, column := range columns { - if field := Schema.LookUpField(column); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := Schema.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() - continue + if fieldsIndex >= len(Schema.Fields) { + values[idx] = &sql.RawBytes{} + continue + } + for _, field := range Schema.Fields[fieldsIndex:] { + if (field.DBName == column || field.Name == column) && field.Readable { + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + fieldsIndex++ + break + } + if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := Schema.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + break + } } } - values[idx] = &sql.RawBytes{} - } else { + } + if values[idx] == nil { values[idx] = &sql.RawBytes{} } } @@ -181,23 +195,31 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { db.RowsAffected++ db.AddError(rows.Scan(values...)) + fieldsIndex = 0 for idx, column := range columns { - if field := Schema.LookUpField(column); field != nil && field.Readable { - field.Set(db.Statement.ReflectValue, values[idx]) - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := Schema.Relationships.Relations[names[0]]; ok { - relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - value := reflect.ValueOf(values[idx]).Elem() - - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value.IsNil() { - continue + if fieldsIndex >= len(Schema.Fields) { + continue + } + for _, field := range Schema.Fields[fieldsIndex:] { + if (field.DBName == column || field.Name == column) && field.Readable { + field.Set(db.Statement.ReflectValue, values[idx]) + fieldsIndex++ + break + } + if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := Schema.Relationships.Relations[names[0]]; ok { + relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + value := reflect.ValueOf(values[idx]).Elem() + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value.IsNil() { + break + } + relValue.Set(reflect.New(relValue.Type().Elem())) } - relValue.Set(reflect.New(relValue.Type().Elem())) + field.Set(relValue, values[idx]) + break } - - field.Set(relValue, values[idx]) } } } diff --git a/tests/query_test.go b/tests/query_test.go index 59f1130b..69d88855 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -146,9 +146,9 @@ func TestFillSmallerStruct(t *testing.T) { DB.Save(&user) type SimpleUser struct { ID int64 - Name string - UpdatedAt time.Time CreatedAt time.Time + UpdatedAt time.Time + Name string } var simpleUser SimpleUser @@ -174,7 +174,7 @@ func TestFillSmallerStruct(t *testing.T) { result := DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&simpleUsers, user.ID) - if !regexp.MustCompile("SELECT .*id.*name.*updated_at.*created_at.* FROM .*users").MatchString(result.Statement.SQL.String()) { + if !regexp.MustCompile("SELECT .*id.*created_at.*updated_at.*name.* FROM .*users").MatchString(result.Statement.SQL.String()) { t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String()) }