Preload belongs_to, has_one, has_many relations

This commit is contained in:
Jinzhu 2015-02-11 13:43:53 +08:00
parent 3b784c37c4
commit dd635a17c2
5 changed files with 87 additions and 41 deletions

View File

@ -3,8 +3,6 @@ package gorm
import (
"fmt"
"reflect"
"github.com/jinzhu/gorm"
)
func Query(scope *Scope) {
@ -15,7 +13,6 @@ func Query(scope *Scope) {
isPtr bool
anyRecordFound bool
destType reflect.Type
primaryKeys []interface{}
)
var dest = scope.IndirectValue()
@ -50,8 +47,6 @@ func Query(scope *Scope) {
return
}
preloadMap := map[string]map[string]*gorm.Field{}
columns, _ := rows.Columns()
defer rows.Close()
for rows.Next() {
@ -66,7 +61,6 @@ func Query(scope *Scope) {
var values = make([]interface{}, len(columns))
fields := scope.New(elem.Addr().Interface()).Fields()
var primaryKey interface{}
for index, column := range columns {
if field, ok := fields[column]; ok {
if field.Field.Kind() == reflect.Ptr {
@ -74,10 +68,6 @@ func Query(scope *Scope) {
} else {
values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
}
if field.IsPrimaryKey {
primaryKey = values[index]
primaryKeys = append(primaryKeys, primaryKey)
}
} else {
var value interface{}
values[index] = &value
@ -104,34 +94,6 @@ func Query(scope *Scope) {
dest.Set(reflect.Append(dest, elem))
}
}
if scope.Search.Preload != nil {
for key := range scope.Search.Preload {
if field := fields[key]; field != nil {
if preloadMap[key] == nil {
preloadMap[key] = map[string]reflect.Value{}
}
preloadMap[key][fmt.Sprintf("%v", primaryKey)] = field
}
}
}
}
for _, value := range preloadMap {
var typ reflect.Type
var relation *Relation
for _, v := range value {
typ = v.Field.Type()
relation = v.Relationship
break
}
sliceType := reflect.SliceOf(typ)
slice := reflect.MakeSlice(sliceType, 0, 0)
slicePtr := reflect.New(sliceType)
slicePtr.Elem().Set(slice)
if relation == "has_many" {
scope.NewDB().Find(slicePtr.Interface(), primaryKeys)
}
}
if !anyRecordFound && !isSlice {
@ -147,4 +109,5 @@ func AfterQuery(scope *Scope) {
func init() {
DefaultCallback.Query().Register("gorm:query", Query)
DefaultCallback.Query().Register("gorm:after_query", AfterQuery)
DefaultCallback.Query().Register("gorm:preload", Preload)
}

View File

@ -79,8 +79,9 @@ func (s *DB) DB() *sql.DB {
}
func (s *DB) New() *DB {
s.search = nil
return s.clone()
clone := s.clone()
clone.search = nil
return clone
}
// Return the underlying sql.DB or sql.Tx instance.

71
preload.go Normal file
View File

@ -0,0 +1,71 @@
package gorm
import (
"errors"
"fmt"
"reflect"
)
func Preload(scope *Scope) {
// Get Fields
var fields map[string]*Field
if scope.IndirectValue().Kind() == reflect.Slice {
elem := reflect.New(scope.IndirectValue().Type().Elem()).Elem()
fields = scope.New(elem.Addr().Interface()).Fields()
} else {
fields = scope.Fields()
}
if scope.Search.Preload != nil {
for key := range scope.Search.Preload {
for _, field := range fields {
if field.Name == key && field.Relationship != nil {
results := makeSlice(field.Field)
relation := field.Relationship
primaryName := scope.PrimaryKeyField().Name
switch relation.Kind {
case "has_one":
sql := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignKey))
scope.NewDB().Find(results, sql, scope.getColumnAsArray(primaryName))
case "has_many":
sql := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignKey))
scope.NewDB().Find(results, sql, scope.getColumnAsArray(primaryName))
case "belongs_to":
scope.NewDB().Find(results, scope.getColumnAsArray(relation.ForeignKey))
case "many_to_many":
scope.Err(errors.New("not supported relation"))
default:
scope.Err(errors.New("not supported relation"))
}
break
}
}
}
}
}
func makeSlice(value reflect.Value) interface{} {
typ := value.Type()
if value.Kind() == reflect.Slice {
typ = typ.Elem()
}
sliceType := reflect.SliceOf(typ)
slice := reflect.New(sliceType)
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
return slice.Interface()
}
func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{}) {
values := scope.IndirectValue()
switch values.Kind() {
case reflect.Slice:
for i := 0; i < values.Len(); i++ {
value := values.Index(i)
primaryKeys = append(primaryKeys, value.FieldByName(column).Interface())
}
case reflect.Struct:
return []interface{}{values.FieldByName(column).Interface()}
}
return
}

11
preload_test.go Normal file
View File

@ -0,0 +1,11 @@
package gorm_test
import "testing"
func TestPreload(t *testing.T) {
user := User{Name: "PreloadUser", BillingAddress: Address{Address1: "Billing Address"}, ShippingAddress: Address{Address1: "Shipping Address"}, Languages: []Language{{Name: "Preload L1"}, {Name: "Preload L2"}}}
DB.Save(&user)
var users []User
DB.Preload("BillingAddress").Preload("ShippingAddress").Preload("Languages").Find(&users)
}

View File

@ -30,7 +30,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
return scope.primaryCondiation(scope.AddToVars(value))
case sql.NullInt64:
return scope.primaryCondiation(scope.AddToVars(value.Int64))
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey()))
clause["args"] = []interface{}{value}
case map[string]interface{}: