diff --git a/callback_query_preload.go b/callback_query_preload.go index c8bb54a4..0ee7bd64 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -8,6 +8,7 @@ import ( "strings" ) + // preloadCallback used to preload associations func preloadCallback(scope *Scope) { if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { @@ -30,8 +31,7 @@ func preloadCallback(scope *Scope) { var ( preloadedMap = map[string]bool{} - preloadedParentQueryMap = map[string]string{} - preloadedParentQueryVarsMap = map[string][]interface{}{} + parentQueryMap = map[string]*SqlExpr{} fields = scope.Fields() ) @@ -41,8 +41,10 @@ func preloadCallback(scope *Scope) { currentScope = scope currentFields = fields ) - parentSQL := currentScope.SQL - parentSQLVars := currentScope.SQLVars + parentQuery := new(SqlExpr) + parentQuery.expr = currentScope.SQL + parentQuery.args = currentScope.SQLVars + cleanParentSql(parentQuery) for idx, preloadField := range preloadFields { var currentPreloadConditions []interface{} @@ -54,9 +56,8 @@ func preloadCallback(scope *Scope) { // if not preloaded if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { parentKey := strings.Join(preloadFields[:idx], ".") - if _, ok := preloadedParentQueryMap[parentKey]; ok { - parentSQL = preloadedParentQueryMap[parentKey] - parentSQLVars = preloadedParentQueryVarsMap[parentKey] + if _, ok := parentQueryMap[parentKey]; ok { + parentQuery = cleanParentSql(parentQueryMap[parentKey]) } @@ -72,13 +73,13 @@ func preloadCallback(scope *Scope) { switch field.Relationship.Kind { case "has_one": - preloadedParentQueryMap[preloadKey], preloadedParentQueryVarsMap[preloadKey] = currentScope.handleHasOnePreload(field, currentPreloadConditions, parentSQL, parentSQLVars) + parentQueryMap[preloadKey] = currentScope.handleHasOnePreload(field, currentPreloadConditions, parentQuery) case "has_many": - preloadedParentQueryMap[preloadKey], preloadedParentQueryVarsMap[preloadKey] = currentScope.handleHasManyPreload(field, currentPreloadConditions, parentSQL, parentSQLVars) + parentQueryMap[preloadKey] = currentScope.handleHasManyPreload(field, currentPreloadConditions, parentQuery) case "belongs_to": - preloadedParentQueryMap[preloadKey], preloadedParentQueryVarsMap[preloadKey] = currentScope.handleBelongsToPreload(field, currentPreloadConditions, parentSQL, parentSQLVars) + parentQueryMap[preloadKey] = currentScope.handleBelongsToPreload(field, currentPreloadConditions, parentQuery) case "many_to_many": - currentScope.handleManyToManyPreload(field, currentPreloadConditions, parentSQL, parentSQLVars) + parentQueryMap[preloadKey] = currentScope.handleManyToManyPreload(field, currentPreloadConditions, parentQuery) default: scope.Err(errors.New("unsupported relation")) } @@ -141,25 +142,23 @@ func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (* } // handleHasOnePreload used to preload has one associations -func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}, parentSQL string, parentSQLVars []interface{}) (string, []interface{}) { +func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}, parentQuery *SqlExpr) *SqlExpr { relation := field.Relationship - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return "", []interface{}{} + // skip query if parent does not exist + if scope.Value == nil { + return nil } // preload conditions preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - subQuerySQL := parentSQL + subQuerySQL := parentQuery.expr subQuerySQL = "SELECT " + toQueryCondition(scope, relation.AssociationForeignDBNames) + " FROM (" + subQuerySQL + ") " + scope.Quote("preHO_" + field.DBName) // find relations query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), subQuerySQL) - //values := toQueryValues(primaryKeys) - values := parentSQLVars + values := parentQuery.args if relation.PolymorphicType != "" { query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) values = append(values, relation.PolymorphicValue) @@ -195,30 +194,27 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}, scope.Err(field.Set(result)) } } - return preloadQuery.QueryExpr().expr, preloadQuery.QueryExpr().args + return preloadQuery.QueryExpr() } // handleHasManyPreload used to preload has many associations -func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}, parentSQL string, parentSQLVars []interface{}) (string, []interface{}) { +func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}, parentSql *SqlExpr) *SqlExpr { relation := field.Relationship - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return "", []interface{}{} + // skip query if parent does not exist + if scope.Value == nil { + return nil } // preload conditions preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - subQuerySQL := parentSQL + subQuerySQL := parentSql.expr subQuerySQL = "SELECT " + toQueryCondition(scope, relation.AssociationForeignDBNames) + " FROM (" + subQuerySQL + ") " + scope.Quote("preHM_" + field.DBName) // find relations - //scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), subQuerySQL), scope.SQLVars).Find(results, preloadConditions...).Error) query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), subQuerySQL) - //values := toQueryValues(primaryKeys) - values := parentSQLVars + values := parentSql.args if relation.PolymorphicType != "" { query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) values = append(values, relation.PolymorphicValue) @@ -255,28 +251,27 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}, } else { scope.Err(field.Set(resultsValue)) } - return preloadQuery.QueryExpr().expr, preloadQuery.QueryExpr().args + return preloadQuery.QueryExpr() } // handleBelongsToPreload used to preload belongs to associations -func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}, parentSQL string, parentSQLVars []interface{}) (string, []interface{}) { +func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}, parentSQL *SqlExpr) *SqlExpr { relation := field.Relationship // preload conditions preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return "", []interface{}{} + if scope.Value == nil { + return nil } - subQuerySQL := parentSQL + subQuerySQL := parentSQL.expr subQuerySQL = "SELECT " + toQueryCondition(scope, relation.ForeignDBNames) + " FROM (" + subQuerySQL + ") " + scope.Quote("preBT_" + field.DBName) // find relations results := makeSlice(field.Struct.Type) - preloadQuery := preloadDB.Model(results).Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), subQuerySQL),parentSQLVars) + preloadQuery := preloadDB.Model(results).Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), subQuerySQL),parentSQL.args) scope.Err(preloadQuery.Find(results, preloadConditions...).Error) // assign find results @@ -307,11 +302,11 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ scope.Err(field.Set(result)) } } - return preloadQuery.QueryExpr().expr, preloadQuery.QueryExpr().args + return preloadQuery.QueryExpr() } // handleManyToManyPreload used to preload many to many associations -func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}, parentSQL string, parentSQLVars []interface{}) { +func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}, parentSql *SqlExpr) *SqlExpr { var ( relation = field.Relationship joinTableHandler = relation.JoinTableHandler @@ -343,11 +338,9 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface preloadDB = preloadDB.Select("*") } - // scope.Value here needs to be a subquery - preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) - - preloadSQL := preloadDB.QueryExpr().expr - fmt.Println(preloadSQL); + subQuerySQL := parentSql + subQuerySQL.expr = "SELECT " + toQueryCondition(scope, relation.ForeignDBNames) + " FROM (" + subQuerySQL.expr + ") " + scope.Quote("preMM_" + field.DBName) + preloadDB = joinTableHandler.JoinWithQuery(joinTableHandler, preloadDB, scope.Value, subQuerySQL) // preload inline conditions if len(preloadConditions) > 0 { @@ -357,7 +350,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface rows, err := preloadDB.Rows() if scope.Err(err) != nil { - return + return nil } defer rows.Close() @@ -439,4 +432,5 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface f.Set(v) } } + return preloadDB.QueryExpr() } diff --git a/join_table_handler.go b/join_table_handler.go index a036d46d..3a733c82 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -19,6 +19,8 @@ type JoinTableHandlerInterface interface { Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error // JoinWith query with `Join` conditions JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB + // JoinWith query with `Join` query + JoinWithQuery(handler JoinTableHandlerInterface, db *DB, source interface{}, query *SqlExpr) *DB // SourceForeignKeys return source foreign keys SourceForeignKeys() []JoinTableForeignKey // DestinationForeignKeys return destination foreign keys @@ -159,6 +161,53 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error } +// JoinWith query with `Join` query +func (s JoinTableHandler) JoinWithQuery(handler JoinTableHandlerInterface, db *DB, source interface{}, query *SqlExpr) *DB { + var ( + scope = db.NewScope(source) + tableName = handler.Table(db) + quotedTableName = scope.Quote(tableName) + joinConditions []string + values []interface{} + ) + + if s.Source.ModelType == scope.GetModelStruct().ModelType { + destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() + for _, foreignKey := range s.Destination.ForeignKeys { + joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) + } + + var foreignDBNames []string + var foreignFieldNames []string + + for _, foreignKey := range s.Source.ForeignKeys { + foreignDBNames = append(foreignDBNames, foreignKey.DBName) + if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { + foreignFieldNames = append(foreignFieldNames, field.Name) + } + } + + var condString string + var quotedForeignDBNames []string + for _, dbName := range foreignDBNames { + quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName) + } + + if query != nil { + values = query.args + condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), query.expr) + } else { + values = []interface{}{} + condString = fmt.Sprintf("1 <> 1") + } + return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))). + Where(condString, values) + } + + db.Error = errors.New("wrong source type for join table handler") + return db +} + // JoinWith query with `Join` conditions func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { var ( diff --git a/preload_test.go b/preload_test.go index dd29fb5e..7be4dd28 100644 --- a/preload_test.go +++ b/preload_test.go @@ -97,6 +97,10 @@ func TestPreload(t *testing.T) { } func TestAutoPreload(t *testing.T) { + var users []User + DB.Find(&users) + DB.Delete(&users) + user1 := getPreloadUser("auto_user1") DB.Save(user1) @@ -108,7 +112,6 @@ func TestAutoPreload(t *testing.T) { user2 := getPreloadUser("auto_user2") DB.Save(user2) - var users []User preloadDB.Find(&users) for _, user := range users { diff --git a/utils.go b/utils.go index d2ae9465..ae1a7ab6 100644 --- a/utils.go +++ b/utils.go @@ -25,6 +25,7 @@ var NowFunc = func() time.Time { var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialismsReplacer *strings.Replacer +var cleanInRegexp = regexp.MustCompile("\\((\\?,)+\\?\\)") var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`) var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`) @@ -224,3 +225,9 @@ func addExtraSpaceIfExist(str string) string { } return "" } + +func cleanParentSql(parentSql *SqlExpr) *SqlExpr { + // having more than one query mark will confuse generating the correct number of query marks based upon args + parentSql.expr = string(cleanInRegexp.ReplaceAll([]byte(parentSql.expr), []byte("(?)"))) + return parentSql +}