Adds support for "associationforeignkey" tag in the has_one and belongs_to cases to handle when the foreign key is not the primary key.

This commit is contained in:
Crystalin 2015-06-15 11:17:04 +02:00
parent cde05781a0
commit 5808585d18
2 changed files with 50 additions and 16 deletions

View File

@ -191,6 +191,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
var relationship = &Relationship{} var relationship = &Relationship{}
foreignKey := gormSettings["FOREIGNKEY"] foreignKey := gormSettings["FOREIGNKEY"]
associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]
if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" {
if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil { if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil {
if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil {
@ -218,7 +219,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
if many2many := gormSettings["MANY2MANY"]; many2many != "" { if many2many := gormSettings["MANY2MANY"]; many2many != "" {
relationship.Kind = "many_to_many" relationship.Kind = "many_to_many"
associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]
if associationForeignKey == "" { if associationForeignKey == "" {
associationForeignKey = elemType.Name() + "Id" associationForeignKey = elemType.Name() + "Id"
} }
@ -269,6 +269,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
relationship.ForeignDBName = foreignField.DBName relationship.ForeignDBName = foreignField.DBName
foreignField.IsForeignKey = true foreignField.IsForeignKey = true
field.Relationship = relationship field.Relationship = relationship
if associationForeignKey != "" {
if associatedField := getForeignField(associationForeignKey, toScope.GetStructFields()); associatedField != nil {
relationship.AssociationForeignFieldName = associatedField.Name
relationship.AssociationForeignDBName = associatedField.DBName
}
}
} else { } else {
if foreignKey == "" { if foreignKey == "" {
foreignKey = modelStruct.ModelType.Name() + "Id" foreignKey = modelStruct.ModelType.Name() + "Id"
@ -277,6 +283,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
relationship.ForeignFieldName = foreignField.Name relationship.ForeignFieldName = foreignField.Name
relationship.ForeignDBName = foreignField.DBName relationship.ForeignDBName = foreignField.DBName
if associationForeignKey != "" {
if associatedField := getForeignField(associationForeignKey, fields); associatedField != nil {
relationship.AssociationForeignFieldName = associatedField.Name
relationship.AssociationForeignDBName = associatedField.DBName
}
}
foreignField.IsForeignKey = true foreignField.IsForeignKey = true
field.Relationship = relationship field.Relationship = relationship
} else if relationship.ForeignFieldName != "" { } else if relationship.ForeignFieldName != "" {

View File

@ -97,17 +97,22 @@ func makeSlice(typ reflect.Type) interface{} {
} }
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
primaryName := scope.PrimaryField().Name var keyName string
primaryKeys := scope.getColumnAsArray(primaryName) relation := field.Relationship
if len(primaryKeys) == 0 { if relation.AssociationForeignFieldName != "" {
keyName = relation.AssociationForeignFieldName
} else {
keyName = scope.PrimaryField().Name
}
associatedKeys := scope.getColumnAsArray(keyName)
if len(associatedKeys) == 0 {
return return
} }
results := makeSlice(field.Struct.Type) results := makeSlice(field.Struct.Type)
relation := field.Relationship
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error) scope.Err(scope.NewDB().Where(condition, associatedKeys).Find(results, conditions...).Error)
resultValues := reflect.Indirect(reflect.ValueOf(results)) resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ { for i := 0; i < resultValues.Len(); i++ {
@ -116,7 +121,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
value := getRealValue(result, relation.ForeignFieldName) value := getRealValue(result, relation.ForeignFieldName)
objects := scope.IndirectValue() objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ { for j := 0; j < objects.Len(); j++ {
if equalAsString(getRealValue(objects.Index(j), primaryName), value) { if equalAsString(getRealValue(objects.Index(j), keyName), value) {
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
break break
} }
@ -131,17 +136,22 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
} }
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
primaryName := scope.PrimaryField().Name var keyName string
primaryKeys := scope.getColumnAsArray(primaryName) relation := field.Relationship
if len(primaryKeys) == 0 { if relation.AssociationForeignFieldName != "" {
keyName = relation.AssociationForeignFieldName
} else {
keyName = scope.PrimaryField().Name
}
associatedKeys := scope.getColumnAsArray(keyName)
if len(associatedKeys) == 0 {
return return
} }
results := makeSlice(field.Struct.Type) results := makeSlice(field.Struct.Type)
relation := field.Relationship
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error) scope.Err(scope.NewDB().Where(condition, associatedKeys).Find(results, conditions...).Error)
resultValues := reflect.Indirect(reflect.ValueOf(results)) resultValues := reflect.Indirect(reflect.ValueOf(results))
if scope.IndirectValue().Kind() == reflect.Slice { if scope.IndirectValue().Kind() == reflect.Slice {
@ -151,7 +161,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
objects := scope.IndirectValue() objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ { for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j)) object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, primaryName), value) { if equalAsString(getRealValue(object, keyName), value) {
f := object.FieldByName(field.Name) f := object.FieldByName(field.Name)
f.Set(reflect.Append(f, result)) f.Set(reflect.Append(f, result))
break break
@ -171,15 +181,27 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
} }
results := makeSlice(field.Struct.Type) results := makeSlice(field.Struct.Type)
associationPrimaryKey := scope.New(results).PrimaryField().Name
scope.Err(scope.NewDB().Where(primaryKeys).Find(results, conditions...).Error) var keyName string
if relation.AssociationForeignFieldName != "" {
keyName = relation.AssociationForeignFieldName
} else {
keyName = scope.New(results).PrimaryField().Name
}
foreignKey, ok := scope.New(results).FieldByName(keyName)
if !ok {
return
}
condition := fmt.Sprintf("%v IN (?)", scope.Quote(foreignKey.DBName))
scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error)
resultValues := reflect.Indirect(reflect.ValueOf(results)) resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ { for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i) result := resultValues.Index(i)
if scope.IndirectValue().Kind() == reflect.Slice { if scope.IndirectValue().Kind() == reflect.Slice {
value := getRealValue(result, associationPrimaryKey) value := getRealValue(result, keyName)
objects := scope.IndirectValue() objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ { for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j)) object := reflect.Indirect(objects.Index(j))