diff --git a/callbacks/preload.go b/callbacks/preload.go index 9a9ce46e..3dd0dea3 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -75,7 +75,7 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string { names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations)) for _, relation := range embeddedRelations.Relations { // skip first struct name - names = append(names, strings.Join(relation.Field.BindNames[1:], ".")) + names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], ".")) } for _, relations := range embeddedRelations.EmbeddedRelations { names = append(names, embeddedValues(relations)...) @@ -123,8 +123,18 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati if joined, nestedJoins := isJoined(name); joined { switch rv := db.Statement.ReflectValue; rv.Kind() { case reflect.Slice, reflect.Array: - for i := 0; i < rv.Len(); i++ { - reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i)) + if rv.Len() > 0 { + reflectValue := rel.FieldSchema.MakeSlice().Elem() + reflectValue.SetLen(rv.Len()) + for i := 0; i < rv.Len(); i++ { + frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i)) + if frv.Kind() != reflect.Ptr { + reflectValue.Index(i).Set(frv.Addr()) + } else { + reflectValue.Index(i).Set(frv) + } + } + tx := preloadDB(db, reflectValue, reflectValue.Interface()) if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { return err diff --git a/chainable_api.go b/chainable_api.go index 1ec9b865..33370603 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -429,6 +429,15 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) { return } +// Unscoped disables the global scope of soft deletion in a query. +// By default, GORM uses soft deletion, marking records as "deleted" +// by setting a timestamp on a specific field (e.g., `deleted_at`). +// Unscoped allows queries to include records marked as deleted, +// overriding the soft deletion behavior. +// Example: +// var users []User +// db.Unscoped().Find(&users) +// // Retrieves all users, including deleted ones. func (db *DB) Unscoped() (tx *DB) { tx = db.getInstance() tx.Statement.Unscoped = true diff --git a/clause/where.go b/clause/where.go index 9ac78578..2c3c90f1 100644 --- a/clause/where.go +++ b/clause/where.go @@ -215,7 +215,12 @@ func (not NotConditions) Build(builder Builder) { for idx, c := range not.Exprs { if idx > 0 { - builder.WriteString(AndWithSpace) + switch c.(type) { + case OrConditions: + builder.WriteString(OrWithSpace) + default: + builder.WriteString(AndWithSpace) + } } e, wrapInParentheses := c.(Expr) diff --git a/clause/where_test.go b/clause/where_test.go index 7d5aca1f..ad23a4ed 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -113,6 +113,22 @@ func TestWhere(t *testing.T) { "SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)", []interface{}{100, 60}, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{ + clause.Not(clause.AndConditions{ + Exprs: []clause.Expression{ + clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, + clause.Gt{Column: "age", Value: 18}, + }}, clause.OrConditions{ + Exprs: []clause.Expression{ + clause.Lt{Column: "score", Value: 100}, + }, + }), + }}}, + "SELECT * FROM `users` WHERE NOT ((`users`.`id` = ? AND `age` > ?) OR `score` < ?)", + []interface{}{"1", 18, 100}, + }, } for idx, result := range results { diff --git a/errors.go b/errors.go index cd76f1f5..025f5d64 100644 --- a/errors.go +++ b/errors.go @@ -49,4 +49,6 @@ var ( ErrDuplicatedKey = errors.New("duplicated key not allowed") // ErrForeignKeyViolated occurs when there is a foreign key constraint violation ErrForeignKeyViolated = errors.New("violates foreign key constraint") + // ErrCheckConstraintViolated occurs when there is a check constraint violation + ErrCheckConstraintViolated = errors.New("violates check constraint") ) diff --git a/migrator/migrator.go b/migrator/migrator.go index acce5df2..189a141f 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -127,6 +127,11 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } else { if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + + if stmt.Schema == nil { + return errors.New("failed to get schema") + } + columnTypes, err := queryTx.Migrator().ColumnTypes(value) if err != nil { return err @@ -211,6 +216,11 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { + + if stmt.Schema == nil { + return errors.New("failed to get schema") + } + var ( createTableSQL = "CREATE TABLE ? (" values = []interface{}{m.CurrentTable(stmt)} @@ -363,6 +373,9 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error { func (m Migrator) AddColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { // avoid using the same name field + if stmt.Schema == nil { + return errors.New("failed to get schema") + } f := stmt.Schema.LookUpField(name) if f == nil { return fmt.Errorf("failed to look up field with name: %s", name) @@ -382,8 +395,10 @@ func (m Migrator) AddColumn(value interface{}, name string) error { // DropColumn drop value's `name` column func (m Migrator) DropColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(name); field != nil { - name = field.DBName + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(name); field != nil { + name = field.DBName + } } return m.DB.Exec( @@ -395,13 +410,15 @@ func (m Migrator) DropColumn(value interface{}, name string) error { // AlterColumn alter value's `field` column' type based on schema definition func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(field); field != nil { - fileType := m.FullDataTypeOf(field) - return m.DB.Exec( - "ALTER TABLE ? ALTER COLUMN ? TYPE ?", - m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, - ).Error + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(field); field != nil { + fileType := m.FullDataTypeOf(field) + return m.DB.Exec( + "ALTER TABLE ? ALTER COLUMN ? TYPE ?", + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, + ).Error + } } return fmt.Errorf("failed to look up field with name: %s", field) }) @@ -413,8 +430,10 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() name := field - if field := stmt.Schema.LookUpField(field); field != nil { - name = field.DBName + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(field); field != nil { + name = field.DBName + } } return m.DB.Raw( @@ -429,12 +448,14 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { // RenameColumn rename value's field name from oldName to newName func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(oldName); field != nil { - oldName = field.DBName - } + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(oldName); field != nil { + oldName = field.DBName + } - if field := stmt.Schema.LookUpField(newName); field != nil { - newName = field.DBName + if field := stmt.Schema.LookUpField(newName); field != nil { + newName = field.DBName + } } return m.DB.Exec( @@ -794,6 +815,9 @@ type BuildIndexOptionsInterface interface { // CreateIndex create index `name` func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if stmt.Schema == nil { + return errors.New("failed to get schema") + } if idx := stmt.Schema.LookIndex(name); idx != nil { opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} @@ -826,8 +850,10 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { // DropIndex drop index `name` func (m Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } } return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error @@ -839,8 +865,10 @@ func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } } return m.DB.Raw( diff --git a/scan.go b/scan.go index 415b9f0d..e95e6d30 100644 --- a/scan.go +++ b/scan.go @@ -257,9 +257,11 @@ func Scan(rows Rows, db *DB, mode ScanMode) { continue } } - values[idx] = &sql.RawBytes{} + var val interface{} + values[idx] = &val } else { - values[idx] = &sql.RawBytes{} + var val interface{} + values[idx] = &val } } } diff --git a/schema/field.go b/schema/field.go index ca2e1148..a16c98ab 100644 --- a/schema/field.go +++ b/schema/field.go @@ -56,6 +56,7 @@ type Field struct { Name string DBName string BindNames []string + EmbeddedBindNames []string DataType DataType GORMDataType DataType PrimaryKey bool @@ -112,6 +113,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { Name: fieldStruct.Name, DBName: tagSetting["COLUMN"], BindNames: []string{fieldStruct.Name}, + EmbeddedBindNames: []string{fieldStruct.Name}, FieldType: fieldStruct.Type, IndirectFieldType: fieldStruct.Type, StructField: fieldStruct, @@ -403,6 +405,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.Schema = schema ef.OwnerSchema = field.EmbeddedSchema ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) + if _, ok := field.TagSettings["EMBEDDED"]; ok || !fieldStruct.Anonymous { + ef.EmbeddedBindNames = append([]string{fieldStruct.Name}, ef.EmbeddedBindNames...) + } // index is negative means is pointer if field.FieldType.Kind() == reflect.Struct { ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) diff --git a/schema/relationship.go b/schema/relationship.go index 2e94fc2c..c11918a5 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -150,12 +150,12 @@ func (schema *Schema) setRelation(relation *Relationship) { } // set embedded relation - if len(relation.Field.BindNames) <= 1 { + if len(relation.Field.EmbeddedBindNames) <= 1 { return } relationships := &schema.Relationships - for i, name := range relation.Field.BindNames { - if i < len(relation.Field.BindNames)-1 { + for i, name := range relation.Field.EmbeddedBindNames { + if i < len(relation.Field.EmbeddedBindNames)-1 { if relationships.EmbeddedRelations == nil { relationships.EmbeddedRelations = map[string]*Relationships{} } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 23d79bbb..f1acf2d9 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -121,6 +121,29 @@ func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) { }) } +func TestBelongsToWithMixin(t *testing.T) { + type Profile struct { + gorm.Model + Refer string + Name string + } + + type ProfileMixin struct { + Profile Profile `gorm:"References:Refer"` + ProfileRefer int + } + + type User struct { + gorm.Model + ProfileMixin + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}}, + }) +} + func TestHasOneOverrideForeignKey(t *testing.T) { type Profile struct { gorm.Model @@ -776,6 +799,10 @@ func TestEmbeddedBelongsTo(t *testing.T) { type NestedAddress struct { Address } + type CountryMixin struct { + CountryID int + Country Country + } type Org struct { ID int PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"` @@ -786,6 +813,7 @@ func TestEmbeddedBelongsTo(t *testing.T) { Address } NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"` + CountryMixin } s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{}) @@ -815,15 +843,11 @@ func TestEmbeddedBelongsTo(t *testing.T) { }, }, "NestedAddress": { - EmbeddedRelations: map[string]EmbeddedRelations{ - "Address": { - Relations: map[string]Relation{ - "Country": { - Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", - References: []Reference{ - {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, - }, - }, + Relations: map[string]Relation{ + "Country": { + Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", + References: []Reference{ + {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, }, }, }, diff --git a/tests/go.mod b/tests/go.mod index d58469c4..10fa7ec8 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -26,7 +26,7 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/mattn/go-sqlite3 v1.14.22 // indirect - github.com/microsoft/go-mssqldb v1.7.0 // indirect + github.com/microsoft/go-mssqldb v1.7.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect golang.org/x/crypto v0.22.0 // indirect @@ -37,3 +37,5 @@ require ( replace gorm.io/gorm => ../ replace github.com/jackc/pgx/v5 => github.com/jackc/pgx/v5 v5.4.3 + +replace github.com/microsoft/go-mssqldb => github.com/microsoft/go-mssqldb v1.7.0 diff --git a/tests/preload_test.go b/tests/preload_test.go index 3c5dca1f..29746cff 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -1,14 +1,14 @@ package tests_test import ( + "context" "encoding/json" "regexp" "sort" "strconv" "sync" "testing" - - "github.com/stretchr/testify/require" + "time" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -337,7 +337,7 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) { DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{}) DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{}) - value := Value{ + value1 := Value{ Name: "value", Nested: Nested{ Preloads: []*Preload{ @@ -346,32 +346,98 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) { Join: Join{Value: "j1"}, }, } - if err := DB.Create(&value).Error; err != nil { + value2 := Value{ + Name: "value2", + Nested: Nested{ + Preloads: []*Preload{ + {Value: "p3"}, {Value: "p4"}, {Value: "p5"}, + }, + Join: Join{Value: "j2"}, + }, + } + + values := []*Value{&value1, &value2} + if err := DB.Create(&values).Error; err != nil { t.Errorf("failed to create value, got err: %v", err) } var find1 Value - err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1).Error + err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1, value1.ID).Error if err != nil { t.Errorf("failed to find value, got err: %v", err) } - AssertEqual(t, find1, value) + AssertEqual(t, find1, value1) var find2 Value // Joins will automatically add Nested queries. - err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2).Error + err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2, value2.ID).Error if err != nil { t.Errorf("failed to find value, got err: %v", err) } - AssertEqual(t, find2, value) + AssertEqual(t, find2, value2) var finds []Value err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error if err != nil { t.Errorf("failed to find value, got err: %v", err) } - require.Len(t, finds, 1) - AssertEqual(t, finds[0], value) + AssertEqual(t, len(finds), 2) + AssertEqual(t, finds[0], value1) + AssertEqual(t, finds[1], value2) +} + +func TestMergeNestedPreloadWithNestedJoin(t *testing.T) { + users := []User{ + { + Name: "TestMergeNestedPreloadWithNestedJoin-1", + Manager: &User{ + Name: "Alexis Manager", + Tools: []Tools{ + {Name: "Alexis Tool 1"}, + {Name: "Alexis Tool 2"}, + }, + }, + }, + { + Name: "TestMergeNestedPreloadWithNestedJoin-2", + Manager: &User{ + Name: "Jinzhu Manager", + Tools: []Tools{ + {Name: "Jinzhu Tool 1"}, + {Name: "Jinzhu Tool 2"}, + }, + }, + }, + } + + DB.Create(&users) + + query := make([]string, 0) + sess := DB.Session(&gorm.Session{Logger: Tracer{ + Logger: DB.Config.Logger, + Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + query = append(query, sql) + }, + }}) + + var result []User + err := sess. + Joins("Manager"). + Preload("Manager.Tools"). + Where("users.name Like ?", "TestMergeNestedPreloadWithNestedJoin%"). + Find(&result).Error + + if err != nil { + t.Fatalf("failed to preload and find users: %v", err) + } + + AssertEqual(t, result, users) + AssertEqual(t, len(query), 2) // Check preload queries are merged + + if !regexp.MustCompile(`SELECT \* FROM .*tools.* WHERE .*IN.*`).MatchString(query[0]) { + t.Fatalf("Expected first query to preload manager tools, got: %s", query[0]) + } } func TestNestedPreloadWithPointerJoin(t *testing.T) { @@ -518,7 +584,7 @@ func TestEmbedPreload(t *testing.T) { }, }, { name: "nested address country", - preloads: map[string][]interface{}{"NestedAddress.EmbeddedAddress.Country": {}}, + preloads: map[string][]interface{}{"NestedAddress.Country": {}}, expect: Org{ ID: org.ID, PostalAddress: EmbeddedAddress{ diff --git a/tests/query_test.go b/tests/query_test.go index c0259a14..d0ed675a 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -559,6 +559,11 @@ func TestNot(t *testing.T) { if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT \\(manager IS NULL AND age >= .+\\) AND .users.\\..deleted_at. IS NULL").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } + + result = dryDB.Not(DB.Where("manager IS NULL").Or("age >= ?", 20)).Find(&User{}) + if !regexp.MustCompile(`SELECT \* FROM .*users.* WHERE NOT \(manager IS NULL OR age >= .+\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } } func TestNotWithAllFields(t *testing.T) { diff --git a/utils/utils.go b/utils/utils.go index 347a331f..b8d30b35 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -32,12 +32,16 @@ func sourceDir(file string) string { // FileWithLineNum return the file name and line number of the current file func FileWithLineNum() string { - // the second caller usually from gorm internal, so set i start from 2 - for i := 2; i < 15; i++ { - _, file, line, ok := runtime.Caller(i) - if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) && - !strings.HasSuffix(file, ".gen.go") { - return file + ":" + strconv.FormatInt(int64(line), 10) + pcs := [13]uintptr{} + // the third caller usually from gorm internal + len := runtime.Callers(3, pcs[:]) + frames := runtime.CallersFrames(pcs[:len]) + for i := 0; i < len; i++ { + // second return value is "more", not "ok" + frame, _ := frames.Next() + if (!strings.HasPrefix(frame.File, gormSourceDir) || + strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") { + return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) } }