From 1da8451bb126d4ac27f530b765792ed19840a3dc Mon Sep 17 00:00:00 2001 From: hectorqin <1069315972@qq.com> Date: Thu, 19 Jul 2018 10:44:13 +0800 Subject: [PATCH] support association Find and Count with customized func, change buildCondition tablename to struct's tablename when query is interface{} --- association.go | 9 +++++---- main.go | 4 ++-- scope.go | 26 ++++++++++++++++++-------- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/association.go b/association.go index 8c6d9864..1e9c10fc 100644 --- a/association.go +++ b/association.go @@ -15,8 +15,9 @@ type Association struct { } // Find find out all related associations -func (association *Association) Find(value interface{}) *Association { - association.scope.related(value, association.column) +func (association *Association) Find(value interface{}, options ...interface{}) *Association { + options = append(options, association.column) + association.scope.related(value, options...) return association.setErr(association.scope.db.Error) } @@ -258,7 +259,7 @@ func (association *Association) Clear() *Association { } // Count return the count of current associations -func (association *Association) Count() int { +func (association *Association) Count(funcs ...func(*DB) *DB) int { var ( count = 0 relationship = association.field.Relationship @@ -290,7 +291,7 @@ func (association *Association) Count() int { ) } - if err := query.Model(fieldValue).Count(&count).Error; err != nil { + if err := query.Model(fieldValue).Scopes(funcs...).Count(&count).Error; err != nil { association.Error = err } return count diff --git a/main.go b/main.go index 25c3a06b..cb8626a3 100644 --- a/main.go +++ b/main.go @@ -354,8 +354,8 @@ func (s *DB) Count(value interface{}) *DB { } // Related get related associations -func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.NewScope(s.Value).related(value, foreignKeys...).db +func (s *DB) Related(value interface{}, options ...interface{}) *DB { + return s.NewScope(s.Value).related(value, options...).db } // FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) diff --git a/scope.go b/scope.go index 397ccf0b..b2ca476b 100644 --- a/scope.go +++ b/scope.go @@ -586,10 +586,11 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) scope.Err(fmt.Errorf("invalid query condition: %v", value)) return } + scopeQuotedTableName := newScope.QuotedTableName() for _, field := range newScope.Fields() { if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) + sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) } } return strings.Join(sqls, " AND ") @@ -1044,10 +1045,19 @@ func (scope *Scope) changeableField(field *Field) bool { return true } -func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { +func (scope *Scope) related(value interface{}, options ...interface{}) *Scope { toScope := scope.db.NewScope(value) tx := scope.db.Set("gorm:association:source", scope.Value) - + foreignKeys := []string{} + dbFuncs := []func(*DB) *DB{} + for _, option := range options { + if key, ok := option.(string); ok { + foreignKeys = append(foreignKeys, key) + } + if dbFunc, ok := option.(func(*DB) *DB); ok { + dbFuncs = append(dbFuncs, dbFunc) + } + } for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { fromField, _ := scope.FieldByName(foreignKey) toField, _ := toScope.FieldByName(foreignKey) @@ -1056,14 +1066,14 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { if relationship := fromField.Relationship; relationship != nil { if relationship.Kind == "many_to_many" { joinTableHandler := relationship.JoinTableHandler - scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Find(value).Error) + scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Scopes(dbFuncs...).Find(value).Error) } else if relationship.Kind == "belongs_to" { for idx, foreignKey := range relationship.ForeignDBNames { if field, ok := scope.FieldByName(foreignKey); ok { tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) } } - scope.Err(tx.Find(value).Error) + scope.Err(tx.Scopes(dbFuncs...).Find(value).Error) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { for idx, foreignKey := range relationship.ForeignDBNames { if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { @@ -1074,16 +1084,16 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { if relationship.PolymorphicType != "" { tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) } - scope.Err(tx.Find(value).Error) + scope.Err(tx.Scopes(dbFuncs...).Find(value).Error) } } else { sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) - scope.Err(tx.Where(sql, fromField.Field.Interface()).Find(value).Error) + scope.Err(tx.Where(sql, fromField.Field.Interface()).Scopes(dbFuncs...).Find(value).Error) } return scope } else if toField != nil { sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) - scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) + scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Scopes(dbFuncs...).Find(value).Error) return scope } }