proposal for struct scanning

This commit is contained in:
Nick Murray 2020-07-25 22:16:19 +08:00
parent 69d8111893
commit 33f8b60b28
2 changed files with 48 additions and 26 deletions

42
scan.go
View File

@ -162,18 +162,32 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
} }
if initialized || rows.Next() { if initialized || rows.Next() {
fieldsIndex := 0
var fieldnames []string
for _, f := range Schema.Fields {
fieldnames = append(fieldnames, f.DBName)
}
for idx, column := range columns { for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable { if fieldsIndex >= len(Schema.Fields) {
values[idx] = &sql.RawBytes{}
continue
}
for _, field := range Schema.Fields[fieldsIndex:] {
if (field.DBName == column || field.Name == column) && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
} else if names := strings.Split(column, "__"); len(names) > 1 { fieldsIndex++
break
}
if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
continue break
} }
} }
values[idx] = &sql.RawBytes{} }
} else { }
if values[idx] == nil {
values[idx] = &sql.RawBytes{} values[idx] = &sql.RawBytes{}
} }
} }
@ -181,23 +195,31 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
fieldsIndex = 0
for idx, column := range columns { for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable { if fieldsIndex >= len(Schema.Fields) {
continue
}
for _, field := range Schema.Fields[fieldsIndex:] {
if (field.DBName == column || field.Name == column) && field.Readable {
field.Set(db.Statement.ReflectValue, values[idx]) field.Set(db.Statement.ReflectValue, values[idx])
} else if names := strings.Split(column, "__"); len(names) > 1 { fieldsIndex++
break
}
if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
value := reflect.ValueOf(values[idx]).Elem() value := reflect.ValueOf(values[idx]).Elem()
if relValue.Kind() == reflect.Ptr && relValue.IsNil() { if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value.IsNil() { if value.IsNil() {
continue break
} }
relValue.Set(reflect.New(relValue.Type().Elem())) relValue.Set(reflect.New(relValue.Type().Elem()))
} }
field.Set(relValue, values[idx]) field.Set(relValue, values[idx])
break
}
} }
} }
} }

View File

@ -146,9 +146,9 @@ func TestFillSmallerStruct(t *testing.T) {
DB.Save(&user) DB.Save(&user)
type SimpleUser struct { type SimpleUser struct {
ID int64 ID int64
Name string
UpdatedAt time.Time
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time
Name string
} }
var simpleUser SimpleUser var simpleUser SimpleUser
@ -174,7 +174,7 @@ func TestFillSmallerStruct(t *testing.T) {
result := DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&simpleUsers, user.ID) result := DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&simpleUsers, user.ID)
if !regexp.MustCompile("SELECT .*id.*name.*updated_at.*created_at.* FROM .*users").MatchString(result.Statement.SQL.String()) { if !regexp.MustCompile("SELECT .*id.*created_at.*updated_at.*name.* FROM .*users").MatchString(result.Statement.SQL.String()) {
t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String()) t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String())
} }