got tests working

This commit is contained in:
Tim Thomas 2020-03-03 17:02:06 -06:00
parent 47bd34caf0
commit ef73cc681d
4 changed files with 98 additions and 45 deletions

View File

@ -8,6 +8,7 @@ import (
"strings" "strings"
) )
// preloadCallback used to preload associations // preloadCallback used to preload associations
func preloadCallback(scope *Scope) { func preloadCallback(scope *Scope) {
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
@ -30,8 +31,7 @@ func preloadCallback(scope *Scope) {
var ( var (
preloadedMap = map[string]bool{} preloadedMap = map[string]bool{}
preloadedParentQueryMap = map[string]string{} parentQueryMap = map[string]*SqlExpr{}
preloadedParentQueryVarsMap = map[string][]interface{}{}
fields = scope.Fields() fields = scope.Fields()
) )
@ -41,8 +41,10 @@ func preloadCallback(scope *Scope) {
currentScope = scope currentScope = scope
currentFields = fields currentFields = fields
) )
parentSQL := currentScope.SQL parentQuery := new(SqlExpr)
parentSQLVars := currentScope.SQLVars parentQuery.expr = currentScope.SQL
parentQuery.args = currentScope.SQLVars
cleanParentSql(parentQuery)
for idx, preloadField := range preloadFields { for idx, preloadField := range preloadFields {
var currentPreloadConditions []interface{} var currentPreloadConditions []interface{}
@ -54,9 +56,8 @@ func preloadCallback(scope *Scope) {
// if not preloaded // if not preloaded
if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
parentKey := strings.Join(preloadFields[:idx], ".") parentKey := strings.Join(preloadFields[:idx], ".")
if _, ok := preloadedParentQueryMap[parentKey]; ok { if _, ok := parentQueryMap[parentKey]; ok {
parentSQL = preloadedParentQueryMap[parentKey] parentQuery = cleanParentSql(parentQueryMap[parentKey])
parentSQLVars = preloadedParentQueryVarsMap[parentKey]
} }
@ -72,13 +73,13 @@ func preloadCallback(scope *Scope) {
switch field.Relationship.Kind { switch field.Relationship.Kind {
case "has_one": case "has_one":
preloadedParentQueryMap[preloadKey], preloadedParentQueryVarsMap[preloadKey] = currentScope.handleHasOnePreload(field, currentPreloadConditions, parentSQL, parentSQLVars) parentQueryMap[preloadKey] = currentScope.handleHasOnePreload(field, currentPreloadConditions, parentQuery)
case "has_many": case "has_many":
preloadedParentQueryMap[preloadKey], preloadedParentQueryVarsMap[preloadKey] = currentScope.handleHasManyPreload(field, currentPreloadConditions, parentSQL, parentSQLVars) parentQueryMap[preloadKey] = currentScope.handleHasManyPreload(field, currentPreloadConditions, parentQuery)
case "belongs_to": case "belongs_to":
preloadedParentQueryMap[preloadKey], preloadedParentQueryVarsMap[preloadKey] = currentScope.handleBelongsToPreload(field, currentPreloadConditions, parentSQL, parentSQLVars) parentQueryMap[preloadKey] = currentScope.handleBelongsToPreload(field, currentPreloadConditions, parentQuery)
case "many_to_many": case "many_to_many":
currentScope.handleManyToManyPreload(field, currentPreloadConditions, parentSQL, parentSQLVars) parentQueryMap[preloadKey] = currentScope.handleManyToManyPreload(field, currentPreloadConditions, parentQuery)
default: default:
scope.Err(errors.New("unsupported relation")) scope.Err(errors.New("unsupported relation"))
} }
@ -141,25 +142,23 @@ func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*
} }
// handleHasOnePreload used to preload has one associations // handleHasOnePreload used to preload has one associations
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}, parentSQL string, parentSQLVars []interface{}) (string, []interface{}) { func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}, parentQuery *SqlExpr) *SqlExpr {
relation := field.Relationship relation := field.Relationship
// get relations's primary keys // skip query if parent does not exist
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) if scope.Value == nil {
if len(primaryKeys) == 0 { return nil
return "", []interface{}{}
} }
// preload conditions // preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
subQuerySQL := parentSQL subQuerySQL := parentQuery.expr
subQuerySQL = "SELECT " + toQueryCondition(scope, relation.AssociationForeignDBNames) + " FROM (" + subQuerySQL + ") " + scope.Quote("preHO_" + field.DBName) subQuerySQL = "SELECT " + toQueryCondition(scope, relation.AssociationForeignDBNames) + " FROM (" + subQuerySQL + ") " + scope.Quote("preHO_" + field.DBName)
// find relations // find relations
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), subQuerySQL) query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), subQuerySQL)
//values := toQueryValues(primaryKeys) values := parentQuery.args
values := parentSQLVars
if relation.PolymorphicType != "" { if relation.PolymorphicType != "" {
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
values = append(values, relation.PolymorphicValue) values = append(values, relation.PolymorphicValue)
@ -195,30 +194,27 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{},
scope.Err(field.Set(result)) scope.Err(field.Set(result))
} }
} }
return preloadQuery.QueryExpr().expr, preloadQuery.QueryExpr().args return preloadQuery.QueryExpr()
} }
// handleHasManyPreload used to preload has many associations // handleHasManyPreload used to preload has many associations
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}, parentSQL string, parentSQLVars []interface{}) (string, []interface{}) { func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}, parentSql *SqlExpr) *SqlExpr {
relation := field.Relationship relation := field.Relationship
// get relations's primary keys // skip query if parent does not exist
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) if scope.Value == nil {
if len(primaryKeys) == 0 { return nil
return "", []interface{}{}
} }
// preload conditions // preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
subQuerySQL := parentSQL subQuerySQL := parentSql.expr
subQuerySQL = "SELECT " + toQueryCondition(scope, relation.AssociationForeignDBNames) + " FROM (" + subQuerySQL + ") " + scope.Quote("preHM_" + field.DBName) subQuerySQL = "SELECT " + toQueryCondition(scope, relation.AssociationForeignDBNames) + " FROM (" + subQuerySQL + ") " + scope.Quote("preHM_" + field.DBName)
// find relations // find relations
//scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), subQuerySQL), scope.SQLVars).Find(results, preloadConditions...).Error)
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), subQuerySQL) query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), subQuerySQL)
//values := toQueryValues(primaryKeys) values := parentSql.args
values := parentSQLVars
if relation.PolymorphicType != "" { if relation.PolymorphicType != "" {
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
values = append(values, relation.PolymorphicValue) values = append(values, relation.PolymorphicValue)
@ -255,28 +251,27 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{},
} else { } else {
scope.Err(field.Set(resultsValue)) scope.Err(field.Set(resultsValue))
} }
return preloadQuery.QueryExpr().expr, preloadQuery.QueryExpr().args return preloadQuery.QueryExpr()
} }
// handleBelongsToPreload used to preload belongs to associations // handleBelongsToPreload used to preload belongs to associations
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}, parentSQL string, parentSQLVars []interface{}) (string, []interface{}) { func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}, parentSQL *SqlExpr) *SqlExpr {
relation := field.Relationship relation := field.Relationship
// preload conditions // preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
// get relations's primary keys // get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) if scope.Value == nil {
if len(primaryKeys) == 0 { return nil
return "", []interface{}{}
} }
subQuerySQL := parentSQL subQuerySQL := parentSQL.expr
subQuerySQL = "SELECT " + toQueryCondition(scope, relation.ForeignDBNames) + " FROM (" + subQuerySQL + ") " + scope.Quote("preBT_" + field.DBName) subQuerySQL = "SELECT " + toQueryCondition(scope, relation.ForeignDBNames) + " FROM (" + subQuerySQL + ") " + scope.Quote("preBT_" + field.DBName)
// find relations // find relations
results := makeSlice(field.Struct.Type) results := makeSlice(field.Struct.Type)
preloadQuery := preloadDB.Model(results).Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), subQuerySQL),parentSQLVars) preloadQuery := preloadDB.Model(results).Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), subQuerySQL),parentSQL.args)
scope.Err(preloadQuery.Find(results, preloadConditions...).Error) scope.Err(preloadQuery.Find(results, preloadConditions...).Error)
// assign find results // assign find results
@ -307,11 +302,11 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
scope.Err(field.Set(result)) scope.Err(field.Set(result))
} }
} }
return preloadQuery.QueryExpr().expr, preloadQuery.QueryExpr().args return preloadQuery.QueryExpr()
} }
// handleManyToManyPreload used to preload many to many associations // handleManyToManyPreload used to preload many to many associations
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}, parentSQL string, parentSQLVars []interface{}) { func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}, parentSql *SqlExpr) *SqlExpr {
var ( var (
relation = field.Relationship relation = field.Relationship
joinTableHandler = relation.JoinTableHandler joinTableHandler = relation.JoinTableHandler
@ -343,11 +338,9 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
preloadDB = preloadDB.Select("*") preloadDB = preloadDB.Select("*")
} }
// scope.Value here needs to be a subquery subQuerySQL := parentSql
preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) subQuerySQL.expr = "SELECT " + toQueryCondition(scope, relation.ForeignDBNames) + " FROM (" + subQuerySQL.expr + ") " + scope.Quote("preMM_" + field.DBName)
preloadDB = joinTableHandler.JoinWithQuery(joinTableHandler, preloadDB, scope.Value, subQuerySQL)
preloadSQL := preloadDB.QueryExpr().expr
fmt.Println(preloadSQL);
// preload inline conditions // preload inline conditions
if len(preloadConditions) > 0 { if len(preloadConditions) > 0 {
@ -357,7 +350,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
rows, err := preloadDB.Rows() rows, err := preloadDB.Rows()
if scope.Err(err) != nil { if scope.Err(err) != nil {
return return nil
} }
defer rows.Close() defer rows.Close()
@ -439,4 +432,5 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
f.Set(v) f.Set(v)
} }
} }
return preloadDB.QueryExpr()
} }

