diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..e237def9 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "go.generateTestsFlags": ["-v"] +} \ No newline at end of file diff --git a/callback_query.go b/callback_query.go index 593e5d30..12c0b841 100644 --- a/callback_query.go +++ b/callback_query.go @@ -18,7 +18,7 @@ func queryCallback(scope *Scope) { if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { return } - + //we are only preloading relations, dont touch base model if _, skip := scope.InstanceGet("gorm:only_preload"); skip { return @@ -76,7 +76,7 @@ func queryCallback(scope *Scope) { elem = reflect.New(resultType).Elem() } - scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) + scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields(), elem.Addr().Interface()) if isSlice { if isPtr { diff --git a/interface_test.go b/interface_test.go new file mode 100644 index 00000000..f325e1bd --- /dev/null +++ b/interface_test.go @@ -0,0 +1,89 @@ +package gorm_test + +import ( + "reflect" + "testing" + + "github.com/kr/pretty" +) + +type ( + UserInterface interface { + UserName() string + UserType() string + } + + UserCommon struct { + Name string + Type string + } + + BasicUser struct { + User + } + + AdminUser struct { + BasicUser + } + + GroupUser struct { + GroupID int64 + User UserInterface + } + + Group struct { + Users []GroupUser + } +) + +func (m *BasicUser) UserName() string { + return m.Name +} + +func (m *BasicUser) Type() string { + return "basic" +} + +func (m *AdminUser) Type() string { + return "admin" +} + +// ScanType returns the scan type for the field +func (m *GroupUser) ScanType(field string) reflect.Type { + switch field { + case "User": + // The geometry data should be encoded as a []byte first + return reflect.TypeOf(User{}) + default: + return reflect.TypeOf(nil) + } +} + +// ScanField handle exporting scanned fields +func (m *GroupUser) ScanField(field string, data interface{}) error { + switch field { + case "User": + m.User = data.(UserInterface) + } + + return nil +} + +var tt *testing.T + +func TestInterface(t *testing.T) { + tt = t + DB.AutoMigrate(&UserCommon{}) + + user1 := UserCommon{Name: "RowUser1", type: "basic"} + + DB.Save(&user1) + + t.Log("loading the users") + users := make([]*UserWrapper, 0) + + if DB.Table("users").Find(&users).Error != nil { + t.Errorf("No errors should happen if set table for find") + } + t.Logf(pretty.Sprint(users)) +} diff --git a/model_struct.go b/model_struct.go index 8c27e209..2ae21006 100644 --- a/model_struct.go +++ b/model_struct.go @@ -21,12 +21,12 @@ var modelStructsMap sync.Map // ModelStruct model definition type ModelStruct struct { - PrimaryFields []*StructField - StructFields []*StructField - ModelType reflect.Type + PrimaryFields []*StructField + StructFields []*StructField + ModelType reflect.Type defaultTableName string - l sync.Mutex + l sync.Mutex } // TableName returns model's table name @@ -59,6 +59,7 @@ type StructField struct { IsNormal bool IsIgnored bool IsScanner bool + IsInterface bool HasDefaultValue bool Tag reflect.StructTag TagSettings map[string]string @@ -100,6 +101,7 @@ func (structField *StructField) clone() *StructField { IsNormal: structField.IsNormal, IsIgnored: structField.IsIgnored, IsScanner: structField.IsScanner, + IsInterface: structField.IsInterface, HasDefaultValue: structField.HasDefaultValue, Tag: structField.Tag, TagSettings: map[string]string{}, @@ -171,317 +173,222 @@ func (scope *Scope) GetModelStruct() *ModelStruct { // Get all fields for i := 0; i < reflectType.NumField(); i++ { - if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { - field := &StructField{ - Struct: fieldStruct, - Name: fieldStruct.Name, - Names: []string{fieldStruct.Name}, - Tag: fieldStruct.Tag, - TagSettings: parseTagSetting(fieldStruct.Tag), + fieldStruct := reflectType.Field(i) + + field := &StructField{ + Struct: fieldStruct, + Name: fieldStruct.Name, + Names: []string{fieldStruct.Name}, + Tag: fieldStruct.Tag, + TagSettings: parseTagSetting(fieldStruct.Tag), + } + + if !ast.IsExported(fieldStruct.Name) { + if _, ok := field.TagSettingsGet("INTERFACE"); ok { + field.IsInterface = true + } else { + continue + } + } + + // is ignored field + if _, ok := field.TagSettingsGet("-"); ok { + field.IsIgnored = true + } else { + if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok { + field.IsPrimaryKey = true + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) } - // is ignored field - if _, ok := field.TagSettingsGet("-"); ok { - field.IsIgnored = true + if _, ok := field.TagSettingsGet("DEFAULT"); ok { + field.HasDefaultValue = true + } + + if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey { + field.HasDefaultValue = true + } + + indirectType := fieldStruct.Type + for indirectType.Kind() == reflect.Ptr { + indirectType = indirectType.Elem() + } + + fieldValue := reflect.New(indirectType).Interface() + if _, isScanner := fieldValue.(sql.Scanner); isScanner { + // is scanner + field.IsScanner, field.IsNormal = true, true + if indirectType.Kind() == reflect.Struct { + for i := 0; i < indirectType.NumField(); i++ { + for key, value := range parseTagSetting(indirectType.Field(i).Tag) { + if _, ok := field.TagSettingsGet(key); !ok { + field.TagSettingsSet(key, value) + } + } + } + } + } else if _, isTime := fieldValue.(*time.Time); isTime { + // is time + field.IsNormal = true + } else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous { + // is embedded struct + for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields { + subField = subField.clone() + subField.Names = append([]string{fieldStruct.Name}, subField.Names...) + if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok { + subField.DBName = prefix + subField.DBName + } + + if subField.IsPrimaryKey { + if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok { + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) + } else { + subField.IsPrimaryKey = false + } + } + + if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil { + if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok { + newJoinTableHandler := &JoinTableHandler{} + newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType) + subField.Relationship.JoinTableHandler = newJoinTableHandler + } + } + + modelStruct.StructFields = append(modelStruct.StructFields, subField) + } + continue } else { - if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } + // build relationships + switch indirectType.Kind() { + case reflect.Slice: + defer func(field *StructField) { + var ( + relationship = &Relationship{} + toScope = scope.New(reflect.New(field.Struct.Type).Interface()) + foreignKeys []string + associationForeignKeys []string + elemType = field.Struct.Type + ) - if _, ok := field.TagSettingsGet("DEFAULT"); ok { - field.HasDefaultValue = true - } - - if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey { - field.HasDefaultValue = true - } - - indirectType := fieldStruct.Type - for indirectType.Kind() == reflect.Ptr { - indirectType = indirectType.Elem() - } - - fieldValue := reflect.New(indirectType).Interface() - if _, isScanner := fieldValue.(sql.Scanner); isScanner { - // is scanner - field.IsScanner, field.IsNormal = true, true - if indirectType.Kind() == reflect.Struct { - for i := 0; i < indirectType.NumField(); i++ { - for key, value := range parseTagSetting(indirectType.Field(i).Tag) { - if _, ok := field.TagSettingsGet(key); !ok { - field.TagSettingsSet(key, value) - } - } - } - } - } else if _, isTime := fieldValue.(*time.Time); isTime { - // is time - field.IsNormal = true - } else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous { - // is embedded struct - for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields { - subField = subField.clone() - subField.Names = append([]string{fieldStruct.Name}, subField.Names...) - if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok { - subField.DBName = prefix + subField.DBName + if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { + foreignKeys = strings.Split(foreignKey, ",") } - if subField.IsPrimaryKey { - if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok { - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) - } else { - subField.IsPrimaryKey = false - } + if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { + associationForeignKeys = strings.Split(foreignKey, ",") + } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { + associationForeignKeys = strings.Split(foreignKey, ",") } - if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil { - if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok { - newJoinTableHandler := &JoinTableHandler{} - newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType) - subField.Relationship.JoinTableHandler = newJoinTableHandler - } + for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() } - modelStruct.StructFields = append(modelStruct.StructFields, subField) - } - continue - } else { - // build relationships - switch indirectType.Kind() { - case reflect.Slice: - defer func(field *StructField) { - var ( - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - foreignKeys []string - associationForeignKeys []string - elemType = field.Struct.Type - ) + if elemType.Kind() == reflect.Struct { + if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { + relationship.Kind = "many_to_many" - if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { - foreignKeys = strings.Split(foreignKey, ",") - } + { // Foreign Keys for Source + joinTableDBNames := []string{} - if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { - associationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { - associationForeignKeys = strings.Split(foreignKey, ",") - } - - for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } - - if elemType.Kind() == reflect.Struct { - if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { - relationship.Kind = "many_to_many" - - { // Foreign Keys for Source - joinTableDBNames := []string{} - - if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" { - joinTableDBNames = strings.Split(foreignKey, ",") - } - - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, field.DBName) - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - // source foreign keys (db names) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) - - // setup join table foreign keys for source - if len(joinTableDBNames) > idx { - // if defined join table's foreign key - relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) - } else { - defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName - relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) - } - } - } - } - - { // Foreign Keys for Association (Destination) - associationJoinTableDBNames := []string{} - - if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" { - associationJoinTableDBNames = strings.Split(foreignKey, ",") - } - - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range toScope.PrimaryFields() { - associationForeignKeys = append(associationForeignKeys, field.DBName) - } - } - - for idx, name := range associationForeignKeys { - if field, ok := toScope.FieldByName(name); ok { - // association foreign keys (db names) - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) - - // setup join table foreign keys for association - if len(associationJoinTableDBNames) > idx { - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) - } else { - // join table foreign keys for association - joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) - } - } - } - } - - joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, ToTableName(many2many), reflectType, elemType) - relationship.JoinTableHandler = &joinTableHandler - field.Relationship = relationship - } else { - // User has many comments, associationType is User, comment use UserID as foreign key - var associationType = reflectType.Name() - var toFields = toScope.GetStructFields() - relationship.Kind = "has_many" - - if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { - // Dog has many toys, tag polymorphic is Owner, then associationType is Owner - // Toy use OwnerID, OwnerType ('dogs') as foreign key - if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { - associationType = polymorphic - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - // if Dog has multiple set of toys set name of the set (instead of default 'dogs') - if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { - relationship.PolymorphicValue = value - } else { - relationship.PolymorphicValue = scope.TableName() - } - polymorphicType.IsForeignKey = true - } + if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" { + joinTableDBNames = strings.Split(foreignKey, ",") } // if no foreign keys defined with tag if len(foreignKeys) == 0 { - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, associationType+field.Name) - associationForeignKeys = append(associationForeignKeys, field.Name) - } - } else { - // generate foreign keys from defined association foreign keys - for _, scopeFieldName := range associationForeignKeys { - if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil { - foreignKeys = append(foreignKeys, associationType+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate association foreign keys from foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, associationType) { - associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return + for _, field := range modelStruct.PrimaryFields { + foreignKeys = append(foreignKeys, field.DBName) } } for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { - // source foreign keys - foreignField.IsForeignKey = true - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) + if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { + // source foreign keys (db names) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + // setup join table foreign keys for source + if len(joinTableDBNames) > idx { + // if defined join table's foreign key + relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) + } else { + defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName + relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) } } } + } - if len(relationship.ForeignFieldNames) != 0 { - field.Relationship = relationship + { // Foreign Keys for Association (Destination) + associationJoinTableDBNames := []string{} + + if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" { + associationJoinTableDBNames = strings.Split(foreignKey, ",") + } + + // if no association foreign keys defined with tag + if len(associationForeignKeys) == 0 { + for _, field := range toScope.PrimaryFields() { + associationForeignKeys = append(associationForeignKeys, field.DBName) + } + } + + for idx, name := range associationForeignKeys { + if field, ok := toScope.FieldByName(name); ok { + // association foreign keys (db names) + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) + + // setup join table foreign keys for association + if len(associationJoinTableDBNames) > idx { + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) + } else { + // join table foreign keys for association + joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + } + } } } + + joinTableHandler := JoinTableHandler{} + joinTableHandler.Setup(relationship, ToTableName(many2many), reflectType, elemType) + relationship.JoinTableHandler = &joinTableHandler + field.Relationship = relationship } else { - field.IsNormal = true - } - }(field) - case reflect.Struct: - defer func(field *StructField) { - var ( - // user has one profile, associationType is User, profile use UserID as foreign key - // user belongs to profile, associationType is Profile, user use ProfileID as foreign key - associationType = reflectType.Name() - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - toFields = toScope.GetStructFields() - tagForeignKeys []string - tagAssociationForeignKeys []string - ) + // User has many comments, associationType is User, comment use UserID as foreign key + var associationType = reflectType.Name() + var toFields = toScope.GetStructFields() + relationship.Kind = "has_many" - if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { - tagForeignKeys = strings.Split(foreignKey, ",") - } - - if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { - tagAssociationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { - tagAssociationForeignKeys = strings.Split(foreignKey, ",") - } - - if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { - // Cat has one toy, tag polymorphic is Owner, then associationType is Owner - // Toy use OwnerID, OwnerType ('cats') as foreign key - if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { - associationType = polymorphic - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - // if Cat has several different types of toys set name for each (instead of default 'cats') - if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { - relationship.PolymorphicValue = value - } else { - relationship.PolymorphicValue = scope.TableName() + if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { + // Dog has many toys, tag polymorphic is Owner, then associationType is Owner + // Toy use OwnerID, OwnerType ('dogs') as foreign key + if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { + associationType = polymorphic + relationship.PolymorphicType = polymorphicType.Name + relationship.PolymorphicDBName = polymorphicType.DBName + // if Dog has multiple set of toys set name of the set (instead of default 'dogs') + if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { + relationship.PolymorphicValue = value + } else { + relationship.PolymorphicValue = scope.TableName() + } + polymorphicType.IsForeignKey = true } - polymorphicType.IsForeignKey = true } - } - // Has One - { - var foreignKeys = tagForeignKeys - var associationForeignKeys = tagAssociationForeignKeys // if no foreign keys defined with tag if len(foreignKeys) == 0 { // if no association foreign keys defined with tag if len(associationForeignKeys) == 0 { - for _, primaryField := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, associationType+primaryField.Name) - associationForeignKeys = append(associationForeignKeys, primaryField.Name) + for _, field := range modelStruct.PrimaryFields { + foreignKeys = append(foreignKeys, associationType+field.Name) + associationForeignKeys = append(associationForeignKeys, field.Name) } } else { - // generate foreign keys form association foreign keys - for _, associationForeignKey := range tagAssociationForeignKeys { - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + // generate foreign keys from defined association foreign keys + for _, scopeFieldName := range associationForeignKeys { + if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil { foreignKeys = append(foreignKeys, associationType+foreignField.Name) associationForeignKeys = append(associationForeignKeys, foreignField.Name) } @@ -509,73 +416,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for idx, foreignKey := range foreignKeys { if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { - foreignField.IsForeignKey = true + if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { // source foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) - - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "has_one" - field.Relationship = relationship - } else { - var foreignKeys = tagForeignKeys - var associationForeignKeys = tagAssociationForeignKeys - - if len(foreignKeys) == 0 { - // generate foreign keys & association foreign keys - if len(associationForeignKeys) == 0 { - for _, primaryField := range toScope.PrimaryFields() { - foreignKeys = append(foreignKeys, field.Name+primaryField.Name) - associationForeignKeys = append(associationForeignKeys, primaryField.Name) - } - } else { - // generate foreign keys with association foreign keys - for _, associationForeignKey := range associationForeignKeys { - if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { - foreignKeys = append(foreignKeys, field.Name+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate foreign keys & association foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, field.Name) { - associationForeignKey := strings.TrimPrefix(foreignKey, field.Name) - if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{toScope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { foreignField.IsForeignKey = true - - // association foreign keys relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - // source foreign keys + // association foreign keys relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) } @@ -583,14 +430,179 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "belongs_to" field.Relationship = relationship } } - }(field) - default: - field.IsNormal = true - } + } else { + field.IsNormal = true + } + }(field) + case reflect.Struct: + defer func(field *StructField) { + var ( + // user has one profile, associationType is User, profile use UserID as foreign key + // user belongs to profile, associationType is Profile, user use ProfileID as foreign key + associationType = reflectType.Name() + relationship = &Relationship{} + toScope = scope.New(reflect.New(field.Struct.Type).Interface()) + toFields = toScope.GetStructFields() + tagForeignKeys []string + tagAssociationForeignKeys []string + ) + + if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { + tagForeignKeys = strings.Split(foreignKey, ",") + } + + if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { + tagAssociationForeignKeys = strings.Split(foreignKey, ",") + } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { + tagAssociationForeignKeys = strings.Split(foreignKey, ",") + } + + if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { + // Cat has one toy, tag polymorphic is Owner, then associationType is Owner + // Toy use OwnerID, OwnerType ('cats') as foreign key + if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { + associationType = polymorphic + relationship.PolymorphicType = polymorphicType.Name + relationship.PolymorphicDBName = polymorphicType.DBName + // if Cat has several different types of toys set name for each (instead of default 'cats') + if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { + relationship.PolymorphicValue = value + } else { + relationship.PolymorphicValue = scope.TableName() + } + polymorphicType.IsForeignKey = true + } + } + + // Has One + { + var foreignKeys = tagForeignKeys + var associationForeignKeys = tagAssociationForeignKeys + // if no foreign keys defined with tag + if len(foreignKeys) == 0 { + // if no association foreign keys defined with tag + if len(associationForeignKeys) == 0 { + for _, primaryField := range modelStruct.PrimaryFields { + foreignKeys = append(foreignKeys, associationType+primaryField.Name) + associationForeignKeys = append(associationForeignKeys, primaryField.Name) + } + } else { + // generate foreign keys form association foreign keys + for _, associationForeignKey := range tagAssociationForeignKeys { + if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + foreignKeys = append(foreignKeys, associationType+foreignField.Name) + associationForeignKeys = append(associationForeignKeys, foreignField.Name) + } + } + } + } else { + // generate association foreign keys from foreign keys + if len(associationForeignKeys) == 0 { + for _, foreignKey := range foreignKeys { + if strings.HasPrefix(foreignKey, associationType) { + associationForeignKey := strings.TrimPrefix(foreignKey, associationType) + if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + associationForeignKeys = append(associationForeignKeys, associationForeignKey) + } + } + } + if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { + associationForeignKeys = []string{scope.PrimaryKey()} + } + } else if len(foreignKeys) != len(associationForeignKeys) { + scope.Err(errors.New("invalid foreign keys, should have same length")) + return + } + } + + for idx, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { + if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { + foreignField.IsForeignKey = true + // source foreign keys + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) + + // association foreign keys + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + } + } + } + } + + if len(relationship.ForeignFieldNames) != 0 { + relationship.Kind = "has_one" + field.Relationship = relationship + } else { + var foreignKeys = tagForeignKeys + var associationForeignKeys = tagAssociationForeignKeys + + if len(foreignKeys) == 0 { + // generate foreign keys & association foreign keys + if len(associationForeignKeys) == 0 { + for _, primaryField := range toScope.PrimaryFields() { + foreignKeys = append(foreignKeys, field.Name+primaryField.Name) + associationForeignKeys = append(associationForeignKeys, primaryField.Name) + } + } else { + // generate foreign keys with association foreign keys + for _, associationForeignKey := range associationForeignKeys { + if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { + foreignKeys = append(foreignKeys, field.Name+foreignField.Name) + associationForeignKeys = append(associationForeignKeys, foreignField.Name) + } + } + } + } else { + // generate foreign keys & association foreign keys + if len(associationForeignKeys) == 0 { + for _, foreignKey := range foreignKeys { + if strings.HasPrefix(foreignKey, field.Name) { + associationForeignKey := strings.TrimPrefix(foreignKey, field.Name) + if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { + associationForeignKeys = append(associationForeignKeys, associationForeignKey) + } + } + } + if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { + associationForeignKeys = []string{toScope.PrimaryKey()} + } + } else if len(foreignKeys) != len(associationForeignKeys) { + scope.Err(errors.New("invalid foreign keys, should have same length")) + return + } + } + + for idx, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { + if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { + foreignField.IsForeignKey = true + + // association foreign keys + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) + + // source foreign keys + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + } + } + } + + if len(relationship.ForeignFieldNames) != 0 { + relationship.Kind = "belongs_to" + field.Relationship = relationship + } + } + }(field) + case reflect.Interface: + field.IsInterface = true + default: + field.IsNormal = true } } diff --git a/scope.go b/scope.go index 806ccb7d..8e8758ab 100644 --- a/scope.go +++ b/scope.go @@ -10,6 +10,8 @@ import ( "regexp" "strings" "time" + + "github.com/kr/pretty" ) // Scope contain current operation's information when you perform any operation on the database @@ -473,18 +475,23 @@ func (scope *Scope) quoteIfPossible(str string) string { return str } -func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { +func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field, elem ...interface{}) { var ( ignored interface{} values = make([]interface{}, len(columns)) selectFields []*Field selectedColumnsMap = map[string]int{} resetFields = map[int]*Field{} + interfaceFields = map[string]interface{}{} + rootElem interface{} ) + if len(elem) > 0 { + rootElem = elem[0] + } + for index, column := range columns { values[index] = &ignored - selectFields = fields offset := 0 if idx, ok := selectedColumnsMap[column]; ok { @@ -494,7 +501,17 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { for fieldIndex, field := range selectFields { if field.DBName == column { - if field.Field.Kind() == reflect.Ptr { + if field.IsInterface { + pretty.Log(column) + if i, ok := rootElem.(interface { + ScanType(field string) reflect.Type + }); ok { + t := i.ScanType(field.DBName) + val := reflect.New(t).Interface() + values[index] = val + interfaceFields[field.DBName] = values[index] + } + } else if field.Field.Kind() == reflect.Ptr { values[index] = field.Field.Addr().Interface() } else { reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) @@ -514,6 +531,20 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { scope.Err(rows.Scan(values...)) + for k, v := range interfaceFields { + if i, ok := elem[0].(interface { + ScanField(field string, data interface{}) error + }); ok { + val := reflect.ValueOf(v) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + if err := i.ScanField(k, val.Interface()); err != nil { + fmt.Println(err) + } + } + } + for index, field := range resetFields { if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { field.Field.Set(v)