support association Find and Count with customized func, change buildCondition tablename to struct's tablename when query is interface{}
This commit is contained in:
parent
0fd395ab37
commit
1da8451bb1
@ -15,8 +15,9 @@ type Association struct {
|
||||
}
|
||||
|
||||
// Find find out all related associations
|
||||
func (association *Association) Find(value interface{}) *Association {
|
||||
association.scope.related(value, association.column)
|
||||
func (association *Association) Find(value interface{}, options ...interface{}) *Association {
|
||||
options = append(options, association.column)
|
||||
association.scope.related(value, options...)
|
||||
return association.setErr(association.scope.db.Error)
|
||||
}
|
||||
|
||||
@ -258,7 +259,7 @@ func (association *Association) Clear() *Association {
|
||||
}
|
||||
|
||||
// Count return the count of current associations
|
||||
func (association *Association) Count() int {
|
||||
func (association *Association) Count(funcs ...func(*DB) *DB) int {
|
||||
var (
|
||||
count = 0
|
||||
relationship = association.field.Relationship
|
||||
@ -290,7 +291,7 @@ func (association *Association) Count() int {
|
||||
)
|
||||
}
|
||||
|
||||
if err := query.Model(fieldValue).Count(&count).Error; err != nil {
|
||||
if err := query.Model(fieldValue).Scopes(funcs...).Count(&count).Error; err != nil {
|
||||
association.Error = err
|
||||
}
|
||||
return count
|
||||
|
4
main.go
4
main.go
@ -354,8 +354,8 @@ func (s *DB) Count(value interface{}) *DB {
|
||||
}
|
||||
|
||||
// Related get related associations
|
||||
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
|
||||
return s.NewScope(s.Value).related(value, foreignKeys...).db
|
||||
func (s *DB) Related(value interface{}, options ...interface{}) *DB {
|
||||
return s.NewScope(s.Value).related(value, options...).db
|
||||
}
|
||||
|
||||
// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions)
|
||||
|
26
scope.go
26
scope.go
@ -586,10 +586,11 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
|
||||
scope.Err(fmt.Errorf("invalid query condition: %v", value))
|
||||
return
|
||||
}
|
||||
scopeQuotedTableName := newScope.QuotedTableName()
|
||||
|
||||
for _, field := range newScope.Fields() {
|
||||
if !field.IsIgnored && !field.IsBlank {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
|
||||
sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
|
||||
}
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
@ -1044,10 +1045,19 @@ func (scope *Scope) changeableField(field *Field) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||
func (scope *Scope) related(value interface{}, options ...interface{}) *Scope {
|
||||
toScope := scope.db.NewScope(value)
|
||||
tx := scope.db.Set("gorm:association:source", scope.Value)
|
||||
|
||||
foreignKeys := []string{}
|
||||
dbFuncs := []func(*DB) *DB{}
|
||||
for _, option := range options {
|
||||
if key, ok := option.(string); ok {
|
||||
foreignKeys = append(foreignKeys, key)
|
||||
}
|
||||
if dbFunc, ok := option.(func(*DB) *DB); ok {
|
||||
dbFuncs = append(dbFuncs, dbFunc)
|
||||
}
|
||||
}
|
||||
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
|
||||
fromField, _ := scope.FieldByName(foreignKey)
|
||||
toField, _ := toScope.FieldByName(foreignKey)
|
||||
@ -1056,14 +1066,14 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||
if relationship := fromField.Relationship; relationship != nil {
|
||||
if relationship.Kind == "many_to_many" {
|
||||
joinTableHandler := relationship.JoinTableHandler
|
||||
scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Find(value).Error)
|
||||
scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Scopes(dbFuncs...).Find(value).Error)
|
||||
} else if relationship.Kind == "belongs_to" {
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(foreignKey); ok {
|
||||
tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
scope.Err(tx.Find(value).Error)
|
||||
scope.Err(tx.Scopes(dbFuncs...).Find(value).Error)
|
||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
|
||||
@ -1074,16 +1084,16 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||
if relationship.PolymorphicType != "" {
|
||||
tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue)
|
||||
}
|
||||
scope.Err(tx.Find(value).Error)
|
||||
scope.Err(tx.Scopes(dbFuncs...).Find(value).Error)
|
||||
}
|
||||
} else {
|
||||
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
|
||||
scope.Err(tx.Where(sql, fromField.Field.Interface()).Find(value).Error)
|
||||
scope.Err(tx.Where(sql, fromField.Field.Interface()).Scopes(dbFuncs...).Find(value).Error)
|
||||
}
|
||||
return scope
|
||||
} else if toField != nil {
|
||||
sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
|
||||
scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Find(value).Error)
|
||||
scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Scopes(dbFuncs...).Find(value).Error)
|
||||
return scope
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user