From 3b784c37c471eebafd925c5e3be93063edd19e50 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 1 Feb 2015 23:19:29 +0800 Subject: [PATCH] First try for the Preload feature --- README.md | 3 ++- callback_query.go | 39 ++++++++++++++++++++++++++++++++++++++- main.go | 8 ++++---- search.go | 12 +++++++++--- 4 files changed, 53 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index f8349ac5..e336dd7e 100644 --- a/README.md +++ b/README.md @@ -1092,8 +1092,9 @@ db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111 db.Mode(&User{}).Do("EditForm").Get("edit_form_html") DefaultTimeZone, R/W Splitting, Validation * Github Pages -* Includes * AlertColumn, DropColumn +* db.Preload("Addresses.Map", "active = ?", true).Preload("Profile").Find(&users) +* db.Find(&users).Related(&users) # Author diff --git a/callback_query.go b/callback_query.go index 02824366..9358f332 100644 --- a/callback_query.go +++ b/callback_query.go @@ -3,6 +3,8 @@ package gorm import ( "fmt" "reflect" + + "github.com/jinzhu/gorm" ) func Query(scope *Scope) { @@ -13,6 +15,7 @@ func Query(scope *Scope) { isPtr bool anyRecordFound bool destType reflect.Type + primaryKeys []interface{} ) var dest = scope.IndirectValue() @@ -47,8 +50,9 @@ func Query(scope *Scope) { return } - columns, _ := rows.Columns() + preloadMap := map[string]map[string]*gorm.Field{} + columns, _ := rows.Columns() defer rows.Close() for rows.Next() { scope.db.RowsAffected += 1 @@ -62,6 +66,7 @@ func Query(scope *Scope) { var values = make([]interface{}, len(columns)) fields := scope.New(elem.Addr().Interface()).Fields() + var primaryKey interface{} for index, column := range columns { if field, ok := fields[column]; ok { if field.Field.Kind() == reflect.Ptr { @@ -69,6 +74,10 @@ func Query(scope *Scope) { } else { values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface() } + if field.IsPrimaryKey { + primaryKey = values[index] + primaryKeys = append(primaryKeys, primaryKey) + } } else { var value interface{} values[index] = &value @@ -95,6 +104,34 @@ func Query(scope *Scope) { dest.Set(reflect.Append(dest, elem)) } } + + if scope.Search.Preload != nil { + for key := range scope.Search.Preload { + if field := fields[key]; field != nil { + if preloadMap[key] == nil { + preloadMap[key] = map[string]reflect.Value{} + } + preloadMap[key][fmt.Sprintf("%v", primaryKey)] = field + } + } + } + } + + for _, value := range preloadMap { + var typ reflect.Type + var relation *Relation + for _, v := range value { + typ = v.Field.Type() + relation = v.Relationship + break + } + sliceType := reflect.SliceOf(typ) + slice := reflect.MakeSlice(sliceType, 0, 0) + slicePtr := reflect.New(sliceType) + slicePtr.Elem().Set(slice) + if relation == "has_many" { + scope.NewDB().Find(slicePtr.Interface(), primaryKeys) + } } if !anyRecordFound && !isSlice { diff --git a/main.go b/main.go index 11929ce1..66897e12 100644 --- a/main.go +++ b/main.go @@ -156,10 +156,6 @@ func (s *DB) Joins(query string) *DB { return s.clone().search.joins(query).db } -func (s *DB) Includes(value interface{}) *DB { - return s.clone().search.includes(value).db -} - func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { c := s for _, f := range funcs { @@ -432,6 +428,10 @@ func (s *DB) Association(column string) *Association { return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, PrimaryType: primaryType, Field: field} } +func (s *DB) Preload(column string, conditions ...interface{}) *DB { + return s.clone().search.preload(column, conditions...).db +} + // Set set value by name func (s *DB) Set(name string, value interface{}) *DB { return s.clone().InstantSet(name, value) diff --git a/search.go b/search.go index 78ed5300..51ec62da 100644 --- a/search.go +++ b/search.go @@ -13,6 +13,7 @@ type search struct { Orders []string Joins string Selects []map[string]interface{} + Preload map[string][]interface{} Offset string Limit string Group string @@ -23,6 +24,7 @@ type search struct { func (s *search) clone() *search { return &search{ + Preload: s.Preload, WhereConditions: s.WhereConditions, OrConditions: s.OrConditions, NotConditions: s.NotConditions, @@ -100,12 +102,16 @@ func (s *search) having(query string, values ...interface{}) *search { return s } -func (s *search) includes(value interface{}) *search { +func (s *search) joins(query string) *search { + s.Joins = query return s } -func (s *search) joins(query string) *search { - s.Joins = query +func (s *search) preload(column string, values ...interface{}) *search { + if s.Preload == nil { + s.Preload = map[string][]interface{}{} + } + s.Preload[column] = values return s }