diff --git a/README.md b/README.md index 21707a8c..78e6e458 100644 --- a/README.md +++ b/README.md @@ -1172,6 +1172,49 @@ db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111 //// INSERT INTO "users" (email,registered_ip) VALUES ("x@example.org", "111.111.111.111") // if record not found ``` +## Embedded interface fields + +```go +type Delta struct { + Id int; + Was interface{} `gorm:"embedded:prefixed"`; + Became interface{} `gorm:"embedded:prefixed"`; +} + +type Login struct { + Login string; + Comment string; +} + +login_old := Login{Login: "Login1"}; +login_new := Login{Login: "Login2", Comment: "2015-05-18"}; + +delta := Delta { + Was: &login_old, + Became: &login_new, +} + +db.SingularTable(true); + +db.CreateTable(&delta); +//// CREATE TABLE "delta__Login__Login" ("id" integer,"was__login" varchar(255),"was__comment" varchar(255),"became__login" varchar(255),"became__comment" varchar(255) , PRIMARY KEY (id)) + +db.Save(&delta); +//// INSERT INTO "delta__Login__Login" ("was__login","was__comment","became__login","became__comment") VALUES ('Login1','','Login2','2015-05-18') + +found := Delta { + Was: &Login{}, + Became: &Login{}, +} +db.Where(&Delta{Was: &Login{Login: "Login1"}, Became: &Login{}}).First(&found); +//// SELECT * FROM "delta__Login__Login" WHERE ("was__login" = 'Login1') ORDER BY "delta__Login__Login"."id" ASC LIMIT 1 + +deltas := []Delta{{Was: &Login{}, Became: &Login{}}}; +db.Find(&deltas) +//// SELECT * FROM "deltas__Login__Login" +deltas = append(deltas[:0], deltas[1:]...) +``` + ## TODO * db.Select("Languages", "Name").Update(&user) db.Omit("Languages").Update(&user) diff --git a/field.go b/field.go index 8f5efa6d..0f8d406a 100644 --- a/field.go +++ b/field.go @@ -76,6 +76,13 @@ func (scope *Scope) Fields() map[string]*Field { func getField(indirectValue reflect.Value, structField *StructField) *Field { field := &Field{StructField: structField} for _, name := range structField.Names { + for ;reflect.Indirect(indirectValue).Kind() == reflect.Interface; { + if (indirectValue.Elem().IsValid()) { + indirectValue = indirectValue.Elem() + } else { + indirectValue.Set(reflect.New(structField.Value.Type())) + } + } indirectValue = reflect.Indirect(indirectValue).FieldByName(name) } field.Field = indirectValue diff --git a/main.go b/main.go index bf8acbae..90b2fb0d 100644 --- a/main.go +++ b/main.go @@ -126,7 +126,8 @@ func (s *DB) LogMode(enable bool) *DB { } func (s *DB) SingularTable(enable bool) { - modelStructs = map[reflect.Type]*ModelStruct{} + modelStructs_byScopeType = map[reflect.Type]*ModelStruct{} + modelStructs_byTableName = map[string ]*ModelStruct{} s.parent.singularTable = enable } diff --git a/model_struct.go b/model_struct.go index a70489fc..0a1631a9 100644 --- a/model_struct.go +++ b/model_struct.go @@ -11,7 +11,9 @@ import ( "time" ) -var modelStructs = map[reflect.Type]*ModelStruct{} +var modelStructs_byScopeType = map[reflect.Type]*ModelStruct{} +var modelStructs_byTableName = map[string ]*ModelStruct{} +var modelStruct_last *ModelStruct type ModelStruct struct { PrimaryFields []*StructField @@ -33,6 +35,7 @@ type StructField struct { Struct reflect.StructField IsForeignKey bool Relationship *Relationship + Value reflect.Value } func (structField *StructField) clone() *StructField { @@ -49,6 +52,7 @@ func (structField *StructField) clone() *StructField { Struct: structField.Struct, IsForeignKey: structField.IsForeignKey, Relationship: structField.Relationship, + Value: structField.Value, } } @@ -67,6 +71,7 @@ var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompi var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"} func (scope *Scope) GetModelStruct() *ModelStruct { + var tableName string var modelStruct ModelStruct reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) @@ -84,21 +89,42 @@ func (scope *Scope) GetModelStruct() *ModelStruct { scopeType = scopeType.Elem() } - if value, ok := modelStructs[scopeType]; ok { + if value, ok := modelStructs_byScopeType[scopeType]; ok { return value } modelStruct.ModelType = scopeType if scopeType.Kind() != reflect.Struct { + modelStruct_last = &modelStruct return &modelStruct } + // Getting table name appendix + for i := 0; i < scopeType.NumField(); i++ { + if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) { + if (fieldStruct.Type.Kind() == reflect.Interface) { + // Interface field + value := reflect.ValueOf(scope.Value).Elem() + if (value.Kind() == reflect.Slice) { + // A slice, using the first element + value = value.Index(0) + } + value = reflect.ValueOf(value.Field(i).Interface()) + if (! value.IsValid()) { + // Invalid interfaces, using Model()'s result + return modelStruct_last + } + tableName = tableName + "__" + value.Elem().Type().Name() + } + } + } // Set tablename if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() { if results := fm.Call([]reflect.Value{}); len(results) > 0 { if name, ok := results[0].Interface().(string); ok { + tableName = name + tableName modelStruct.TableName = func(*DB) string { - return name + return tableName } } } @@ -112,23 +138,48 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } + tableName = name + tableName modelStruct.TableName = func(*DB) string { - return name + return tableName } } + if value, ok := modelStructs_byTableName[tableName]; ok { + modelStruct_last = value + return value + } + // Get all fields + cachable_byScopeType := true fields := []*StructField{} for i := 0; i < scopeType.NumField(); i++ { if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) { + var value reflect.Value + if (fieldStruct.Type.Kind() == reflect.Interface) { + value = reflect.ValueOf(scope.Value).Elem() + if (value.Kind() == reflect.Slice) { + value = value.Index(0) + } + value = reflect.ValueOf(value.Field(i).Interface()) + cachable_byScopeType = false + } else { + value = reflect.Indirect(reflect.ValueOf(scope.Value)) + } + if (value.Kind() == reflect.Slice) { + if (value.Len() == 0) { + value = reflect.MakeSlice(value.Type(), 1, 1); + } + value = value.Index(0); + } field := &StructField{ Struct: fieldStruct, + Value: value, Name: fieldStruct.Name, Names: []string{fieldStruct.Name}, Tag: fieldStruct.Tag, } - if fieldStruct.Tag.Get("sql") == "-" { + if (fieldStruct.Tag.Get("sql") == "-") { field.IsIgnored = true } else { sqlSettings := parseTagSetting(field.Tag.Get("sql")) @@ -170,8 +221,15 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if !field.IsNormal { + var iface interface{} gormSettings := parseTagSetting(field.Tag.Get("gorm")) - toScope := scope.New(reflect.New(fieldStruct.Type).Interface()) + if (fieldStruct.Type.Kind() == reflect.Interface) { + indirectType = (*field).Value.Elem().Type() + iface = (*field).Value.Elem().Interface() + } else { + iface = reflect.New(fieldStruct.Type).Interface() + } + toScope := scope.New(iface) getForeignField := func(column string, fields []*StructField) *StructField { for _, field := range fields { @@ -241,10 +299,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { field.IsNormal = true } case reflect.Struct: - if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { + if embType, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { for _, toField := range toScope.GetStructFields() { toField = toField.clone() - toField.Names = append([]string{fieldStruct.Name}, toField.Names...) + if (embType == "prefixed") { + toField.DBName = field.DBName+"__"+toField.DBName + } + toField.Names = append([]string{fieldStruct.Name}, toField.Names...) modelStruct.StructFields = append(modelStruct.StructFields, toField) if toField.IsPrimaryKey { modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField) @@ -292,10 +353,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } modelStruct.StructFields = append(modelStruct.StructFields, field) } + modelStruct_last = &modelStruct }() - modelStructs[scopeType] = &modelStruct - + if (cachable_byScopeType) { + modelStructs_byScopeType[scopeType] = &modelStruct + } else { + modelStructs_byTableName[tableName] = &modelStruct + } return &modelStruct }