diff --git a/association.go b/association.go index 14fd1c35..c36a31b7 100644 --- a/association.go +++ b/association.go @@ -86,7 +86,7 @@ func (association *Association) Replace(values ...interface{}) *Association { } } - newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) + newPrimaryKeys := scope.getColumnAsArrayUnique(associationForeignFieldNames, field.Interface()) if len(newPrimaryKeys) > 0 { sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys)) @@ -104,7 +104,7 @@ func (association *Association) Replace(values ...interface{}) *Association { } } - if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { + if sourcePrimaryKeys := scope.getColumnAsArrayUnique(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) @@ -149,7 +149,7 @@ func (association *Association) Delete(values ...interface{}) *Association { deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) } - deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) + deletingPrimaryKeys := scope.getColumnAsArrayUnique(deletingResourcePrimaryFieldNames, values...) if relationship.Kind == "many_to_many" { // source value's foreign keys @@ -169,7 +169,7 @@ func (association *Association) Delete(values ...interface{}) *Association { } // association value's foreign keys - deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...) + deletingPrimaryKeys := scope.getColumnAsArrayUnique(associationForeignFieldNames, values...) sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) @@ -182,7 +182,7 @@ func (association *Association) Delete(values ...interface{}) *Association { if relationship.Kind == "belongs_to" { // find with deleting relation's foreign keys - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...) + primaryKeys := scope.getColumnAsArrayUnique(relationship.AssociationForeignFieldNames, values...) newDB = newDB.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)..., @@ -199,7 +199,7 @@ func (association *Association) Delete(values ...interface{}) *Association { } } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { // find all relations - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) + primaryKeys := scope.getColumnAsArrayUnique(relationship.AssociationForeignFieldNames, scope.Value) newDB = newDB.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)..., @@ -224,7 +224,7 @@ func (association *Association) Delete(values ...interface{}) *Association { for i := 0; i < field.Len(); i++ { reflectValue := field.Index(i) - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] + primaryKey := scope.getColumnAsArrayUnique(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] var isDeleted = false for _, pk := range deletingPrimaryKeys { if equalAsString(primaryKey, pk) { @@ -239,7 +239,7 @@ func (association *Association) Delete(values ...interface{}) *Association { association.field.Set(leftValues) } else if field.Kind() == reflect.Struct { - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0] + primaryKey := scope.getColumnAsArrayUnique(deletingResourcePrimaryFieldNames, field.Interface())[0] for _, pk := range deletingPrimaryKeys { if equalAsString(primaryKey, pk) { association.field.Set(reflect.Zero(field.Type())) @@ -270,13 +270,13 @@ func (association *Association) Count() int { if relationship.Kind == "many_to_many" { query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) + primaryKeys := scope.getColumnAsArrayUnique(relationship.AssociationForeignFieldNames, scope.Value) query = query.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)..., ) } else if relationship.Kind == "belongs_to" { - primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) + primaryKeys := scope.getColumnAsArrayUnique(relationship.ForeignFieldNames, scope.Value) query = query.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)..., diff --git a/callback_query_preload.go b/callback_query_preload.go index fff252c9..393390a6 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -126,7 +126,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) relation := field.Relationship // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) + primaryKeys := scope.getColumnAsArrayUnique(relation.AssociationForeignFieldNames, scope.Value) if len(primaryKeys) == 0 { return } @@ -175,7 +175,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) relation := field.Relationship // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) + primaryKeys := scope.getColumnAsArrayUnique(relation.AssociationForeignFieldNames, scope.Value) if len(primaryKeys) == 0 { return } @@ -231,7 +231,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) + primaryKeys := scope.getColumnAsArrayUnique(relation.ForeignFieldNames, scope.Value) if len(primaryKeys) == 0 { return } diff --git a/join_table_handler.go b/join_table_handler.go index 18c12a85..a0bef0a3 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -178,7 +178,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so } } - foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value) + foreignFieldValues := scope.getColumnAsArrayUnique(foreignFieldNames, scope.Value) var condString string if len(foreignFieldValues) > 0 { @@ -189,7 +189,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) - keys := scope.getColumnAsArray(foreignFieldNames, scope.Value) + keys := scope.getColumnAsArrayUnique(foreignFieldNames, scope.Value) values = append(values, toQueryValues(keys)) } else { condString = fmt.Sprintf("1 <> 1") diff --git a/scope.go b/scope.go index 9a237998..47a768f3 100644 --- a/scope.go +++ b/scope.go @@ -1270,6 +1270,40 @@ func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (r return } +func (scope *Scope) getColumnAsArrayUnique(columns []string, values ...interface{}) (results [][]interface{}) { + unfilteredResults := scope.getColumnAsArray(columns, values...) + + rootMap := map[interface{}]interface{}{} + + for _, valueSet := range unfilteredResults { + currentMap := rootMap + appendResult := false + + for _, value := range valueSet { + if isBlank(reflect.ValueOf(value)) { + appendResult = false + break + } + + innerMap, ok := currentMap[value] + if !ok { + innerMap = map[interface{}]interface{}{} + currentMap[value] = innerMap + + appendResult = true + } + + currentMap = innerMap.(map[interface{}]interface{}) + } + + if appendResult { + results = append(results, valueSet) + } + } + + return +} + func (scope *Scope) getColumnAsScope(column string) *Scope { indirectScopeValue := scope.IndirectValue()