diff --git a/model_struct.go b/model_struct.go index 10423ae2..ac8e1369 100644 --- a/model_struct.go +++ b/model_struct.go @@ -191,6 +191,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { var relationship = &Relationship{} foreignKey := gormSettings["FOREIGNKEY"] + associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"] if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil { if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { @@ -218,7 +219,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if many2many := gormSettings["MANY2MANY"]; many2many != "" { relationship.Kind = "many_to_many" - associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"] if associationForeignKey == "" { associationForeignKey = elemType.Name() + "Id" } @@ -269,6 +269,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.ForeignDBName = foreignField.DBName foreignField.IsForeignKey = true field.Relationship = relationship + if associationForeignKey != "" { + if associatedField := getForeignField(associationForeignKey, toScope.GetStructFields()); associatedField != nil { + relationship.AssociationForeignFieldName = associatedField.Name + relationship.AssociationForeignDBName = associatedField.DBName + } + } } else { if foreignKey == "" { foreignKey = modelStruct.ModelType.Name() + "Id" @@ -277,6 +283,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { relationship.ForeignFieldName = foreignField.Name relationship.ForeignDBName = foreignField.DBName + if associationForeignKey != "" { + if associatedField := getForeignField(associationForeignKey, fields); associatedField != nil { + relationship.AssociationForeignFieldName = associatedField.Name + relationship.AssociationForeignDBName = associatedField.DBName + } + } foreignField.IsForeignKey = true field.Relationship = relationship } else if relationship.ForeignFieldName != "" { diff --git a/preload.go b/preload.go index 03910c44..51c22ee5 100644 --- a/preload.go +++ b/preload.go @@ -97,17 +97,22 @@ func makeSlice(typ reflect.Type) interface{} { } func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { - primaryName := scope.PrimaryField().Name - primaryKeys := scope.getColumnAsArray(primaryName) - if len(primaryKeys) == 0 { + var keyName string + relation := field.Relationship + if relation.AssociationForeignFieldName != "" { + keyName = relation.AssociationForeignFieldName + } else { + keyName = scope.PrimaryField().Name + } + associatedKeys := scope.getColumnAsArray(keyName) + if len(associatedKeys) == 0 { return } results := makeSlice(field.Struct.Type) - relation := field.Relationship condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error) + scope.Err(scope.NewDB().Where(condition, associatedKeys).Find(results, conditions...).Error) resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { @@ -116,7 +121,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) value := getRealValue(result, relation.ForeignFieldName) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { - if equalAsString(getRealValue(objects.Index(j), primaryName), value) { + if equalAsString(getRealValue(objects.Index(j), keyName), value) { reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) break } @@ -131,17 +136,22 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) } func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { - primaryName := scope.PrimaryField().Name - primaryKeys := scope.getColumnAsArray(primaryName) - if len(primaryKeys) == 0 { + var keyName string + relation := field.Relationship + if relation.AssociationForeignFieldName != "" { + keyName = relation.AssociationForeignFieldName + } else { + keyName = scope.PrimaryField().Name + } + associatedKeys := scope.getColumnAsArray(keyName) + if len(associatedKeys) == 0 { return } results := makeSlice(field.Struct.Type) - relation := field.Relationship condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error) + scope.Err(scope.NewDB().Where(condition, associatedKeys).Find(results, conditions...).Error) resultValues := reflect.Indirect(reflect.ValueOf(results)) if scope.IndirectValue().Kind() == reflect.Slice { @@ -151,7 +161,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, primaryName), value) { + if equalAsString(getRealValue(object, keyName), value) { f := object.FieldByName(field.Name) f.Set(reflect.Append(f, result)) break @@ -171,15 +181,27 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ } results := makeSlice(field.Struct.Type) - associationPrimaryKey := scope.New(results).PrimaryField().Name - scope.Err(scope.NewDB().Where(primaryKeys).Find(results, conditions...).Error) + var keyName string + if relation.AssociationForeignFieldName != "" { + keyName = relation.AssociationForeignFieldName + } else { + keyName = scope.New(results).PrimaryField().Name + } + + foreignKey, ok := scope.New(results).FieldByName(keyName) + if !ok { + return + } + + condition := fmt.Sprintf("%v IN (?)", scope.Quote(foreignKey.DBName)) + scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error) resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) if scope.IndirectValue().Kind() == reflect.Slice { - value := getRealValue(result, associationPrimaryKey) + value := getRealValue(result, keyName) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j))