got tests working

This commit is contained in:
Tim Thomas 2020-03-03 17:02:06 -06:00
parent 47bd34caf0
commit ef73cc681d
4 changed files with 98 additions and 45 deletions

View File

@ -8,6 +8,7 @@ import (
"strings"
)
// preloadCallback used to preload associations
func preloadCallback(scope *Scope) {
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
@ -30,8 +31,7 @@ func preloadCallback(scope *Scope) {
var (
preloadedMap = map[string]bool{}
preloadedParentQueryMap = map[string]string{}
preloadedParentQueryVarsMap = map[string][]interface{}{}
parentQueryMap = map[string]*SqlExpr{}
fields = scope.Fields()
)
@ -41,8 +41,10 @@ func preloadCallback(scope *Scope) {
currentScope = scope
currentFields = fields
)
parentSQL := currentScope.SQL
parentSQLVars := currentScope.SQLVars
parentQuery := new(SqlExpr)
parentQuery.expr = currentScope.SQL
parentQuery.args = currentScope.SQLVars
cleanParentSql(parentQuery)
for idx, preloadField := range preloadFields {
var currentPreloadConditions []interface{}
@ -54,9 +56,8 @@ func preloadCallback(scope *Scope) {
// if not preloaded
if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
parentKey := strings.Join(preloadFields[:idx], ".")
if _, ok := preloadedParentQueryMap[parentKey]; ok {
parentSQL = preloadedParentQueryMap[parentKey]
parentSQLVars = preloadedParentQueryVarsMap[parentKey]
if _, ok := parentQueryMap[parentKey]; ok {
parentQuery = cleanParentSql(parentQueryMap[parentKey])
}
@ -72,13 +73,13 @@ func preloadCallback(scope *Scope) {
switch field.Relationship.Kind {
case "has_one":
preloadedParentQueryMap[preloadKey], preloadedParentQueryVarsMap[preloadKey] = currentScope.handleHasOnePreload(field, currentPreloadConditions, parentSQL, parentSQLVars)
parentQueryMap[preloadKey] = currentScope.handleHasOnePreload(field, currentPreloadConditions, parentQuery)
case "has_many":
preloadedParentQueryMap[preloadKey], preloadedParentQueryVarsMap[preloadKey] = currentScope.handleHasManyPreload(field, currentPreloadConditions, parentSQL, parentSQLVars)
parentQueryMap[preloadKey] = currentScope.handleHasManyPreload(field, currentPreloadConditions, parentQuery)
case "belongs_to":
preloadedParentQueryMap[preloadKey], preloadedParentQueryVarsMap[preloadKey] = currentScope.handleBelongsToPreload(field, currentPreloadConditions, parentSQL, parentSQLVars)
parentQueryMap[preloadKey] = currentScope.handleBelongsToPreload(field, currentPreloadConditions, parentQuery)
case "many_to_many":
currentScope.handleManyToManyPreload(field, currentPreloadConditions, parentSQL, parentSQLVars)
parentQueryMap[preloadKey] = currentScope.handleManyToManyPreload(field, currentPreloadConditions, parentQuery)
default:
scope.Err(errors.New("unsupported relation"))
}
@ -141,25 +142,23 @@ func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*
}
// handleHasOnePreload used to preload has one associations
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}, parentSQL string, parentSQLVars []interface{}) (string, []interface{}) {
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}, parentQuery *SqlExpr) *SqlExpr {
relation := field.Relationship
// get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 {
return "", []interface{}{}
// skip query if parent does not exist
if scope.Value == nil {
return nil
}
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
subQuerySQL := parentSQL
subQuerySQL := parentQuery.expr
subQuerySQL = "SELECT " + toQueryCondition(scope, relation.AssociationForeignDBNames) + " FROM (" + subQuerySQL + ") " + scope.Quote("preHO_" + field.DBName)
// find relations
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), subQuerySQL)
//values := toQueryValues(primaryKeys)
values := parentSQLVars
values := parentQuery.args
if relation.PolymorphicType != "" {
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
values = append(values, relation.PolymorphicValue)
@ -195,30 +194,27 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{},
scope.Err(field.Set(result))
}
}
return preloadQuery.QueryExpr().expr, preloadQuery.QueryExpr().args
return preloadQuery.QueryExpr()
}
// handleHasManyPreload used to preload has many associations
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}, parentSQL string, parentSQLVars []interface{}) (string, []interface{}) {
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}, parentSql *SqlExpr) *SqlExpr {
relation := field.Relationship
// get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 {
return "", []interface{}{}
// skip query if parent does not exist
if scope.Value == nil {
return nil
}
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
subQuerySQL := parentSQL
subQuerySQL := parentSql.expr
subQuerySQL = "SELECT " + toQueryCondition(scope, relation.AssociationForeignDBNames) + " FROM (" + subQuerySQL + ") " + scope.Quote("preHM_" + field.DBName)
// find relations
//scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), subQuerySQL), scope.SQLVars).Find(results, preloadConditions...).Error)
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), subQuerySQL)
//values := toQueryValues(primaryKeys)
values := parentSQLVars
values := parentSql.args
if relation.PolymorphicType != "" {
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
values = append(values, relation.PolymorphicValue)
@ -255,28 +251,27 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{},
} else {
scope.Err(field.Set(resultsValue))
}
return preloadQuery.QueryExpr().expr, preloadQuery.QueryExpr().args
return preloadQuery.QueryExpr()
}
// handleBelongsToPreload used to preload belongs to associations
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}, parentSQL string, parentSQLVars []interface{}) (string, []interface{}) {
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}, parentSQL *SqlExpr) *SqlExpr {
relation := field.Relationship
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
// get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 {
return "", []interface{}{}
if scope.Value == nil {
return nil
}
subQuerySQL := parentSQL
subQuerySQL := parentSQL.expr
subQuerySQL = "SELECT " + toQueryCondition(scope, relation.ForeignDBNames) + " FROM (" + subQuerySQL + ") " + scope.Quote("preBT_" + field.DBName)
// find relations
results := makeSlice(field.Struct.Type)
preloadQuery := preloadDB.Model(results).Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), subQuerySQL),parentSQLVars)
preloadQuery := preloadDB.Model(results).Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), subQuerySQL),parentSQL.args)
scope.Err(preloadQuery.Find(results, preloadConditions...).Error)
// assign find results
@ -307,11 +302,11 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
scope.Err(field.Set(result))
}
}
return preloadQuery.QueryExpr().expr, preloadQuery.QueryExpr().args
return preloadQuery.QueryExpr()
}
// handleManyToManyPreload used to preload many to many associations
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}, parentSQL string, parentSQLVars []interface{}) {
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}, parentSql *SqlExpr) *SqlExpr {
var (
relation = field.Relationship
joinTableHandler = relation.JoinTableHandler
@ -343,11 +338,9 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
preloadDB = preloadDB.Select("*")
}
// scope.Value here needs to be a subquery
preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
preloadSQL := preloadDB.QueryExpr().expr
fmt.Println(preloadSQL);
subQuerySQL := parentSql
subQuerySQL.expr = "SELECT " + toQueryCondition(scope, relation.ForeignDBNames) + " FROM (" + subQuerySQL.expr + ") " + scope.Quote("preMM_" + field.DBName)
preloadDB = joinTableHandler.JoinWithQuery(joinTableHandler, preloadDB, scope.Value, subQuerySQL)
// preload inline conditions
if len(preloadConditions) > 0 {
@ -357,7 +350,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
rows, err := preloadDB.Rows()
if scope.Err(err) != nil {
return
return nil
}
defer rows.Close()
@ -439,4 +432,5 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
f.Set(v)
}
}
return preloadDB.QueryExpr()
}

