diff --git a/callback_query_preload.go b/callback_query_preload.go index 30f6b585..481bfbe3 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -14,8 +14,14 @@ func preloadCallback(scope *Scope) { return } - if _, ok := scope.Get("gorm:auto_preload"); ok { - autoPreload(scope) + if ap, ok := scope.Get("gorm:auto_preload"); ok { + // If gorm:auto_preload IS NOT a bool then auto preload. + // Else if it IS a bool, use the value + if apb, ok := ap.(bool); !ok { + autoPreload(scope) + } else if apb { + autoPreload(scope) + } } if scope.Search.preload == nil || scope.HasError() { diff --git a/errors.go b/errors.go index da2cf13c..27c9a92d 100644 --- a/errors.go +++ b/errors.go @@ -6,7 +6,7 @@ import ( ) var ( - // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct + // ErrRecordNotFound record not found error, happens when only haven't find any matched data when looking up with a struct, finding a slice won't return this error ErrRecordNotFound = errors.New("record not found") // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL ErrInvalidSQL = errors.New("invalid SQL") diff --git a/preload_test.go b/preload_test.go index 311ad0be..1db625c9 100644 --- a/preload_test.go +++ b/preload_test.go @@ -123,6 +123,31 @@ func TestAutoPreload(t *testing.T) { } } +func TestAutoPreloadFalseDoesntPreload(t *testing.T) { + user1 := getPreloadUser("auto_user1") + DB.Save(user1) + + preloadDB := DB.Set("gorm:auto_preload", false).Where("role = ?", "Preload") + var user User + preloadDB.Find(&user) + + if user.BillingAddress.Address1 != "" { + t.Error("AutoPreload was set to fasle, but still fetched data") + } + + user2 := getPreloadUser("auto_user2") + DB.Save(user2) + + var users []User + preloadDB.Find(&users) + + for _, user := range users { + if user.BillingAddress.Address1 != "" { + t.Error("AutoPreload was set to fasle, but still fetched data") + } + } +} + func TestNestedPreload1(t *testing.T) { type ( Level1 struct { diff --git a/scope.go b/scope.go index a05c1d61..ca861d8a 100644 --- a/scope.go +++ b/scope.go @@ -586,10 +586,10 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) scope.Err(fmt.Errorf("invalid query condition: %v", value)) return } - + scopeQuotedTableName := newScope.QuotedTableName() for _, field := range newScope.Fields() { if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) + sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) } } return strings.Join(sqls, " AND ")