diff --git a/callback_query.go b/callback_query.go index 9358f332..ee6871f8 100644 --- a/callback_query.go +++ b/callback_query.go @@ -3,8 +3,6 @@ package gorm import ( "fmt" "reflect" - - "github.com/jinzhu/gorm" ) func Query(scope *Scope) { @@ -15,7 +13,6 @@ func Query(scope *Scope) { isPtr bool anyRecordFound bool destType reflect.Type - primaryKeys []interface{} ) var dest = scope.IndirectValue() @@ -50,8 +47,6 @@ func Query(scope *Scope) { return } - preloadMap := map[string]map[string]*gorm.Field{} - columns, _ := rows.Columns() defer rows.Close() for rows.Next() { @@ -66,7 +61,6 @@ func Query(scope *Scope) { var values = make([]interface{}, len(columns)) fields := scope.New(elem.Addr().Interface()).Fields() - var primaryKey interface{} for index, column := range columns { if field, ok := fields[column]; ok { if field.Field.Kind() == reflect.Ptr { @@ -74,10 +68,6 @@ func Query(scope *Scope) { } else { values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface() } - if field.IsPrimaryKey { - primaryKey = values[index] - primaryKeys = append(primaryKeys, primaryKey) - } } else { var value interface{} values[index] = &value @@ -104,34 +94,6 @@ func Query(scope *Scope) { dest.Set(reflect.Append(dest, elem)) } } - - if scope.Search.Preload != nil { - for key := range scope.Search.Preload { - if field := fields[key]; field != nil { - if preloadMap[key] == nil { - preloadMap[key] = map[string]reflect.Value{} - } - preloadMap[key][fmt.Sprintf("%v", primaryKey)] = field - } - } - } - } - - for _, value := range preloadMap { - var typ reflect.Type - var relation *Relation - for _, v := range value { - typ = v.Field.Type() - relation = v.Relationship - break - } - sliceType := reflect.SliceOf(typ) - slice := reflect.MakeSlice(sliceType, 0, 0) - slicePtr := reflect.New(sliceType) - slicePtr.Elem().Set(slice) - if relation == "has_many" { - scope.NewDB().Find(slicePtr.Interface(), primaryKeys) - } } if !anyRecordFound && !isSlice { @@ -147,4 +109,5 @@ func AfterQuery(scope *Scope) { func init() { DefaultCallback.Query().Register("gorm:query", Query) DefaultCallback.Query().Register("gorm:after_query", AfterQuery) + DefaultCallback.Query().Register("gorm:preload", Preload) } diff --git a/main.go b/main.go index 66897e12..9d833969 100644 --- a/main.go +++ b/main.go @@ -79,8 +79,9 @@ func (s *DB) DB() *sql.DB { } func (s *DB) New() *DB { - s.search = nil - return s.clone() + clone := s.clone() + clone.search = nil + return clone } // Return the underlying sql.DB or sql.Tx instance. diff --git a/preload.go b/preload.go new file mode 100644 index 00000000..eb6cb080 --- /dev/null +++ b/preload.go @@ -0,0 +1,71 @@ +package gorm + +import ( + "errors" + "fmt" + "reflect" +) + +func Preload(scope *Scope) { + // Get Fields + var fields map[string]*Field + if scope.IndirectValue().Kind() == reflect.Slice { + elem := reflect.New(scope.IndirectValue().Type().Elem()).Elem() + fields = scope.New(elem.Addr().Interface()).Fields() + } else { + fields = scope.Fields() + } + + if scope.Search.Preload != nil { + for key := range scope.Search.Preload { + for _, field := range fields { + if field.Name == key && field.Relationship != nil { + results := makeSlice(field.Field) + relation := field.Relationship + primaryName := scope.PrimaryKeyField().Name + + switch relation.Kind { + case "has_one": + sql := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignKey)) + scope.NewDB().Find(results, sql, scope.getColumnAsArray(primaryName)) + case "has_many": + sql := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignKey)) + scope.NewDB().Find(results, sql, scope.getColumnAsArray(primaryName)) + case "belongs_to": + scope.NewDB().Find(results, scope.getColumnAsArray(relation.ForeignKey)) + case "many_to_many": + scope.Err(errors.New("not supported relation")) + default: + scope.Err(errors.New("not supported relation")) + } + break + } + } + } + } +} + +func makeSlice(value reflect.Value) interface{} { + typ := value.Type() + if value.Kind() == reflect.Slice { + typ = typ.Elem() + } + sliceType := reflect.SliceOf(typ) + slice := reflect.New(sliceType) + slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) + return slice.Interface() +} + +func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{}) { + values := scope.IndirectValue() + switch values.Kind() { + case reflect.Slice: + for i := 0; i < values.Len(); i++ { + value := values.Index(i) + primaryKeys = append(primaryKeys, value.FieldByName(column).Interface()) + } + case reflect.Struct: + return []interface{}{values.FieldByName(column).Interface()} + } + return +} diff --git a/preload_test.go b/preload_test.go new file mode 100644 index 00000000..939dbd3e --- /dev/null +++ b/preload_test.go @@ -0,0 +1,11 @@ +package gorm_test + +import "testing" + +func TestPreload(t *testing.T) { + user := User{Name: "PreloadUser", BillingAddress: Address{Address1: "Billing Address"}, ShippingAddress: Address{Address1: "Shipping Address"}, Languages: []Language{{Name: "Preload L1"}, {Name: "Preload L2"}}} + DB.Save(&user) + + var users []User + DB.Preload("BillingAddress").Preload("ShippingAddress").Preload("Languages").Find(&users) +} diff --git a/scope_private.go b/scope_private.go index 05f6671a..f481f378 100644 --- a/scope_private.go +++ b/scope_private.go @@ -30,7 +30,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri return scope.primaryCondiation(scope.AddToVars(value)) case sql.NullInt64: return scope.primaryCondiation(scope.AddToVars(value.Int64)) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: + case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey())) clause["args"] = []interface{}{value} case map[string]interface{}: