got tests working
This commit is contained in:
parent
47bd34caf0
commit
ef73cc681d
@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
||||
// preloadCallback used to preload associations
|
||||
func preloadCallback(scope *Scope) {
|
||||
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
|
||||
@ -30,8 +31,7 @@ func preloadCallback(scope *Scope) {
|
||||
|
||||
var (
|
||||
preloadedMap = map[string]bool{}
|
||||
preloadedParentQueryMap = map[string]string{}
|
||||
preloadedParentQueryVarsMap = map[string][]interface{}{}
|
||||
parentQueryMap = map[string]*SqlExpr{}
|
||||
fields = scope.Fields()
|
||||
)
|
||||
|
||||
@ -41,8 +41,10 @@ func preloadCallback(scope *Scope) {
|
||||
currentScope = scope
|
||||
currentFields = fields
|
||||
)
|
||||
parentSQL := currentScope.SQL
|
||||
parentSQLVars := currentScope.SQLVars
|
||||
parentQuery := new(SqlExpr)
|
||||
parentQuery.expr = currentScope.SQL
|
||||
parentQuery.args = currentScope.SQLVars
|
||||
cleanParentSql(parentQuery)
|
||||
|
||||
for idx, preloadField := range preloadFields {
|
||||
var currentPreloadConditions []interface{}
|
||||
@ -54,9 +56,8 @@ func preloadCallback(scope *Scope) {
|
||||
// if not preloaded
|
||||
if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
|
||||
parentKey := strings.Join(preloadFields[:idx], ".")
|
||||
if _, ok := preloadedParentQueryMap[parentKey]; ok {
|
||||
parentSQL = preloadedParentQueryMap[parentKey]
|
||||
parentSQLVars = preloadedParentQueryVarsMap[parentKey]
|
||||
if _, ok := parentQueryMap[parentKey]; ok {
|
||||
parentQuery = cleanParentSql(parentQueryMap[parentKey])
|
||||
}
|
||||
|
||||
|
||||
@ -72,13 +73,13 @@ func preloadCallback(scope *Scope) {
|
||||
|
||||
switch field.Relationship.Kind {
|
||||
case "has_one":
|
||||
preloadedParentQueryMap[preloadKey], preloadedParentQueryVarsMap[preloadKey] = currentScope.handleHasOnePreload(field, currentPreloadConditions, parentSQL, parentSQLVars)
|
||||
parentQueryMap[preloadKey] = currentScope.handleHasOnePreload(field, currentPreloadConditions, parentQuery)
|
||||
case "has_many":
|
||||
preloadedParentQueryMap[preloadKey], preloadedParentQueryVarsMap[preloadKey] = currentScope.handleHasManyPreload(field, currentPreloadConditions, parentSQL, parentSQLVars)
|
||||
parentQueryMap[preloadKey] = currentScope.handleHasManyPreload(field, currentPreloadConditions, parentQuery)
|
||||
case "belongs_to":
|
||||
preloadedParentQueryMap[preloadKey], preloadedParentQueryVarsMap[preloadKey] = currentScope.handleBelongsToPreload(field, currentPreloadConditions, parentSQL, parentSQLVars)
|
||||
parentQueryMap[preloadKey] = currentScope.handleBelongsToPreload(field, currentPreloadConditions, parentQuery)
|
||||
case "many_to_many":
|
||||
currentScope.handleManyToManyPreload(field, currentPreloadConditions, parentSQL, parentSQLVars)
|
||||
parentQueryMap[preloadKey] = currentScope.handleManyToManyPreload(field, currentPreloadConditions, parentQuery)
|
||||
default:
|
||||
scope.Err(errors.New("unsupported relation"))
|
||||
}
|
||||
@ -141,25 +142,23 @@ func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*
|
||||
}
|
||||
|
||||
// handleHasOnePreload used to preload has one associations
|
||||
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}, parentSQL string, parentSQLVars []interface{}) (string, []interface{}) {
|
||||
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}, parentQuery *SqlExpr) *SqlExpr {
|
||||
relation := field.Relationship
|
||||
|
||||
// get relations's primary keys
|
||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
|
||||
if len(primaryKeys) == 0 {
|
||||
return "", []interface{}{}
|
||||
// skip query if parent does not exist
|
||||
if scope.Value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
subQuerySQL := parentSQL
|
||||
subQuerySQL := parentQuery.expr
|
||||
subQuerySQL = "SELECT " + toQueryCondition(scope, relation.AssociationForeignDBNames) + " FROM (" + subQuerySQL + ") " + scope.Quote("preHO_" + field.DBName)
|
||||
|
||||
// find relations
|
||||
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), subQuerySQL)
|
||||
//values := toQueryValues(primaryKeys)
|
||||
values := parentSQLVars
|
||||
values := parentQuery.args
|
||||
if relation.PolymorphicType != "" {
|
||||
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
|
||||
values = append(values, relation.PolymorphicValue)
|
||||
@ -195,30 +194,27 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{},
|
||||
scope.Err(field.Set(result))
|
||||
}
|
||||
}
|
||||
return preloadQuery.QueryExpr().expr, preloadQuery.QueryExpr().args
|
||||
return preloadQuery.QueryExpr()
|
||||
}
|
||||
|
||||
// handleHasManyPreload used to preload has many associations
|
||||
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}, parentSQL string, parentSQLVars []interface{}) (string, []interface{}) {
|
||||
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}, parentSql *SqlExpr) *SqlExpr {
|
||||
relation := field.Relationship
|
||||
|
||||
// get relations's primary keys
|
||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
|
||||
if len(primaryKeys) == 0 {
|
||||
return "", []interface{}{}
|
||||
// skip query if parent does not exist
|
||||
if scope.Value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
subQuerySQL := parentSQL
|
||||
subQuerySQL := parentSql.expr
|
||||
subQuerySQL = "SELECT " + toQueryCondition(scope, relation.AssociationForeignDBNames) + " FROM (" + subQuerySQL + ") " + scope.Quote("preHM_" + field.DBName)
|
||||
|
||||
// find relations
|
||||
//scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), subQuerySQL), scope.SQLVars).Find(results, preloadConditions...).Error)
|
||||
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), subQuerySQL)
|
||||
//values := toQueryValues(primaryKeys)
|
||||
values := parentSQLVars
|
||||
values := parentSql.args
|
||||
if relation.PolymorphicType != "" {
|
||||
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
|
||||
values = append(values, relation.PolymorphicValue)
|
||||
@ -255,28 +251,27 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{},
|
||||
} else {
|
||||
scope.Err(field.Set(resultsValue))
|
||||
}
|
||||
return preloadQuery.QueryExpr().expr, preloadQuery.QueryExpr().args
|
||||
return preloadQuery.QueryExpr()
|
||||
}
|
||||
|
||||
// handleBelongsToPreload used to preload belongs to associations
|
||||
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}, parentSQL string, parentSQLVars []interface{}) (string, []interface{}) {
|
||||
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}, parentSQL *SqlExpr) *SqlExpr {
|
||||
relation := field.Relationship
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
// get relations's primary keys
|
||||
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
|
||||
if len(primaryKeys) == 0 {
|
||||
return "", []interface{}{}
|
||||
if scope.Value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
subQuerySQL := parentSQL
|
||||
subQuerySQL := parentSQL.expr
|
||||
subQuerySQL = "SELECT " + toQueryCondition(scope, relation.ForeignDBNames) + " FROM (" + subQuerySQL + ") " + scope.Quote("preBT_" + field.DBName)
|
||||
|
||||
// find relations
|
||||
results := makeSlice(field.Struct.Type)
|
||||
preloadQuery := preloadDB.Model(results).Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), subQuerySQL),parentSQLVars)
|
||||
preloadQuery := preloadDB.Model(results).Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), subQuerySQL),parentSQL.args)
|
||||
scope.Err(preloadQuery.Find(results, preloadConditions...).Error)
|
||||
|
||||
// assign find results
|
||||
@ -307,11 +302,11 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
|
||||
scope.Err(field.Set(result))
|
||||
}
|
||||
}
|
||||
return preloadQuery.QueryExpr().expr, preloadQuery.QueryExpr().args
|
||||
return preloadQuery.QueryExpr()
|
||||
}
|
||||
|
||||
// handleManyToManyPreload used to preload many to many associations
|
||||
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}, parentSQL string, parentSQLVars []interface{}) {
|
||||
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}, parentSql *SqlExpr) *SqlExpr {
|
||||
var (
|
||||
relation = field.Relationship
|
||||
joinTableHandler = relation.JoinTableHandler
|
||||
@ -343,11 +338,9 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
|
||||
preloadDB = preloadDB.Select("*")
|
||||
}
|
||||
|
||||
// scope.Value here needs to be a subquery
|
||||
preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
|
||||
|
||||
preloadSQL := preloadDB.QueryExpr().expr
|
||||
fmt.Println(preloadSQL);
|
||||
subQuerySQL := parentSql
|
||||
subQuerySQL.expr = "SELECT " + toQueryCondition(scope, relation.ForeignDBNames) + " FROM (" + subQuerySQL.expr + ") " + scope.Quote("preMM_" + field.DBName)
|
||||
preloadDB = joinTableHandler.JoinWithQuery(joinTableHandler, preloadDB, scope.Value, subQuerySQL)
|
||||
|
||||
// preload inline conditions
|
||||
if len(preloadConditions) > 0 {
|
||||
@ -357,7 +350,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
|
||||
rows, err := preloadDB.Rows()
|
||||
|
||||
if scope.Err(err) != nil {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
@ -439,4 +432,5 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
|
||||
f.Set(v)
|
||||
}
|
||||
}
|
||||
return preloadDB.QueryExpr()
|
||||
}
|
||||
|
@ -19,6 +19,8 @@ type JoinTableHandlerInterface interface {
|
||||
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
|
||||
// JoinWith query with `Join` conditions
|
||||
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
|
||||
// JoinWith query with `Join` query
|
||||
JoinWithQuery(handler JoinTableHandlerInterface, db *DB, source interface{}, query *SqlExpr) *DB
|
||||
// SourceForeignKeys return source foreign keys
|
||||
SourceForeignKeys() []JoinTableForeignKey
|
||||
// DestinationForeignKeys return destination foreign keys
|
||||
@ -159,6 +161,53 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
|
||||
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
|
||||
}
|
||||
|
||||
// JoinWith query with `Join` query
|
||||
func (s JoinTableHandler) JoinWithQuery(handler JoinTableHandlerInterface, db *DB, source interface{}, query *SqlExpr) *DB {
|
||||
var (
|
||||
scope = db.NewScope(source)
|
||||
tableName = handler.Table(db)
|
||||
quotedTableName = scope.Quote(tableName)
|
||||
joinConditions []string
|
||||
values []interface{}
|
||||
)
|
||||
|
||||
if s.Source.ModelType == scope.GetModelStruct().ModelType {
|
||||
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
|
||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
||||
}
|
||||
|
||||
var foreignDBNames []string
|
||||
var foreignFieldNames []string
|
||||
|
||||
for _, foreignKey := range s.Source.ForeignKeys {
|
||||
foreignDBNames = append(foreignDBNames, foreignKey.DBName)
|
||||
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||
foreignFieldNames = append(foreignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
var condString string
|
||||
var quotedForeignDBNames []string
|
||||
for _, dbName := range foreignDBNames {
|
||||
quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName)
|
||||
}
|
||||
|
||||
if query != nil {
|
||||
values = query.args
|
||||
condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), query.expr)
|
||||
} else {
|
||||
values = []interface{}{}
|
||||
condString = fmt.Sprintf("1 <> 1")
|
||||
}
|
||||
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
|
||||
Where(condString, values)
|
||||
}
|
||||
|
||||
db.Error = errors.New("wrong source type for join table handler")
|
||||
return db
|
||||
}
|
||||
|
||||
// JoinWith query with `Join` conditions
|
||||
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
|
||||
var (
|
||||
|
@ -97,6 +97,10 @@ func TestPreload(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAutoPreload(t *testing.T) {
|
||||
var users []User
|
||||
DB.Find(&users)
|
||||
DB.Delete(&users)
|
||||
|
||||
user1 := getPreloadUser("auto_user1")
|
||||
DB.Save(user1)
|
||||
|
||||
@ -108,7 +112,6 @@ func TestAutoPreload(t *testing.T) {
|
||||
user2 := getPreloadUser("auto_user2")
|
||||
DB.Save(user2)
|
||||
|
||||
var users []User
|
||||
preloadDB.Find(&users)
|
||||
|
||||
for _, user := range users {
|
||||
|
7
utils.go
7
utils.go
@ -25,6 +25,7 @@ var NowFunc = func() time.Time {
|
||||
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
||||
var commonInitialismsReplacer *strings.Replacer
|
||||
|
||||
var cleanInRegexp = regexp.MustCompile("\\((\\?,)+\\?\\)")
|
||||
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`)
|
||||
var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`)
|
||||
|
||||
@ -224,3 +225,9 @@ func addExtraSpaceIfExist(str string) string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func cleanParentSql(parentSql *SqlExpr) *SqlExpr {
|
||||
// having more than one query mark will confuse generating the correct number of query marks based upon args
|
||||
parentSql.expr = string(cleanInRegexp.ReplaceAll([]byte(parentSql.expr), []byte("(?)")))
|
||||
return parentSql
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user