View File

@ -19,6 +19,8 @@ type JoinTableHandlerInterface interface {
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
// JoinWith query with `Join` conditions
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
// JoinWith query with `Join` query
JoinWithQuery(handler JoinTableHandlerInterface, db *DB, source interface{}, query *SqlExpr) *DB
// SourceForeignKeys return source foreign keys
SourceForeignKeys() []JoinTableForeignKey
// DestinationForeignKeys return destination foreign keys
@ -159,6 +161,53 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
}
// JoinWith query with `Join` query
func (s JoinTableHandler) JoinWithQuery(handler JoinTableHandlerInterface, db *DB, source interface{}, query *SqlExpr) *DB {
var (
scope = db.NewScope(source)
tableName = handler.Table(db)
quotedTableName = scope.Quote(tableName)
joinConditions []string
values []interface{}
)
if s.Source.ModelType == scope.GetModelStruct().ModelType {
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
for _, foreignKey := range s.Destination.ForeignKeys {
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
}
var foreignDBNames []string
var foreignFieldNames []string
for _, foreignKey := range s.Source.ForeignKeys {
foreignDBNames = append(foreignDBNames, foreignKey.DBName)
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
foreignFieldNames = append(foreignFieldNames, field.Name)
}
}
var condString string
var quotedForeignDBNames []string
for _, dbName := range foreignDBNames {
quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName)
}
if query != nil {
values = query.args
condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), query.expr)
} else {
values = []interface{}{}
condString = fmt.Sprintf("1 <> 1")
}
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
Where(condString, values)
}
db.Error = errors.New("wrong source type for join table handler")
return db
}
// JoinWith query with `Join` conditions
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
var (

View File

@ -97,6 +97,10 @@ func TestPreload(t *testing.T) {
}
func TestAutoPreload(t *testing.T) {
var users []User
DB.Find(&users)
DB.Delete(&users)
user1 := getPreloadUser("auto_user1")
DB.Save(user1)
@ -108,7 +112,6 @@ func TestAutoPreload(t *testing.T) {
user2 := getPreloadUser("auto_user2")
DB.Save(user2)
var users []User
preloadDB.Find(&users)
for _, user := range users {

View File

@ -25,6 +25,7 @@ var NowFunc = func() time.Time {
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
var commonInitialismsReplacer *strings.Replacer
var cleanInRegexp = regexp.MustCompile("\\((\\?,)+\\?\\)")
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`)
var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`)
@ -224,3 +225,9 @@ func addExtraSpaceIfExist(str string) string {
}
return ""
}
func cleanParentSql(parentSql *SqlExpr) *SqlExpr {
// having more than one query mark will confuse generating the correct number of query marks based upon args
parentSql.expr = string(cleanInRegexp.ReplaceAll([]byte(parentSql.expr), []byte("(?)")))
return parentSql
}