Merge 1da8451bb126d4ac27f530b765792ed19840a3dc into 0fd395ab37aefd2d50854f0556a4311dccc6f45a

This commit is contained in:
hector 2018-07-19 02:47:45 +00:00 committed by GitHub
commit 549f4e1239
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 14 deletions

View File

@ -15,8 +15,9 @@ type Association struct {
}
// Find find out all related associations
func (association *Association) Find(value interface{}) *Association {
association.scope.related(value, association.column)
func (association *Association) Find(value interface{}, options ...interface{}) *Association {
options = append(options, association.column)
association.scope.related(value, options...)
return association.setErr(association.scope.db.Error)
}
@ -258,7 +259,7 @@ func (association *Association) Clear() *Association {
}
// Count return the count of current associations
func (association *Association) Count() int {
func (association *Association) Count(funcs ...func(*DB) *DB) int {
var (
count = 0
relationship = association.field.Relationship
@ -290,7 +291,7 @@ func (association *Association) Count() int {
)
}
if err := query.Model(fieldValue).Count(&count).Error; err != nil {
if err := query.Model(fieldValue).Scopes(funcs...).Count(&count).Error; err != nil {
association.Error = err
}
return count

View File

@ -354,8 +354,8 @@ func (s *DB) Count(value interface{}) *DB {
}
// Related get related associations
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
return s.NewScope(s.Value).related(value, foreignKeys...).db
func (s *DB) Related(value interface{}, options ...interface{}) *DB {
return s.NewScope(s.Value).related(value, options...).db
}
// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions)

View File

@ -586,10 +586,11 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
scope.Err(fmt.Errorf("invalid query condition: %v", value))
return
}
scopeQuotedTableName := newScope.QuotedTableName()
for _, field := range newScope.Fields() {
if !field.IsIgnored && !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
}
}
return strings.Join(sqls, " AND ")
@ -1044,10 +1045,19 @@ func (scope *Scope) changeableField(field *Field) bool {
return true
}
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
func (scope *Scope) related(value interface{}, options ...interface{}) *Scope {
toScope := scope.db.NewScope(value)
tx := scope.db.Set("gorm:association:source", scope.Value)
foreignKeys := []string{}
dbFuncs := []func(*DB) *DB{}
for _, option := range options {
if key, ok := option.(string); ok {
foreignKeys = append(foreignKeys, key)
}
if dbFunc, ok := option.(func(*DB) *DB); ok {
dbFuncs = append(dbFuncs, dbFunc)
}
}
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
fromField, _ := scope.FieldByName(foreignKey)
toField, _ := toScope.FieldByName(foreignKey)
@ -1056,14 +1066,14 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
if relationship := fromField.Relationship; relationship != nil {
if relationship.Kind == "many_to_many" {
joinTableHandler := relationship.JoinTableHandler
scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Find(value).Error)
scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Scopes(dbFuncs...).Find(value).Error)
} else if relationship.Kind == "belongs_to" {
for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(foreignKey); ok {
tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface())
}
}
scope.Err(tx.Find(value).Error)
scope.Err(tx.Scopes(dbFuncs...).Find(value).Error)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
@ -1074,16 +1084,16 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
if relationship.PolymorphicType != "" {
tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue)
}
scope.Err(tx.Find(value).Error)
scope.Err(tx.Scopes(dbFuncs...).Find(value).Error)
}
} else {
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
scope.Err(tx.Where(sql, fromField.Field.Interface()).Find(value).Error)
scope.Err(tx.Where(sql, fromField.Field.Interface()).Scopes(dbFuncs...).Find(value).Error)
}
return scope
} else if toField != nil {
sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Find(value).Error)
scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Scopes(dbFuncs...).Find(value).Error)
return scope
}
}