Preload belongs_to, has_one, has_many relations
This commit is contained in:
parent
3b784c37c4
commit
dd635a17c2
@ -3,8 +3,6 @@ package gorm
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Query(scope *Scope) {
|
func Query(scope *Scope) {
|
||||||
@ -15,7 +13,6 @@ func Query(scope *Scope) {
|
|||||||
isPtr bool
|
isPtr bool
|
||||||
anyRecordFound bool
|
anyRecordFound bool
|
||||||
destType reflect.Type
|
destType reflect.Type
|
||||||
primaryKeys []interface{}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var dest = scope.IndirectValue()
|
var dest = scope.IndirectValue()
|
||||||
@ -50,8 +47,6 @@ func Query(scope *Scope) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
preloadMap := map[string]map[string]*gorm.Field{}
|
|
||||||
|
|
||||||
columns, _ := rows.Columns()
|
columns, _ := rows.Columns()
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
@ -66,7 +61,6 @@ func Query(scope *Scope) {
|
|||||||
var values = make([]interface{}, len(columns))
|
var values = make([]interface{}, len(columns))
|
||||||
|
|
||||||
fields := scope.New(elem.Addr().Interface()).Fields()
|
fields := scope.New(elem.Addr().Interface()).Fields()
|
||||||
var primaryKey interface{}
|
|
||||||
for index, column := range columns {
|
for index, column := range columns {
|
||||||
if field, ok := fields[column]; ok {
|
if field, ok := fields[column]; ok {
|
||||||
if field.Field.Kind() == reflect.Ptr {
|
if field.Field.Kind() == reflect.Ptr {
|
||||||
@ -74,10 +68,6 @@ func Query(scope *Scope) {
|
|||||||
} else {
|
} else {
|
||||||
values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
|
values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
|
||||||
}
|
}
|
||||||
if field.IsPrimaryKey {
|
|
||||||
primaryKey = values[index]
|
|
||||||
primaryKeys = append(primaryKeys, primaryKey)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
var value interface{}
|
var value interface{}
|
||||||
values[index] = &value
|
values[index] = &value
|
||||||
@ -104,34 +94,6 @@ func Query(scope *Scope) {
|
|||||||
dest.Set(reflect.Append(dest, elem))
|
dest.Set(reflect.Append(dest, elem))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.Search.Preload != nil {
|
|
||||||
for key := range scope.Search.Preload {
|
|
||||||
if field := fields[key]; field != nil {
|
|
||||||
if preloadMap[key] == nil {
|
|
||||||
preloadMap[key] = map[string]reflect.Value{}
|
|
||||||
}
|
|
||||||
preloadMap[key][fmt.Sprintf("%v", primaryKey)] = field
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, value := range preloadMap {
|
|
||||||
var typ reflect.Type
|
|
||||||
var relation *Relation
|
|
||||||
for _, v := range value {
|
|
||||||
typ = v.Field.Type()
|
|
||||||
relation = v.Relationship
|
|
||||||
break
|
|
||||||
}
|
|
||||||
sliceType := reflect.SliceOf(typ)
|
|
||||||
slice := reflect.MakeSlice(sliceType, 0, 0)
|
|
||||||
slicePtr := reflect.New(sliceType)
|
|
||||||
slicePtr.Elem().Set(slice)
|
|
||||||
if relation == "has_many" {
|
|
||||||
scope.NewDB().Find(slicePtr.Interface(), primaryKeys)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !anyRecordFound && !isSlice {
|
if !anyRecordFound && !isSlice {
|
||||||
@ -147,4 +109,5 @@ func AfterQuery(scope *Scope) {
|
|||||||
func init() {
|
func init() {
|
||||||
DefaultCallback.Query().Register("gorm:query", Query)
|
DefaultCallback.Query().Register("gorm:query", Query)
|
||||||
DefaultCallback.Query().Register("gorm:after_query", AfterQuery)
|
DefaultCallback.Query().Register("gorm:after_query", AfterQuery)
|
||||||
|
DefaultCallback.Query().Register("gorm:preload", Preload)
|
||||||
}
|
}
|
||||||
|
5
main.go
5
main.go
@ -79,8 +79,9 @@ func (s *DB) DB() *sql.DB {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) New() *DB {
|
func (s *DB) New() *DB {
|
||||||
s.search = nil
|
clone := s.clone()
|
||||||
return s.clone()
|
clone.search = nil
|
||||||
|
return clone
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the underlying sql.DB or sql.Tx instance.
|
// Return the underlying sql.DB or sql.Tx instance.
|
||||||
|
71
preload.go
Normal file
71
preload.go
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Preload(scope *Scope) {
|
||||||
|
// Get Fields
|
||||||
|
var fields map[string]*Field
|
||||||
|
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||||
|
elem := reflect.New(scope.IndirectValue().Type().Elem()).Elem()
|
||||||
|
fields = scope.New(elem.Addr().Interface()).Fields()
|
||||||
|
} else {
|
||||||
|
fields = scope.Fields()
|
||||||
|
}
|
||||||
|
|
||||||
|
if scope.Search.Preload != nil {
|
||||||
|
for key := range scope.Search.Preload {
|
||||||
|
for _, field := range fields {
|
||||||
|
if field.Name == key && field.Relationship != nil {
|
||||||
|
results := makeSlice(field.Field)
|
||||||
|
relation := field.Relationship
|
||||||
|
primaryName := scope.PrimaryKeyField().Name
|
||||||
|
|
||||||
|
switch relation.Kind {
|
||||||
|
case "has_one":
|
||||||
|
sql := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignKey))
|
||||||
|
scope.NewDB().Find(results, sql, scope.getColumnAsArray(primaryName))
|
||||||
|
case "has_many":
|
||||||
|
sql := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignKey))
|
||||||
|
scope.NewDB().Find(results, sql, scope.getColumnAsArray(primaryName))
|
||||||
|
case "belongs_to":
|
||||||
|
scope.NewDB().Find(results, scope.getColumnAsArray(relation.ForeignKey))
|
||||||
|
case "many_to_many":
|
||||||
|
scope.Err(errors.New("not supported relation"))
|
||||||
|
default:
|
||||||
|
scope.Err(errors.New("not supported relation"))
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeSlice(value reflect.Value) interface{} {
|
||||||
|
typ := value.Type()
|
||||||
|
if value.Kind() == reflect.Slice {
|
||||||
|
typ = typ.Elem()
|
||||||
|
}
|
||||||
|
sliceType := reflect.SliceOf(typ)
|
||||||
|
slice := reflect.New(sliceType)
|
||||||
|
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
|
||||||
|
return slice.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{}) {
|
||||||
|
values := scope.IndirectValue()
|
||||||
|
switch values.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
for i := 0; i < values.Len(); i++ {
|
||||||
|
value := values.Index(i)
|
||||||
|
primaryKeys = append(primaryKeys, value.FieldByName(column).Interface())
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
return []interface{}{values.FieldByName(column).Interface()}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
11
preload_test.go
Normal file
11
preload_test.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package gorm_test
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestPreload(t *testing.T) {
|
||||||
|
user := User{Name: "PreloadUser", BillingAddress: Address{Address1: "Billing Address"}, ShippingAddress: Address{Address1: "Shipping Address"}, Languages: []Language{{Name: "Preload L1"}, {Name: "Preload L2"}}}
|
||||||
|
DB.Save(&user)
|
||||||
|
|
||||||
|
var users []User
|
||||||
|
DB.Preload("BillingAddress").Preload("ShippingAddress").Preload("Languages").Find(&users)
|
||||||
|
}
|
@ -30,7 +30,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
|
|||||||
return scope.primaryCondiation(scope.AddToVars(value))
|
return scope.primaryCondiation(scope.AddToVars(value))
|
||||||
case sql.NullInt64:
|
case sql.NullInt64:
|
||||||
return scope.primaryCondiation(scope.AddToVars(value.Int64))
|
return scope.primaryCondiation(scope.AddToVars(value.Int64))
|
||||||
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
|
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
|
||||||
str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey()))
|
str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey()))
|
||||||
clause["args"] = []interface{}{value}
|
clause["args"] = []interface{}{value}
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user