Preload belongs_to, has_one, has_many relations
This commit is contained in:
parent
3b784c37c4
commit
dd635a17c2
@ -3,8 +3,6 @@ package gorm
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func Query(scope *Scope) {
|
||||
@ -15,7 +13,6 @@ func Query(scope *Scope) {
|
||||
isPtr bool
|
||||
anyRecordFound bool
|
||||
destType reflect.Type
|
||||
primaryKeys []interface{}
|
||||
)
|
||||
|
||||
var dest = scope.IndirectValue()
|
||||
@ -50,8 +47,6 @@ func Query(scope *Scope) {
|
||||
return
|
||||
}
|
||||
|
||||
preloadMap := map[string]map[string]*gorm.Field{}
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
@ -66,7 +61,6 @@ func Query(scope *Scope) {
|
||||
var values = make([]interface{}, len(columns))
|
||||
|
||||
fields := scope.New(elem.Addr().Interface()).Fields()
|
||||
var primaryKey interface{}
|
||||
for index, column := range columns {
|
||||
if field, ok := fields[column]; ok {
|
||||
if field.Field.Kind() == reflect.Ptr {
|
||||
@ -74,10 +68,6 @@ func Query(scope *Scope) {
|
||||
} else {
|
||||
values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
|
||||
}
|
||||
if field.IsPrimaryKey {
|
||||
primaryKey = values[index]
|
||||
primaryKeys = append(primaryKeys, primaryKey)
|
||||
}
|
||||
} else {
|
||||
var value interface{}
|
||||
values[index] = &value
|
||||
@ -104,34 +94,6 @@ func Query(scope *Scope) {
|
||||
dest.Set(reflect.Append(dest, elem))
|
||||
}
|
||||
}
|
||||
|
||||
if scope.Search.Preload != nil {
|
||||
for key := range scope.Search.Preload {
|
||||
if field := fields[key]; field != nil {
|
||||
if preloadMap[key] == nil {
|
||||
preloadMap[key] = map[string]reflect.Value{}
|
||||
}
|
||||
preloadMap[key][fmt.Sprintf("%v", primaryKey)] = field
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, value := range preloadMap {
|
||||
var typ reflect.Type
|
||||
var relation *Relation
|
||||
for _, v := range value {
|
||||
typ = v.Field.Type()
|
||||
relation = v.Relationship
|
||||
break
|
||||
}
|
||||
sliceType := reflect.SliceOf(typ)
|
||||
slice := reflect.MakeSlice(sliceType, 0, 0)
|
||||
slicePtr := reflect.New(sliceType)
|
||||
slicePtr.Elem().Set(slice)
|
||||
if relation == "has_many" {
|
||||
scope.NewDB().Find(slicePtr.Interface(), primaryKeys)
|
||||
}
|
||||
}
|
||||
|
||||
if !anyRecordFound && !isSlice {
|
||||
@ -147,4 +109,5 @@ func AfterQuery(scope *Scope) {
|
||||
func init() {
|
||||
DefaultCallback.Query().Register("gorm:query", Query)
|
||||
DefaultCallback.Query().Register("gorm:after_query", AfterQuery)
|
||||
DefaultCallback.Query().Register("gorm:preload", Preload)
|
||||
}
|
||||
|
5
main.go
5
main.go
@ -79,8 +79,9 @@ func (s *DB) DB() *sql.DB {
|
||||
}
|
||||
|
||||
func (s *DB) New() *DB {
|
||||
s.search = nil
|
||||
return s.clone()
|
||||
clone := s.clone()
|
||||
clone.search = nil
|
||||
return clone
|
||||
}
|
||||
|
||||
// Return the underlying sql.DB or sql.Tx instance.
|
||||
|
71
preload.go
Normal file
71
preload.go
Normal file
@ -0,0 +1,71 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func Preload(scope *Scope) {
|
||||
// Get Fields
|
||||
var fields map[string]*Field
|
||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||
elem := reflect.New(scope.IndirectValue().Type().Elem()).Elem()
|
||||
fields = scope.New(elem.Addr().Interface()).Fields()
|
||||
} else {
|
||||
fields = scope.Fields()
|
||||
}
|
||||
|
||||
if scope.Search.Preload != nil {
|
||||
for key := range scope.Search.Preload {
|
||||
for _, field := range fields {
|
||||
if field.Name == key && field.Relationship != nil {
|
||||
results := makeSlice(field.Field)
|
||||
relation := field.Relationship
|
||||
primaryName := scope.PrimaryKeyField().Name
|
||||
|
||||
switch relation.Kind {
|
||||
case "has_one":
|
||||
sql := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignKey))
|
||||
scope.NewDB().Find(results, sql, scope.getColumnAsArray(primaryName))
|
||||
case "has_many":
|
||||
sql := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignKey))
|
||||
scope.NewDB().Find(results, sql, scope.getColumnAsArray(primaryName))
|
||||
case "belongs_to":
|
||||
scope.NewDB().Find(results, scope.getColumnAsArray(relation.ForeignKey))
|
||||
case "many_to_many":
|
||||
scope.Err(errors.New("not supported relation"))
|
||||
default:
|
||||
scope.Err(errors.New("not supported relation"))
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func makeSlice(value reflect.Value) interface{} {
|
||||
typ := value.Type()
|
||||
if value.Kind() == reflect.Slice {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
sliceType := reflect.SliceOf(typ)
|
||||
slice := reflect.New(sliceType)
|
||||
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
|
||||
return slice.Interface()
|
||||
}
|
||||
|
||||
func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{}) {
|
||||
values := scope.IndirectValue()
|
||||
switch values.Kind() {
|
||||
case reflect.Slice:
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
value := values.Index(i)
|
||||
primaryKeys = append(primaryKeys, value.FieldByName(column).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
return []interface{}{values.FieldByName(column).Interface()}
|
||||
}
|
||||
return
|
||||
}
|
11
preload_test.go
Normal file
11
preload_test.go
Normal file
@ -0,0 +1,11 @@
|
||||
package gorm_test
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestPreload(t *testing.T) {
|
||||
user := User{Name: "PreloadUser", BillingAddress: Address{Address1: "Billing Address"}, ShippingAddress: Address{Address1: "Shipping Address"}, Languages: []Language{{Name: "Preload L1"}, {Name: "Preload L2"}}}
|
||||
DB.Save(&user)
|
||||
|
||||
var users []User
|
||||
DB.Preload("BillingAddress").Preload("ShippingAddress").Preload("Languages").Find(&users)
|
||||
}
|
@ -30,7 +30,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
|
||||
return scope.primaryCondiation(scope.AddToVars(value))
|
||||
case sql.NullInt64:
|
||||
return scope.primaryCondiation(scope.AddToVars(value.Int64))
|
||||
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
|
||||
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
|
||||
str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey()))
|
||||
clause["args"] = []interface{}{value}
|
||||
case map[string]interface{}:
|
||||
|
Loading…
x
Reference in New Issue
Block a user