Deduplicate ids when preloading. Omit zero-value primary keys from preloading.

Fix preload deduplication for multiple primary keys
This commit is contained in:
zardak 2016-07-29 16:17:54 +03:00
parent f26fa242cc
commit fad7c3f662
4 changed files with 50 additions and 19 deletions

View File

@ -84,7 +84,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
} }
} }
newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) newPrimaryKeys := scope.getColumnAsArrayUnique(associationForeignFieldNames, field.Interface())
if len(newPrimaryKeys) > 0 { if len(newPrimaryKeys) > 0 {
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys)) sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys))
@ -102,7 +102,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
} }
} }
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { if sourcePrimaryKeys := scope.getColumnAsArrayUnique(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
@ -147,7 +147,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
} }
deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) deletingPrimaryKeys := scope.getColumnAsArrayUnique(deletingResourcePrimaryFieldNames, values...)
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
// source value's foreign keys // source value's foreign keys
@ -167,7 +167,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
} }
// association value's foreign keys // association value's foreign keys
deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...) deletingPrimaryKeys := scope.getColumnAsArrayUnique(associationForeignFieldNames, values...)
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
@ -180,7 +180,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
if relationship.Kind == "belongs_to" { if relationship.Kind == "belongs_to" {
// find with deleting relation's foreign keys // find with deleting relation's foreign keys
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...) primaryKeys := scope.getColumnAsArrayUnique(relationship.AssociationForeignFieldNames, values...)
newDB = newDB.Where( newDB = newDB.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)..., toQueryValues(primaryKeys)...,
@ -197,7 +197,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
} }
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
// find all relations // find all relations
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) primaryKeys := scope.getColumnAsArrayUnique(relationship.AssociationForeignFieldNames, scope.Value)
newDB = newDB.Where( newDB = newDB.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)..., toQueryValues(primaryKeys)...,
@ -222,7 +222,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
for i := 0; i < field.Len(); i++ { for i := 0; i < field.Len(); i++ {
reflectValue := field.Index(i) reflectValue := field.Index(i)
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] primaryKey := scope.getColumnAsArrayUnique(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
var isDeleted = false var isDeleted = false
for _, pk := range deletingPrimaryKeys { for _, pk := range deletingPrimaryKeys {
if equalAsString(primaryKey, pk) { if equalAsString(primaryKey, pk) {
@ -237,7 +237,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
association.field.Set(leftValues) association.field.Set(leftValues)
} else if field.Kind() == reflect.Struct { } else if field.Kind() == reflect.Struct {
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0] primaryKey := scope.getColumnAsArrayUnique(deletingResourcePrimaryFieldNames, field.Interface())[0]
for _, pk := range deletingPrimaryKeys { for _, pk := range deletingPrimaryKeys {
if equalAsString(primaryKey, pk) { if equalAsString(primaryKey, pk) {
association.field.Set(reflect.Zero(field.Type())) association.field.Set(reflect.Zero(field.Type()))
@ -268,13 +268,13 @@ func (association *Association) Count() int {
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) primaryKeys := scope.getColumnAsArrayUnique(relationship.AssociationForeignFieldNames, scope.Value)
query = query.Where( query = query.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)..., toQueryValues(primaryKeys)...,
) )
} else if relationship.Kind == "belongs_to" { } else if relationship.Kind == "belongs_to" {
primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) primaryKeys := scope.getColumnAsArrayUnique(relationship.ForeignFieldNames, scope.Value)
query = query.Where( query = query.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)..., toQueryValues(primaryKeys)...,

View File

@ -101,7 +101,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
relation := field.Relationship relation := field.Relationship
// get relations's primary keys // get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) primaryKeys := scope.getColumnAsArrayUnique(relation.AssociationForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 { if len(primaryKeys) == 0 {
return return
} }
@ -150,7 +150,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
relation := field.Relationship relation := field.Relationship
// get relations's primary keys // get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) primaryKeys := scope.getColumnAsArrayUnique(relation.AssociationForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 { if len(primaryKeys) == 0 {
return return
} }
@ -204,7 +204,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
// get relations's primary keys // get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) primaryKeys := scope.getColumnAsArrayUnique(relation.ForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 { if len(primaryKeys) == 0 {
return return
} }

View File

@ -178,7 +178,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
} }
} }
foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value) foreignFieldValues := scope.getColumnAsArrayUnique(foreignFieldNames, scope.Value)
var condString string var condString string
if len(foreignFieldValues) > 0 { if len(foreignFieldValues) > 0 {
@ -189,7 +189,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
keys := scope.getColumnAsArray(foreignFieldNames, scope.Value) keys := scope.getColumnAsArrayUnique(foreignFieldNames, scope.Value)
values = append(values, toQueryValues(keys)) values = append(values, toQueryValues(keys))
} else { } else {
condString = fmt.Sprintf("1 <> 1") condString = fmt.Sprintf("1 <> 1")

View File

@ -1204,10 +1204,7 @@ func (scope *Scope) autoIndex() *Scope {
func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
for _, value := range values { for _, value := range values {
indirectValue := reflect.ValueOf(value) indirectValue := indirect(reflect.ValueOf(value))
for indirectValue.Kind() == reflect.Ptr {
indirectValue = indirectValue.Elem()
}
switch indirectValue.Kind() { switch indirectValue.Kind() {
case reflect.Slice: case reflect.Slice:
@ -1230,6 +1227,40 @@ func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (r
return return
} }
func (scope *Scope) getColumnAsArrayUnique(columns []string, values ...interface{}) (results [][]interface{}) {
unfilteredResults := scope.getColumnAsArray(columns, values...)
rootMap := map[interface{}]interface{}{}
for _, valueSet := range unfilteredResults {
currentMap := rootMap
appendResult := false
for _, value := range valueSet {
if isBlank(reflect.ValueOf(value)) {
appendResult = false
break
}
innerMap, ok := currentMap[value]
if !ok {
innerMap = map[interface{}]interface{}{}
currentMap[value] = innerMap
appendResult = true
}
currentMap = innerMap.(map[interface{}]interface{})
}
if appendResult {
results = append(results, valueSet)
}
}
return
}
func (scope *Scope) getColumnAsScope(column string) *Scope { func (scope *Scope) getColumnAsScope(column string) *Scope {
indirectScopeValue := scope.IndirectValue() indirectScopeValue := scope.IndirectValue()