Added support of embedded interface fields

This commit is contained in:
Dmitry Yu Okunev 2015-05-18 11:06:05 +03:00
parent 371cd41204
commit 6c805b6a0e
No known key found for this signature in database
GPG Key ID: AD8AE40C8E30679C
4 changed files with 88 additions and 10 deletions

View File

@ -1172,6 +1172,44 @@ db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111
//// INSERT INTO "users" (email,registered_ip) VALUES ("x@example.org", "111.111.111.111") // if record not found //// INSERT INTO "users" (email,registered_ip) VALUES ("x@example.org", "111.111.111.111") // if record not found
``` ```
## Tables with embedded iterface fields
```go
type Delta struct {
Id int;
Was interface{} `gorm:"embedded"`;
Became interface{} `gorm:"embedded"`;
}
type Login struct {
Login string;
Comment string;
}
login_old := Login{Login: "Login1"};
login_new := Login{Login: "Login2", Comment: "2015-05-18"};
delta := Delta {
Was: &login_old,
Became: &login_new,
}
db.SingularTable(true);
db.CreateTable(&delta);
//// CREATE TABLE "delta__Login__Login" ("id" integer,"was__login" varchar(255),"was__comment" varchar(255),"became__login" varchar(255),"became__comment" varchar(255) , PRIMARY KEY (id))
db.Save(&delta);
//// INSERT INTO "delta__Login__Login" ("was__login","was__comment","became__login","became__comment") VALUES ('Login1','','Login2','2015-05-18')
found := Delta {
Was: &Login{},
Became: &Login{},
}
db.Where(&Delta{Was: &Login{Login: "Login1"}, Became: &Login{}}).First(&found);
//// SELECT * FROM "delta__Login__Login" WHERE ("was__login" = 'Login1') ORDER BY "delta__Login__Login"."id" ASC LIMIT 1
```
## TODO ## TODO
* db.Select("Languages", "Name").Update(&user) * db.Select("Languages", "Name").Update(&user)
db.Omit("Languages").Update(&user) db.Omit("Languages").Update(&user)

View File