View File

@ -19,6 +19,8 @@ type JoinTableHandlerInterface interface {
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
// JoinWith query with `Join` conditions // JoinWith query with `Join` conditions
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
// JoinWith query with `Join` query
JoinWithQuery(handler JoinTableHandlerInterface, db *DB, source interface{}, query *SqlExpr) *DB
// SourceForeignKeys return source foreign keys // SourceForeignKeys return source foreign keys
SourceForeignKeys() []JoinTableForeignKey SourceForeignKeys() []JoinTableForeignKey
// DestinationForeignKeys return destination foreign keys // DestinationForeignKeys return destination foreign keys
@ -159,6 +161,53 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
} }
// JoinWith query with `Join` query
func (s JoinTableHandler) JoinWithQuery(handler JoinTableHandlerInterface, db *DB, source interface{}, query *SqlExpr) *DB {
var (
scope = db.NewScope(source)
tableName = handler.Table(db)
quotedTableName = scope.Quote(tableName)
joinConditions []string
values []interface{}
)
if s.Source.ModelType == scope.GetModelStruct().ModelType {
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
for _, foreignKey := range s.Destination.ForeignKeys {
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
}
var foreignDBNames []string
var foreignFieldNames []string
for _, foreignKey := range s.Source.ForeignKeys {
foreignDBNames = append(foreignDBNames, foreignKey.DBName)
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
foreignFieldNames = append(foreignFieldNames, field.Name)
}
}
var condString string
var quotedForeignDBNames []string
for _, dbName := range foreignDBNames {
quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName)
}
if query != nil {
values = query.args
condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), query.expr)
} else {
values = []interface{}{}
condString = fmt.Sprintf("1 <> 1")
}
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
Where(condString, values)
}
db.Error = errors.New("wrong source type for join table handler")
return db
}
// JoinWith query with `Join` conditions // JoinWith query with `Join` conditions
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
var ( var (

View File

@ -97,6 +97,10 @@ func TestPreload(t *testing.T) {
} }
func TestAutoPreload(t *testing.T) { func TestAutoPreload(t *testing.T) {
var users []User
DB.Find(&users)
DB.Delete(&users)
user1 := getPreloadUser("auto_user1") user1 := getPreloadUser("auto_user1")
DB.Save(user1) DB.Save(user1)
@ -108,7 +112,6 @@ func TestAutoPreload(t *testing.T) {
user2 := getPreloadUser("auto_user2") user2 := getPreloadUser("auto_user2")
DB.Save(user2) DB.Save(user2)
var users []User
preloadDB.Find(&users) preloadDB.Find(&users)
for _, user := range users { for _, user := range users {

View File

@ -25,6 +25,7 @@ var NowFunc = func() time.Time {
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
var commonInitialismsReplacer *strings.Replacer var commonInitialismsReplacer *strings.Replacer
var cleanInRegexp = regexp.MustCompile("\\((\\?,)+\\?\\)")
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`) var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`)
var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`) var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`)
@ -224,3 +225,9 @@ func addExtraSpaceIfExist(str string) string {
} }
return "" return ""
} }
func cleanParentSql(parentSql *SqlExpr) *SqlExpr {
// having more than one query mark will confuse generating the correct number of query marks based upon args
parentSql.expr = string(cleanInRegexp.ReplaceAll([]byte(parentSql.expr), []byte("(?)")))
return parentSql
}