diff --git a/README.md b/README.md index 1bc164ce..177456aa 100644 --- a/README.md +++ b/README.md @@ -1172,7 +1172,7 @@ db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111 //// INSERT INTO "users" (email,registered_ip) VALUES ("x@example.org", "111.111.111.111") // if record not found ``` -## Tables with embedded iterface fields +## Embedded interface fields ```go type Delta struct { @@ -1207,7 +1207,12 @@ found := Delta { Became: &Login{}, } db.Where(&Delta{Was: &Login{Login: "Login1"}, Became: &Login{}}).First(&found); -//// SELECT * FROM "delta__Login__Login" WHERE ("was__login" = 'Login1') ORDER BY "delta__Login__Login"."id" ASC LIMIT 1 +//// SELECT * FROM "delta__Login__Login" WHERE ("was__login" = 'Login1') ORDER BY "delta__Login__Login"."id" ASC LIMIT 1 + +deltas := []Delta{{Was: &Login{}, Became: &Login{}}}; +db.Find(&deltas) +//// SELECT * FROM "deltas__Login__Login" +deltas = append(deltas[:0], deltas[1:]...) ``` ## TODO diff --git a/field.go b/field.go index 1e4353ff..0f8d406a 100644 --- a/field.go +++ b/field.go @@ -76,8 +76,12 @@ func (scope *Scope) Fields() map[string]*Field { func getField(indirectValue reflect.Value, structField *StructField) *Field { field := &Field{StructField: structField} for _, name := range structField.Names { - if (reflect.Indirect(indirectValue).Kind() == reflect.Interface) { - indirectValue = indirectValue.Elem() + for ;reflect.Indirect(indirectValue).Kind() == reflect.Interface; { + if (indirectValue.Elem().IsValid()) { + indirectValue = indirectValue.Elem() + } else { + indirectValue.Set(reflect.New(structField.Value.Type())) + } } indirectValue = reflect.Indirect(indirectValue).FieldByName(name) } diff --git a/model_struct.go b/model_struct.go index 18610038..0bbeff76 100644 --- a/model_struct.go +++ b/model_struct.go @@ -13,6 +13,7 @@ import ( var modelStructs_byScopeType = map[reflect.Type]*ModelStruct{} var modelStructs_byTableName = map[string ]*ModelStruct{} +var modelStruct_last *ModelStruct type ModelStruct struct { PrimaryFields []*StructField @@ -94,18 +95,30 @@ func (scope *Scope) GetModelStruct() *ModelStruct { modelStruct.ModelType = scopeType if scopeType.Kind() != reflect.Struct { + modelStruct_last = &modelStruct return &modelStruct } - // Set tablename + // Getting table name appendix for i := 0; i < scopeType.NumField(); i++ { if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) { if (fieldStruct.Type.Kind() == reflect.Interface) { - value := reflect.ValueOf(reflect.ValueOf(scope.Value).Elem().Field(i).Interface()) + // Interface field + value := reflect.ValueOf(scope.Value).Elem() + if (value.Kind() == reflect.Slice) { + // A slice, using the first element + value = value.Index(0) + } + value = reflect.ValueOf(value.Field(i).Interface()) + if (! value.IsValid()) { + // Invalid interfaces, using Model()'s result + return modelStruct_last + } tableName = tableName + "__" + value.Elem().Type().Name() } } } + // Set tablename if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() { if results := fm.Call([]reflect.Value{}); len(results) > 0 { if name, ok := results[0].Interface().(string); ok { @@ -132,6 +145,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if value, ok := modelStructs_byTableName[tableName]; ok { + modelStruct_last = value return value } @@ -144,6 +158,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if (fieldStruct.Type.Kind() == reflect.Interface) { value = reflect.ValueOf(reflect.ValueOf(scope.Value).Elem().Field(i).Interface()) cachable_byScopeType = false + } else { + value = reflect.Indirect(reflect.ValueOf(scope.Value)) } field := &StructField{ Struct: fieldStruct, @@ -325,6 +341,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } modelStruct.StructFields = append(modelStruct.StructFields, field) } + modelStruct_last = &modelStruct }() if (cachable_byScopeType) {