@ -76,6 +76,9 @@ func (scope *Scope) Fields() map[string]*Field {
func getField(indirectValue reflect.Value, structField *StructField) *Field { func getField(indirectValue reflect.Value, structField *StructField) *Field {
field := &Field{StructField: structField} field := &Field{StructField: structField}
for _, name := range structField.Names { for _, name := range structField.Names {
if (reflect.Indirect(indirectValue).Kind() == reflect.Interface) {
indirectValue = indirectValue.Elem()
}
indirectValue = reflect.Indirect(indirectValue).FieldByName(name) indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
} }
field.Field = indirectValue field.Field = indirectValue

View File

@ -126,7 +126,8 @@ func (s *DB) LogMode(enable bool) *DB {
} }
func (s *DB) SingularTable(enable bool) { func (s *DB) SingularTable(enable bool) {
modelStructs = map[reflect.Type]*ModelStruct{} modelStructs_byScopeType = map[reflect.Type]*ModelStruct{}
modelStructs_byTableName = map[string ]*ModelStruct{}
s.parent.singularTable = enable s.parent.singularTable = enable
} }

View File

@ -11,7 +11,8 @@ import (
"time" "time"
) )
var modelStructs = map[reflect.Type]*ModelStruct{} var modelStructs_byScopeType = map[reflect.Type]*ModelStruct{}
var modelStructs_byTableName = map[string ]*ModelStruct{}
type ModelStruct struct { type ModelStruct struct {
PrimaryFields []*StructField PrimaryFields []*StructField
@ -33,6 +34,7 @@ type StructField struct {
Struct reflect.StructField Struct reflect.StructField
IsForeignKey bool IsForeignKey bool
Relationship *Relationship Relationship *Relationship
Value reflect.Value
} }
func (structField *StructField) clone() *StructField { func (structField *StructField) clone() *StructField {
@ -49,6 +51,7 @@ func (structField *StructField) clone() *StructField {
Struct: structField.Struct, Struct: structField.Struct,
IsForeignKey: structField.IsForeignKey, IsForeignKey: structField.IsForeignKey,
Relationship: structField.Relationship, Relationship: structField.Relationship,
Value: structField.Value,
} }
} }
@ -67,6 +70,7 @@ var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompi
var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"} var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"}
func (scope *Scope) GetModelStruct() *ModelStruct { func (scope *Scope) GetModelStruct() *ModelStruct {
var tableName string
var modelStruct ModelStruct var modelStruct ModelStruct
reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value))
@ -84,7 +88,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
scopeType = scopeType.Elem() scopeType = scopeType.Elem()
} }
if value, ok := modelStructs[scopeType]; ok { if value, ok := modelStructs_byScopeType[scopeType]; ok {
return value return value
} }
@ -94,11 +98,20 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
// Set tablename // Set tablename
for i := 0; i < scopeType.NumField(); i++ {
if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) {
if (fieldStruct.Type.Kind() == reflect.Interface) {
value := reflect.ValueOf(reflect.ValueOf(scope.Value).Elem().Field(i).Interface())
tableName = tableName + "__" + value.Elem().Type().Name()
}
}
}
if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() { if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() {
if results := fm.Call([]reflect.Value{}); len(results) > 0 { if results := fm.Call([]reflect.Value{}); len(results) > 0 {
if name, ok := results[0].Interface().(string); ok { if name, ok := results[0].Interface().(string); ok {
tableName = name + tableName
modelStruct.TableName = func(*DB) string { modelStruct.TableName = func(*DB) string {
return name return tableName
} }
} }
} }
@ -112,23 +125,35 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
} }
tableName = name + tableName
modelStruct.TableName = func(*DB) string { modelStruct.TableName = func(*DB) string {
return name return tableName
} }
} }
if value, ok := modelStructs_byTableName[tableName]; ok {
return value
}
// Get all fields // Get all fields
cachable_byScopeType := true
fields := []*StructField{} fields := []*StructField{}
for i := 0; i < scopeType.NumField(); i++ { for i := 0; i < scopeType.NumField(); i++ {
if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) { if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) {
var value reflect.Value
if (fieldStruct.Type.Kind() == reflect.Interface) {
value = reflect.ValueOf(reflect.ValueOf(scope.Value).Elem().Field(i).Interface())
cachable_byScopeType = false
}
field := &StructField{ field := &StructField{
Struct: fieldStruct, Struct: fieldStruct,
Value: value,
Name: fieldStruct.Name, Name: fieldStruct.Name,
Names: []string{fieldStruct.Name}, Names: []string{fieldStruct.Name},
Tag: fieldStruct.Tag, Tag: fieldStruct.Tag,
} }
if fieldStruct.Tag.Get("sql") == "-" { if (fieldStruct.Tag.Get("sql") == "-") {
field.IsIgnored = true field.IsIgnored = true
} else { } else {
sqlSettings := parseTagSetting(field.Tag.Get("sql")) sqlSettings := parseTagSetting(field.Tag.Get("sql"))
@ -170,8 +195,15 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
if !field.IsNormal { if !field.IsNormal {
var iface interface{}
gormSettings := parseTagSetting(field.Tag.Get("gorm")) gormSettings := parseTagSetting(field.Tag.Get("gorm"))
toScope := scope.New(reflect.New(fieldStruct.Type).Interface()) if (fieldStruct.Type.Kind() == reflect.Interface) {
indirectType = (*field).Value.Elem().Type()
iface = (*field).Value.Elem().Interface()
} else {
iface = reflect.New(fieldStruct.Type).Interface()
}
toScope := scope.New(iface)
getForeignField := func(column string, fields []*StructField) *StructField { getForeignField := func(column string, fields []*StructField) *StructField {
for _, field := range fields { for _, field := range fields {
@ -244,7 +276,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
for _, toField := range toScope.GetStructFields() { for _, toField := range toScope.GetStructFields() {
toField = toField.clone() toField = toField.clone()
toField.Names = append([]string{fieldStruct.Name}, toField.Names...) toField.DBName = field.DBName+"__"+toField.DBName
toField.Names = append([]string{fieldStruct.Name}, toField.Names...)
modelStruct.StructFields = append(modelStruct.StructFields, toField) modelStruct.StructFields = append(modelStruct.StructFields, toField)
if toField.IsPrimaryKey { if toField.IsPrimaryKey {
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField) modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField)
@ -294,8 +327,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
}() }()
modelStructs[scopeType] = &modelStruct if (cachable_byScopeType) {
modelStructs_byScopeType[scopeType] = &modelStruct
} else {
modelStructs_byTableName[tableName] = &modelStruct
}
return &modelStruct return &modelStruct
} }