From fad7c3f662b4b6d01dac8b26622ecdf998efc08c Mon Sep 17 00:00:00 2001 From: zardak Date: Fri, 29 Jul 2016 16:17:54 +0300 Subject: [PATCH] Deduplicate ids when preloading. Omit zero-value primary keys from preloading. Fix preload deduplication for multiple primary keys --- association.go | 20 ++++++++++---------- callback_query_preload.go | 6 +++--- join_table_handler.go | 4 ++-- scope.go | 39 +++++++++++++++++++++++++++++++++++---- 4 files changed, 50 insertions(+), 19 deletions(-) diff --git a/association.go b/association.go index 0f94683d..12d66cda 100644 --- a/association.go +++ b/association.go @@ -84,7 +84,7 @@ func (association *Association) Replace(values ...interface{}) *Association { } } - newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) + newPrimaryKeys := scope.getColumnAsArrayUnique(associationForeignFieldNames, field.Interface()) if len(newPrimaryKeys) > 0 { sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys)) @@ -102,7 +102,7 @@ func (association *Association) Replace(values ...interface{}) *Association { } } - if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { + if sourcePrimaryKeys := scope.getColumnAsArrayUnique(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) @@ -147,7 +147,7 @@ func (association *Association) Delete(values ...interface{}) *Association { deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) } - deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) + deletingPrimaryKeys := scope.getColumnAsArrayUnique(deletingResourcePrimaryFieldNames, values...) if relationship.Kind == "many_to_many" { // source value's foreign keys @@ -167,7 +167,7 @@ func (association *Association) Delete(values ...interface{}) *Association { } // association value's foreign keys - deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...) + deletingPrimaryKeys := scope.getColumnAsArrayUnique(associationForeignFieldNames, values...) sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) @@ -180,7 +180,7 @@ func (association *Association) Delete(values ...interface{}) *Association { if relationship.Kind == "belongs_to" { // find with deleting relation's foreign keys - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...) + primaryKeys := scope.getColumnAsArrayUnique(relationship.AssociationForeignFieldNames, values...) newDB = newDB.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)..., @@ -197,7 +197,7 @@ func (association *Association) Delete(values ...interface{}) *Association { } } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { // find all relations - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) + primaryKeys := scope.getColumnAsArrayUnique(relationship.AssociationForeignFieldNames, scope.Value) newDB = newDB.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)..., @@ -222,7 +222,7 @@ func (association *Association) Delete(values ...interface{}) *Association { for i := 0; i < field.Len(); i++ { reflectValue := field.Index(i) - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] + primaryKey := scope.getColumnAsArrayUnique(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] var isDeleted = false for _, pk := range deletingPrimaryKeys { if equalAsString(primaryKey, pk) { @@ -237,7 +237,7 @@ func (association *Association) Delete(values ...interface{}) *Association { association.field.Set(leftValues) } else if field.Kind() == reflect.Struct { - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0] + primaryKey := scope.getColumnAsArrayUnique(deletingResourcePrimaryFieldNames, field.Interface())[0] for _, pk := range deletingPrimaryKeys { if equalAsString(primaryKey, pk) { association.field.Set(reflect.Zero(field.Type())) @@ -268,13 +268,13 @@ func (association *Association) Count() int { if relationship.Kind == "many_to_many" { query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) + primaryKeys := scope.getColumnAsArrayUnique(relationship.AssociationForeignFieldNames, scope.Value) query = query.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)..., ) } else if relationship.Kind == "belongs_to" { - primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) + primaryKeys := scope.getColumnAsArrayUnique(relationship.ForeignFieldNames, scope.Value) query = query.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)..., diff --git a/callback_query_preload.go b/callback_query_preload.go index d9ec8bdd..d56b9ffd 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -101,7 +101,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) relation := field.Relationship // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) + primaryKeys := scope.getColumnAsArrayUnique(relation.AssociationForeignFieldNames, scope.Value) if len(primaryKeys) == 0 { return } @@ -150,7 +150,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) relation := field.Relationship // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) + primaryKeys := scope.getColumnAsArrayUnique(relation.AssociationForeignFieldNames, scope.Value) if len(primaryKeys) == 0 { return } @@ -204,7 +204,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) + primaryKeys := scope.getColumnAsArrayUnique(relation.ForeignFieldNames, scope.Value) if len(primaryKeys) == 0 { return } diff --git a/join_table_handler.go b/join_table_handler.go index 18c12a85..a0bef0a3 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -178,7 +178,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so } } - foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value) + foreignFieldValues := scope.getColumnAsArrayUnique(foreignFieldNames, scope.Value) var condString string if len(foreignFieldValues) > 0 { @@ -189,7 +189,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) - keys := scope.getColumnAsArray(foreignFieldNames, scope.Value) + keys := scope.getColumnAsArrayUnique(foreignFieldNames, scope.Value) values = append(values, toQueryValues(keys)) } else { condString = fmt.Sprintf("1 <> 1") diff --git a/scope.go b/scope.go index 23a5701b..b60c6603 100644 --- a/scope.go +++ b/scope.go @@ -1204,10 +1204,7 @@ func (scope *Scope) autoIndex() *Scope { func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { for _, value := range values { - indirectValue := reflect.ValueOf(value) - for indirectValue.Kind() == reflect.Ptr { - indirectValue = indirectValue.Elem() - } + indirectValue := indirect(reflect.ValueOf(value)) switch indirectValue.Kind() { case reflect.Slice: @@ -1230,6 +1227,40 @@ func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (r return } +func (scope *Scope) getColumnAsArrayUnique(columns []string, values ...interface{}) (results [][]interface{}) { + unfilteredResults := scope.getColumnAsArray(columns, values...) + + rootMap := map[interface{}]interface{}{} + + for _, valueSet := range unfilteredResults { + currentMap := rootMap + appendResult := false + + for _, value := range valueSet { + if isBlank(reflect.ValueOf(value)) { + appendResult = false + break + } + + innerMap, ok := currentMap[value] + if !ok { + innerMap = map[interface{}]interface{}{} + currentMap[value] = innerMap + + appendResult = true + } + + currentMap = innerMap.(map[interface{}]interface{}) + } + + if appendResult { + results = append(results, valueSet) + } + } + + return +} + func (scope *Scope) getColumnAsScope(column string) *Scope { indirectScopeValue := scope.IndirectValue()