diff --git a/callbacks/associations.go b/callbacks/associations.go index 38f21218..75bd6c6a 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -207,7 +207,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { } cacheKey := utils.ToStringKey(relPrimaryValues) - if len(relPrimaryValues) == 0 || (len(relPrimaryValues) == len(rel.FieldSchema.PrimaryFields) && !identityMap[cacheKey]) { + if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { identityMap[cacheKey] = true if isPtr { elems = reflect.Append(elems, elem) diff --git a/callbacks/create.go b/callbacks/create.go index c585fbe9..9dc5b8b1 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -83,7 +83,7 @@ func Create(config *Config) func(db *gorm.DB) { ) if db.AddError(err) == nil { gorm.Scan(rows, db, mode) - rows.Close() + db.AddError(rows.Close()) } return diff --git a/callbacks/delete.go b/callbacks/delete.go index 525c0145..b05a9d08 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -168,7 +168,7 @@ func Delete(config *Config) func(db *gorm.DB) { if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { gorm.Scan(rows, db, mode) - rows.Close() + db.AddError(rows.Close()) } } } diff --git a/callbacks/query.go b/callbacks/query.go index 6ca3a1fb..2f98a4b6 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -20,9 +20,8 @@ func Query(db *gorm.DB) { db.AddError(err) return } - defer rows.Close() - gorm.Scan(rows, db, 0) + db.AddError(rows.Close()) } } } diff --git a/callbacks/update.go b/callbacks/update.go index 1f4960b5..fa7640de 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -88,7 +88,7 @@ func Update(config *Config) func(db *gorm.DB) { db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface() gorm.Scan(rows, db, mode) db.Statement.Dest = dest - rows.Close() + db.AddError(rows.Close()) } } else { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/finisher_api.go b/finisher_api.go index b3bdedc8..d38d60b7 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -457,12 +457,12 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { tx.Config = &config if rows, err := tx.Rows(); err == nil { - defer rows.Close() if rows.Next() { tx.ScanRows(rows, dest) } else { tx.RowsAffected = 0 } + tx.AddError(rows.Close()) } currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) { diff --git a/migrator/migrator.go b/migrator/migrator.go index 91bf60a7..18212dbb 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -430,13 +430,15 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy // ColumnTypes return columnTypes []gorm.ColumnType and execErr error func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { columnTypes := make([]gorm.ColumnType, 0) - execErr := m.RunWithValue(value, func(stmt *gorm.Statement) error { + execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() if err != nil { return err } - defer rows.Close() + defer func() { + err = rows.Close() + }() var rawColumnTypes []*sql.ColumnType rawColumnTypes, err = rows.ColumnTypes() @@ -448,7 +450,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { columnTypes = append(columnTypes, c) } - return nil + return }) return columnTypes, execErr diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index 3e4de726..e37da7d3 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -132,6 +132,13 @@ func TestBelongsToAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Company", 0, "after clear") AssertAssociationCount(t, user2, "Manager", 0, "after clear") + + // unexist company id + unexistCompanyID := company.ID + 9999999 + user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID} + if err := DB.Create(&user).Error; err == nil { + t.Errorf("should have gotten foreign key violation error") + } } func TestBelongsToAssociationForSlice(t *testing.T) { diff --git a/tests/associations_test.go b/tests/associations_test.go index a4b1f1f2..f88d1523 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -179,12 +179,8 @@ func TestForeignKeyConstraintsBelongsTo(t *testing.T) { func TestFullSaveAssociations(t *testing.T) { coupon := &Coupon{ - ID: "full-save-association-coupon1", AppliesToProduct: []*CouponProduct{ - { - CouponId: "full-save-association-coupon1", - ProductId: "full-save-association-product1", - }, + {ProductId: "full-save-association-product1"}, }, AmountOff: 10, PercentOff: 0.0, @@ -198,11 +194,11 @@ func TestFullSaveAssociations(t *testing.T) { t.Errorf("Failed, got error: %v", err) } - if DB.First(&Coupon{}, "id = ?", "full-save-association-coupon1").Error != nil { + if DB.First(&Coupon{}, "id = ?", coupon.ID).Error != nil { t.Errorf("Failed to query saved coupon") } - if DB.First(&CouponProduct{}, "coupon_id = ? AND product_id = ?", "full-save-association-coupon1", "full-save-association-product1").Error != nil { + if DB.First(&CouponProduct{}, "coupon_id = ? AND product_id = ?", coupon.ID, "full-save-association-product1").Error != nil { t.Errorf("Failed to query saved association") } @@ -210,4 +206,18 @@ func TestFullSaveAssociations(t *testing.T) { if err := DB.Create(&orders).Error; err != nil { t.Errorf("failed to create orders, got %v", err) } + + coupon2 := Coupon{ + AppliesToProduct: []*CouponProduct{{Desc: "coupon-description"}}, + } + + DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&coupon2) + var result Coupon + if err := DB.Preload("AppliesToProduct").First(&result, "id = ?", coupon2.ID).Error; err != nil { + t.Errorf("Failed to create coupon w/o name, got error: %v", err) + } + + if len(result.AppliesToProduct) != 1 { + t.Errorf("Failed to preload AppliesToProduct") + } } diff --git a/utils/tests/models.go b/utils/tests/models.go index 337682d6..c84f9cae 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -62,15 +62,16 @@ type Language struct { } type Coupon struct { - ID string `gorm:"primarykey; size:255"` + ID int `gorm:"primarykey; size:255"` AppliesToProduct []*CouponProduct `gorm:"foreignKey:CouponId;constraint:OnDelete:CASCADE"` AmountOff uint32 `gorm:"amount_off"` PercentOff float32 `gorm:"percent_off"` } type CouponProduct struct { - CouponId string `gorm:"primarykey; size:255"` - ProductId string `gorm:"primarykey; size:255"` + CouponId int `gorm:"primarykey;size:255"` + ProductId string `gorm:"primarykey;size:255"` + Desc string } type Order struct {