support association Find and Count with customized func, change buildCondition tablename to struct's tablename when query is interface{}

This commit is contained in:
hectorqin 2018-07-19 10:44:13 +08:00
parent 0fd395ab37
commit 1da8451bb1
3 changed files with 25 additions and 14 deletions

View File

@ -15,8 +15,9 @@ type Association struct {
} }
// Find find out all related associations // Find find out all related associations
func (association *Association) Find(value interface{}) *Association { func (association *Association) Find(value interface{}, options ...interface{}) *Association {
association.scope.related(value, association.column) options = append(options, association.column)
association.scope.related(value, options...)
return association.setErr(association.scope.db.Error) return association.setErr(association.scope.db.Error)
} }
@ -258,7 +259,7 @@ func (association *Association) Clear() *Association {
} }
// Count return the count of current associations // Count return the count of current associations
func (association *Association) Count() int { func (association *Association) Count(funcs ...func(*DB) *DB) int {
var ( var (
count = 0 count = 0
relationship = association.field.Relationship relationship = association.field.Relationship
@ -290,7 +291,7 @@ func (association *Association) Count() int {
) )
} }
if err := query.Model(fieldValue).Count(&count).Error; err != nil { if err := query.Model(fieldValue).Scopes(funcs...).Count(&count).Error; err != nil {
association.Error = err association.Error = err
} }
return count return count

View File

@ -354,8 +354,8 @@ func (s *DB) Count(value interface{}) *DB {
} }
// Related get related associations // Related get related associations
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { func (s *DB) Related(value interface{}, options ...interface{}) *DB {
return s.NewScope(s.Value).related(value, foreignKeys...).db return s.NewScope(s.Value).related(value, options...).db
} }
// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) // FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions)

View File

@ -586,10 +586,11 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
scope.Err(fmt.Errorf("invalid query condition: %v", value)) scope.Err(fmt.Errorf("invalid query condition: %v", value))
return return
} }
scopeQuotedTableName := newScope.QuotedTableName()
for _, field := range newScope.Fields() { for _, field := range newScope.Fields() {
if !field.IsIgnored && !field.IsBlank { if !field.IsIgnored && !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
} }
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
@ -1044,10 +1045,19 @@ func (scope *Scope) changeableField(field *Field) bool {
return true return true
} }
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { func (scope *Scope) related(value interface{}, options ...interface{}) *Scope {
toScope := scope.db.NewScope(value) toScope := scope.db.NewScope(value)
tx := scope.db.Set("gorm:association:source", scope.Value) tx := scope.db.Set("gorm:association:source", scope.Value)
foreignKeys := []string{}
dbFuncs := []func(*DB) *DB{}
for _, option := range options {
if key, ok := option.(string); ok {
foreignKeys = append(foreignKeys, key)
}
if dbFunc, ok := option.(func(*DB) *DB); ok {
dbFuncs = append(dbFuncs, dbFunc)
}
}
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
fromField, _ := scope.FieldByName(foreignKey) fromField, _ := scope.FieldByName(foreignKey)
toField, _ := toScope.FieldByName(foreignKey) toField, _ := toScope.FieldByName(foreignKey)
@ -1056,14 +1066,14 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
if relationship := fromField.Relationship; relationship != nil { if relationship := fromField.Relationship; relationship != nil {
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
joinTableHandler := relationship.JoinTableHandler joinTableHandler := relationship.JoinTableHandler
scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Find(value).Error) scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Scopes(dbFuncs...).Find(value).Error)
} else if relationship.Kind == "belongs_to" { } else if relationship.Kind == "belongs_to" {
for idx, foreignKey := range relationship.ForeignDBNames { for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(foreignKey); ok { if field, ok := scope.FieldByName(foreignKey); ok {
tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface())
} }
} }
scope.Err(tx.Find(value).Error) scope.Err(tx.Scopes(dbFuncs...).Find(value).Error)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
for idx, foreignKey := range relationship.ForeignDBNames { for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
@ -1074,16 +1084,16 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
if relationship.PolymorphicType != "" { if relationship.PolymorphicType != "" {
tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue)
} }
scope.Err(tx.Find(value).Error) scope.Err(tx.Scopes(dbFuncs...).Find(value).Error)
} }
} else { } else {
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
scope.Err(tx.Where(sql, fromField.Field.Interface()).Find(value).Error) scope.Err(tx.Where(sql, fromField.Field.Interface()).Scopes(dbFuncs...).Find(value).Error)
} }
return scope return scope
} else if toField != nil { } else if toField != nil {
sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Scopes(dbFuncs...).Find(value).Error)
return scope return scope
} }
} }