Merge branch 'go-gorm:master' into fix/nested-joins-preloads

This commit is contained in:
Nico Schäfer 2024-05-02 12:40:26 +02:00 committed by GitHub
commit 3bef3bad27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 233 additions and 55 deletions

View File

@ -75,7 +75,7 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations)) names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
for _, relation := range embeddedRelations.Relations { for _, relation := range embeddedRelations.Relations {
// skip first struct name // skip first struct name
names = append(names, strings.Join(relation.Field.BindNames[1:], ".")) names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
} }
for _, relations := range embeddedRelations.EmbeddedRelations { for _, relations := range embeddedRelations.EmbeddedRelations {
names = append(names, embeddedValues(relations)...) names = append(names, embeddedValues(relations)...)
@ -123,8 +123,18 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
if joined, nestedJoins := isJoined(name); joined { if joined, nestedJoins := isJoined(name); joined {
switch rv := db.Statement.ReflectValue; rv.Kind() { switch rv := db.Statement.ReflectValue; rv.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ { if rv.Len() > 0 {
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i)) reflectValue := rel.FieldSchema.MakeSlice().Elem()
reflectValue.SetLen(rv.Len())
for i := 0; i < rv.Len(); i++ {
frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
if frv.Kind() != reflect.Ptr {
reflectValue.Index(i).Set(frv.Addr())
} else {
reflectValue.Index(i).Set(frv)
}
}
tx := preloadDB(db, reflectValue, reflectValue.Interface()) tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err return err

View File

@ -429,6 +429,15 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
return return
} }
// Unscoped disables the global scope of soft deletion in a query.
// By default, GORM uses soft deletion, marking records as "deleted"
// by setting a timestamp on a specific field (e.g., `deleted_at`).
// Unscoped allows queries to include records marked as deleted,
// overriding the soft deletion behavior.
// Example:
// var users []User
// db.Unscoped().Find(&users)
// // Retrieves all users, including deleted ones.
func (db *DB) Unscoped() (tx *DB) { func (db *DB) Unscoped() (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Unscoped = true tx.Statement.Unscoped = true

View File

@ -215,7 +215,12 @@ func (not NotConditions) Build(builder Builder) {
for idx, c := range not.Exprs { for idx, c := range not.Exprs {
if idx > 0 { if idx > 0 {
builder.WriteString(AndWithSpace) switch c.(type) {
case OrConditions:
builder.WriteString(OrWithSpace)
default:
builder.WriteString(AndWithSpace)
}
} }
e, wrapInParentheses := c.(Expr) e, wrapInParentheses := c.(Expr)

View File

@ -113,6 +113,22 @@ func TestWhere(t *testing.T) {
"SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)", "SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)",
[]interface{}{100, 60}, []interface{}{100, 60},
}, },
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{
clause.Not(clause.AndConditions{
Exprs: []clause.Expression{
clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.Gt{Column: "age", Value: 18},
}}, clause.OrConditions{
Exprs: []clause.Expression{
clause.Lt{Column: "score", Value: 100},
},
}),
}}},
"SELECT * FROM `users` WHERE NOT ((`users`.`id` = ? AND `age` > ?) OR `score` < ?)",
[]interface{}{"1", 18, 100},
},
} }
for idx, result := range results { for idx, result := range results {

View File

@ -49,4 +49,6 @@ var (
ErrDuplicatedKey = errors.New("duplicated key not allowed") ErrDuplicatedKey = errors.New("duplicated key not allowed")
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation // ErrForeignKeyViolated occurs when there is a foreign key constraint violation
ErrForeignKeyViolated = errors.New("violates foreign key constraint") ErrForeignKeyViolated = errors.New("violates foreign key constraint")
// ErrCheckConstraintViolated occurs when there is a check constraint violation
ErrCheckConstraintViolated = errors.New("violates check constraint")
) )

View File

@ -127,6 +127,11 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
} }
} else { } else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
columnTypes, err := queryTx.Migrator().ColumnTypes(value) columnTypes, err := queryTx.Migrator().ColumnTypes(value)
if err != nil { if err != nil {
return err return err
@ -211,6 +216,11 @@ func (m Migrator) CreateTable(values ...interface{}) error {
for _, value := range m.ReorderModels(values, false) { for _, value := range m.ReorderModels(values, false) {
tx := m.DB.Session(&gorm.Session{}) tx := m.DB.Session(&gorm.Session{})
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
var ( var (
createTableSQL = "CREATE TABLE ? (" createTableSQL = "CREATE TABLE ? ("
values = []interface{}{m.CurrentTable(stmt)} values = []interface{}{m.CurrentTable(stmt)}
@ -363,6 +373,9 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
func (m Migrator) AddColumn(value interface{}, name string) error { func (m Migrator) AddColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// avoid using the same name field // avoid using the same name field
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
f := stmt.Schema.LookUpField(name) f := stmt.Schema.LookUpField(name)
if f == nil { if f == nil {
return fmt.Errorf("failed to look up field with name: %s", name) return fmt.Errorf("failed to look up field with name: %s", name)
@ -382,8 +395,10 @@ func (m Migrator) AddColumn(value interface{}, name string) error {
// DropColumn drop value's `name` column // DropColumn drop value's `name` column
func (m Migrator) DropColumn(value interface{}, name string) error { func (m Migrator) DropColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(name); field != nil { if stmt.Schema != nil {
name = field.DBName if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
} }
return m.DB.Exec( return m.DB.Exec(
@ -395,13 +410,15 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
// AlterColumn alter value's `field` column' type based on schema definition // AlterColumn alter value's `field` column' type based on schema definition
func (m Migrator) AlterColumn(value interface{}, field string) error { func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil { if stmt.Schema != nil {
fileType := m.FullDataTypeOf(field) if field := stmt.Schema.LookUpField(field); field != nil {
return m.DB.Exec( fileType := m.FullDataTypeOf(field)
"ALTER TABLE ? ALTER COLUMN ? TYPE ?", return m.DB.Exec(
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, "ALTER TABLE ? ALTER COLUMN ? TYPE ?",
).Error m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
).Error
}
} }
return fmt.Errorf("failed to look up field with name: %s", field) return fmt.Errorf("failed to look up field with name: %s", field)
}) })
@ -413,8 +430,10 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase() currentDatabase := m.DB.Migrator().CurrentDatabase()
name := field name := field
if field := stmt.Schema.LookUpField(field); field != nil { if stmt.Schema != nil {
name = field.DBName if field := stmt.Schema.LookUpField(field); field != nil {
name = field.DBName
}
} }
return m.DB.Raw( return m.DB.Raw(
@ -429,12 +448,14 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
// RenameColumn rename value's field name from oldName to newName // RenameColumn rename value's field name from oldName to newName
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(oldName); field != nil { if stmt.Schema != nil {
oldName = field.DBName if field := stmt.Schema.LookUpField(oldName); field != nil {
} oldName = field.DBName
}
if field := stmt.Schema.LookUpField(newName); field != nil { if field := stmt.Schema.LookUpField(newName); field != nil {
newName = field.DBName newName = field.DBName
}
} }
return m.DB.Exec( return m.DB.Exec(
@ -794,6 +815,9 @@ type BuildIndexOptionsInterface interface {
// CreateIndex create index `name` // CreateIndex create index `name`
func (m Migrator) CreateIndex(value interface{}, name string) error { func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
if idx := stmt.Schema.LookIndex(name); idx != nil { if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
@ -826,8 +850,10 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
// DropIndex drop index `name` // DropIndex drop index `name`
func (m Migrator) DropIndex(value interface{}, name string) error { func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil { if stmt.Schema != nil {
name = idx.Name if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
} }
return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error
@ -839,8 +865,10 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int64 var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase() currentDatabase := m.DB.Migrator().CurrentDatabase()
if idx := stmt.Schema.LookIndex(name); idx != nil { if stmt.Schema != nil {
name = idx.Name if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
} }
return m.DB.Raw( return m.DB.Raw(

View File

@ -257,9 +257,11 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
continue continue
} }
} }
values[idx] = &sql.RawBytes{} var val interface{}
values[idx] = &val
} else { } else {
values[idx] = &sql.RawBytes{} var val interface{}
values[idx] = &val
} }
} }
} }

View File

@ -56,6 +56,7 @@ type Field struct {
Name string Name string
DBName string DBName string
BindNames []string BindNames []string
EmbeddedBindNames []string
DataType DataType DataType DataType
GORMDataType DataType GORMDataType DataType
PrimaryKey bool PrimaryKey bool
@ -112,6 +113,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
Name: fieldStruct.Name, Name: fieldStruct.Name,
DBName: tagSetting["COLUMN"], DBName: tagSetting["COLUMN"],
BindNames: []string{fieldStruct.Name}, BindNames: []string{fieldStruct.Name},
EmbeddedBindNames: []string{fieldStruct.Name},
FieldType: fieldStruct.Type, FieldType: fieldStruct.Type,
IndirectFieldType: fieldStruct.Type, IndirectFieldType: fieldStruct.Type,
StructField: fieldStruct, StructField: fieldStruct,
@ -403,6 +405,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
ef.Schema = schema ef.Schema = schema
ef.OwnerSchema = field.EmbeddedSchema ef.OwnerSchema = field.EmbeddedSchema
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
if _, ok := field.TagSettings["EMBEDDED"]; ok || !fieldStruct.Anonymous {
ef.EmbeddedBindNames = append([]string{fieldStruct.Name}, ef.EmbeddedBindNames...)
}
// index is negative means is pointer // index is negative means is pointer
if field.FieldType.Kind() == reflect.Struct { if field.FieldType.Kind() == reflect.Struct {
ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...)

View File

@ -150,12 +150,12 @@ func (schema *Schema) setRelation(relation *Relationship) {
} }
// set embedded relation // set embedded relation
if len(relation.Field.BindNames) <= 1 { if len(relation.Field.EmbeddedBindNames) <= 1 {
return return
} }
relationships := &schema.Relationships relationships := &schema.Relationships
for i, name := range relation.Field.BindNames { for i, name := range relation.Field.EmbeddedBindNames {
if i < len(relation.Field.BindNames)-1 { if i < len(relation.Field.EmbeddedBindNames)-1 {
if relationships.EmbeddedRelations == nil { if relationships.EmbeddedRelations == nil {
relationships.EmbeddedRelations = map[string]*Relationships{} relationships.EmbeddedRelations = map[string]*Relationships{}
} }

View File

@ -121,6 +121,29 @@ func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) {
}) })
} }
func TestBelongsToWithMixin(t *testing.T) {
type Profile struct {
gorm.Model
Refer string
Name string
}
type ProfileMixin struct {
Profile Profile `gorm:"References:Refer"`
ProfileRefer int
}
type User struct {
gorm.Model
ProfileMixin
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}},
})
}
func TestHasOneOverrideForeignKey(t *testing.T) { func TestHasOneOverrideForeignKey(t *testing.T) {
type Profile struct { type Profile struct {
gorm.Model gorm.Model
@ -776,6 +799,10 @@ func TestEmbeddedBelongsTo(t *testing.T) {
type NestedAddress struct { type NestedAddress struct {
Address Address
} }
type CountryMixin struct {
CountryID int
Country Country
}
type Org struct { type Org struct {
ID int ID int
PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"` PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"`
@ -786,6 +813,7 @@ func TestEmbeddedBelongsTo(t *testing.T) {
Address Address
} }
NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"` NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
CountryMixin
} }
s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{}) s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{})
@ -815,15 +843,11 @@ func TestEmbeddedBelongsTo(t *testing.T) {
}, },
}, },
"NestedAddress": { "NestedAddress": {
EmbeddedRelations: map[string]EmbeddedRelations{ Relations: map[string]Relation{
"Address": { "Country": {
Relations: map[string]Relation{ Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
"Country": { References: []Reference{
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
}, },
}, },
}, },

View File

@ -26,7 +26,7 @@ require (
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/microsoft/go-mssqldb v1.7.0 // indirect github.com/microsoft/go-mssqldb v1.7.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect
golang.org/x/crypto v0.22.0 // indirect golang.org/x/crypto v0.22.0 // indirect
@ -37,3 +37,5 @@ require (
replace gorm.io/gorm => ../ replace gorm.io/gorm => ../
replace github.com/jackc/pgx/v5 => github.com/jackc/pgx/v5 v5.4.3 replace github.com/jackc/pgx/v5 => github.com/jackc/pgx/v5 v5.4.3
replace github.com/microsoft/go-mssqldb => github.com/microsoft/go-mssqldb v1.7.0

View File

@ -1,14 +1,14 @@
package tests_test package tests_test
import ( import (
"context"
"encoding/json" "encoding/json"
"regexp" "regexp"
"sort" "sort"
"strconv" "strconv"
"sync" "sync"
"testing" "testing"
"time"
"github.com/stretchr/testify/require"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
@ -337,7 +337,7 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) {
DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{}) DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{})
DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{}) DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{})
value := Value{ value1 := Value{
Name: "value", Name: "value",
Nested: Nested{ Nested: Nested{
Preloads: []*Preload{ Preloads: []*Preload{
@ -346,32 +346,98 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) {
Join: Join{Value: "j1"}, Join: Join{Value: "j1"},
}, },
} }
if err := DB.Create(&value).Error; err != nil { value2 := Value{
Name: "value2",
Nested: Nested{
Preloads: []*Preload{
{Value: "p3"}, {Value: "p4"}, {Value: "p5"},
},
Join: Join{Value: "j2"},
},
}
values := []*Value{&value1, &value2}
if err := DB.Create(&values).Error; err != nil {
t.Errorf("failed to create value, got err: %v", err) t.Errorf("failed to create value, got err: %v", err)
} }
var find1 Value var find1 Value
err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1).Error err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1, value1.ID).Error
if err != nil { if err != nil {
t.Errorf("failed to find value, got err: %v", err) t.Errorf("failed to find value, got err: %v", err)
} }
AssertEqual(t, find1, value) AssertEqual(t, find1, value1)
var find2 Value var find2 Value
// Joins will automatically add Nested queries. // Joins will automatically add Nested queries.
err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2).Error err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2, value2.ID).Error
if err != nil { if err != nil {
t.Errorf("failed to find value, got err: %v", err) t.Errorf("failed to find value, got err: %v", err)
} }
AssertEqual(t, find2, value) AssertEqual(t, find2, value2)
var finds []Value var finds []Value
err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error
if err != nil { if err != nil {
t.Errorf("failed to find value, got err: %v", err) t.Errorf("failed to find value, got err: %v", err)
} }
require.Len(t, finds, 1) AssertEqual(t, len(finds), 2)
AssertEqual(t, finds[0], value) AssertEqual(t, finds[0], value1)
AssertEqual(t, finds[1], value2)
}
func TestMergeNestedPreloadWithNestedJoin(t *testing.T) {
users := []User{
{
Name: "TestMergeNestedPreloadWithNestedJoin-1",
Manager: &User{
Name: "Alexis Manager",
Tools: []Tools{
{Name: "Alexis Tool 1"},
{Name: "Alexis Tool 2"},
},
},
},
{
Name: "TestMergeNestedPreloadWithNestedJoin-2",
Manager: &User{
Name: "Jinzhu Manager",
Tools: []Tools{
{Name: "Jinzhu Tool 1"},
{Name: "Jinzhu Tool 2"},
},
},
},
}
DB.Create(&users)
query := make([]string, 0)
sess := DB.Session(&gorm.Session{Logger: Tracer{
Logger: DB.Config.Logger,
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
query = append(query, sql)
},
}})
var result []User
err := sess.
Joins("Manager").
Preload("Manager.Tools").
Where("users.name Like ?", "TestMergeNestedPreloadWithNestedJoin%").
Find(&result).Error
if err != nil {
t.Fatalf("failed to preload and find users: %v", err)
}
AssertEqual(t, result, users)
AssertEqual(t, len(query), 2) // Check preload queries are merged
if !regexp.MustCompile(`SELECT \* FROM .*tools.* WHERE .*IN.*`).MatchString(query[0]) {
t.Fatalf("Expected first query to preload manager tools, got: %s", query[0])
}
} }
func TestNestedPreloadWithPointerJoin(t *testing.T) { func TestNestedPreloadWithPointerJoin(t *testing.T) {
@ -518,7 +584,7 @@ func TestEmbedPreload(t *testing.T) {
}, },
}, { }, {
name: "nested address country", name: "nested address country",
preloads: map[string][]interface{}{"NestedAddress.EmbeddedAddress.Country": {}}, preloads: map[string][]interface{}{"NestedAddress.Country": {}},
expect: Org{ expect: Org{
ID: org.ID, ID: org.ID,
PostalAddress: EmbeddedAddress{ PostalAddress: EmbeddedAddress{

View File

@ -559,6 +559,11 @@ func TestNot(t *testing.T) {
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT \\(manager IS NULL AND age >= .+\\) AND .users.\\..deleted_at. IS NULL").MatchString(result.Statement.SQL.String()) { if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT \\(manager IS NULL AND age >= .+\\) AND .users.\\..deleted_at. IS NULL").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
} }
result = dryDB.Not(DB.Where("manager IS NULL").Or("age >= ?", 20)).Find(&User{})
if !regexp.MustCompile(`SELECT \* FROM .*users.* WHERE NOT \(manager IS NULL OR age >= .+\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
}
} }
func TestNotWithAllFields(t *testing.T) { func TestNotWithAllFields(t *testing.T) {

View File

@ -32,12 +32,16 @@ func sourceDir(file string) string {
// FileWithLineNum return the file name and line number of the current file // FileWithLineNum return the file name and line number of the current file
func FileWithLineNum() string { func FileWithLineNum() string {
// the second caller usually from gorm internal, so set i start from 2 pcs := [13]uintptr{}
for i := 2; i < 15; i++ { // the third caller usually from gorm internal
_, file, line, ok := runtime.Caller(i) len := runtime.Callers(3, pcs[:])
if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) && frames := runtime.CallersFrames(pcs[:len])
!strings.HasSuffix(file, ".gen.go") { for i := 0; i < len; i++ {
return file + ":" + strconv.FormatInt(int64(line), 10) // second return value is "more", not "ok"
frame, _ := frames.Next()
if (!strings.HasPrefix(frame.File, gormSourceDir) ||
strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") {
return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10))
} }
} }