From 0a35abb59f2e64192dd0e87a8ff7b73eb5cde4b9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 27 Feb 2018 21:11:39 +0800 Subject: [PATCH] Handle SelectAttrs, OmitAttrs when generate assignable attrs --- model/field.go | 132 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 88 insertions(+), 44 deletions(-) diff --git a/model/field.go b/model/field.go index 0dcf0647..0589a059 100644 --- a/model/field.go +++ b/model/field.go @@ -20,44 +20,6 @@ type Field struct { Value reflect.Value } -// GetAssignments get assignments -func GetAssignments(tx *gorm.DB) chan [][]*Field { - fieldChan := make(chan [][]*Field) - - go func() { - // TODO handle select, omit, protected - switch dest := tx.Statement.Dest.(type) { - case map[string]interface{}: - fieldChan <- [][]*Field{mapToFields(dest, tx.Statement, schema.Parse(tx.Statement.Table))} - case []map[string]interface{}: - fields := [][]*Field{} - tableSchema := schema.Parse(tx.Statement.Table) - - for _, v := range dest { - fields = append(fields, mapToFields(v, tx.Statement, tableSchema)) - } - fieldChan <- fields - default: - if s := schema.Parse(tx.Statement.Dest); s != nil { - results := indirect(reflect.ValueOf(tx.Statement.Dest)) - - switch results.Kind() { - case reflect.Slice: - fields := [][]*Field{} - for i := 0; i < results.Len(); i++ { - fields = append(fields, structToField(results.Index(i), tx.Statement, s)) - } - fieldChan <- fields - case reflect.Struct: - fieldChan <- [][]*Field{structToField(results, tx.Statement, s)} - } - } - } - }() - - return fieldChan -} - // Set set a value to the field func (field *Field) Set(value interface{}) (err error) { if !field.Value.IsValid() { @@ -101,18 +63,62 @@ func (field *Field) Set(value interface{}) (err error) { return err } -func mapToFields(value map[string]interface{}, stmt *builder.Statement, s *schema.Schema) (fields []*Field) { - // sort +// GetAssignments get assignments +func GetAssignments(tx *gorm.DB) chan [][]*Field { + fieldChan := make(chan [][]*Field) + + go func() { + assignableChecker := generateAssignableChecker(selectAttrs(tx.Statement), omitAttrs(tx.Statement)) + + switch dest := tx.Statement.Dest.(type) { + case map[string]interface{}: + fieldChan <- [][]*Field{mapToFields(dest, schema.Parse(tx.Statement.Table), assignableChecker)} + case []map[string]interface{}: + fields := [][]*Field{} + tableSchema := schema.Parse(tx.Statement.Table) + + for _, v := range dest { + fields = append(fields, mapToFields(v, tableSchema, assignableChecker)) + } + fieldChan <- fields + default: + if s := schema.Parse(tx.Statement.Dest); s != nil { + results := indirect(reflect.ValueOf(tx.Statement.Dest)) + + switch results.Kind() { + case reflect.Slice: + fields := [][]*Field{} + for i := 0; i < results.Len(); i++ { + fields = append(fields, structToField(results.Index(i), s, assignableChecker)) + } + fieldChan <- fields + case reflect.Struct: + fieldChan <- [][]*Field{structToField(results, s, assignableChecker)} + } + } + } + }() + + return fieldChan +} + +func mapToFields(value map[string]interface{}, s *schema.Schema, assignableChecker func(*Field) bool) (fields []*Field) { // TODO assign those value to dest for k, v := range value { if s != nil { if f := s.FieldByName(k); f != nil { - fields = append(fields, &Field{Field: f, Value: reflect.ValueOf(v)}) + field := &Field{Field: f, Value: reflect.ValueOf(v)} + if assignableChecker(field) { + fields = append(fields, field) + } continue } } - fields = append(fields, &Field{Field: &schema.Field{DBName: k}, Value: reflect.ValueOf(v)}) + field := &Field{Field: &schema.Field{DBName: k}, Value: reflect.ValueOf(v)} + if assignableChecker(field) { + fields = append(fields, field) + } } sort.SliceStable(fields, func(i, j int) bool { @@ -121,14 +127,52 @@ func mapToFields(value map[string]interface{}, stmt *builder.Statement, s *schem return } -func structToField(value reflect.Value, stmt *builder.Statement, s *schema.Schema) (fields []*Field) { +func structToField(value reflect.Value, s *schema.Schema, assignableChecker func(*Field) bool) (fields []*Field) { // TODO use Offset to replace FieldByName? for _, sf := range s.Fields { obj := value for _, bn := range sf.BindNames { obj = value.FieldByName(bn) } - fields = append(fields, &Field{Field: sf, Value: obj, IsBlank: isBlank(obj)}) + field := &Field{Field: sf, Value: obj, IsBlank: isBlank(obj)} + if assignableChecker(field) { + fields = append(fields, field) + } } return } + +// generateAssignableChecker generate checker to check if field is assignable or not +func generateAssignableChecker(selectAttrs []string, omitAttrs []string) func(*Field) bool { + return func(field *Field) bool { + if len(selectAttrs) > 0 { + for _, attr := range selectAttrs { + if field.Name == attr || field.DBName == attr { + return true + } + } + return false + } + + for _, attr := range omitAttrs { + if field.Name == attr || field.DBName == attr { + return false + } + } + return true + } +} + +// omitAttrs return selected attributes of stmt +func selectAttrs(stmt *builder.Statement) []string { + columns := stmt.Select.Columns + for _, arg := range stmt.Select.Args { + columns = append(columns, fmt.Sprint(arg)) + } + return columns +} + +// omitAttrs return omitted attributes of stmt +func omitAttrs(stmt *builder.Statement) []string { + return stmt.Omit +}