Add ability to disable association upserts.

This commit is contained in:
Tony Dong 2025-08-02 11:22:41 -07:00
parent eb90a02a07
commit 55407b1f36
3 changed files with 254 additions and 4 deletions

View File

@ -358,7 +358,9 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
} }
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) { func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations { if stmt.DB.DisableAssociationUpserts {
onConflict.DoNothing = true
} else if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations {
onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames)) onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
for _, dbName := range s.PrimaryFieldDBNames { for _, dbName := range s.PrimaryFieldDBNames {
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName}) onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName})

13
gorm.go
View File

@ -28,6 +28,8 @@ type Config struct {
NamingStrategy schema.Namer NamingStrategy schema.Namer
// FullSaveAssociations full save associations // FullSaveAssociations full save associations
FullSaveAssociations bool FullSaveAssociations bool
// DisableAssociationUpserts disable upserting of associations when they already exist
DisableAssociationUpserts bool
// Logger // Logger
Logger logger.Interface Logger logger.Interface
// NowFunc the function to be used when creating a new timestamp // NowFunc the function to be used when creating a new timestamp
@ -117,9 +119,10 @@ type Session struct {
SkipHooks bool SkipHooks bool
SkipDefaultTransaction bool SkipDefaultTransaction bool
DisableNestedTransaction bool DisableNestedTransaction bool
AllowGlobalUpdate bool AllowGlobalUpdate bool
FullSaveAssociations bool FullSaveAssociations bool
PropagateUnscoped bool DisableAssociationUpserts bool
PropagateUnscoped bool
QueryFields bool QueryFields bool
Context context.Context Context context.Context
Logger logger.Interface Logger logger.Interface
@ -272,6 +275,10 @@ func (db *DB) Session(config *Session) *DB {
txConfig.FullSaveAssociations = true txConfig.FullSaveAssociations = true
} }
if config.DisableAssociationUpserts {
txConfig.DisableAssociationUpserts = true
}
if config.PropagateUnscoped { if config.PropagateUnscoped {
txConfig.PropagateUnscoped = true txConfig.PropagateUnscoped = true
} }

View File

@ -0,0 +1,241 @@
package tests_test
import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
func TestDisableAssociationUpserts(t *testing.T) {
// Setup test models
type Profile struct {
ID uint
Name string
}
type UserWithProfile struct {
ID uint
Name string
ProfileID uint
Profile Profile
}
// Clean up and migrate
DB.Migrator().DropTable(&UserWithProfile{}, &Profile{})
if err := DB.AutoMigrate(&UserWithProfile{}, &Profile{}); err != nil {
t.Fatalf("Failed to migrate tables: %v", err)
}
// Test 1: Default behavior (associations are created but not updated on conflict)
t.Run("Default behavior", func(t *testing.T) {
profile := Profile{ID: 1, Name: "Original Profile"}
user := UserWithProfile{
ID: 1,
Name: "Test User",
ProfileID: 1,
Profile: profile,
}
// First create
if err := DB.Create(&user).Error; err != nil {
t.Fatalf("Failed to create user: %v", err)
}
// Verify profile was created
var savedProfile Profile
if err := DB.First(&savedProfile, 1).Error; err != nil {
t.Fatalf("Failed to find created profile: %v", err)
}
if savedProfile.Name != "Original Profile" {
t.Errorf("Expected profile name 'Original Profile', got '%s'", savedProfile.Name)
}
// Second create with updated profile (should not update existing profile by default)
user.Profile.Name = "Updated Profile"
if err := DB.Create(&user).Error; err != nil {
t.Fatalf("Failed to create user second time: %v", err)
}
// Verify profile was NOT updated (default DoNothing behavior)
var unchangedProfile Profile
if err := DB.First(&unchangedProfile, 1).Error; err != nil {
t.Fatalf("Failed to find profile after second create: %v", err)
}
if unchangedProfile.Name != "Original Profile" {
t.Errorf("Expected profile name to remain 'Original Profile', got '%s'", unchangedProfile.Name)
}
})
// Test 2: With FullSaveAssociations (should update associations)
t.Run("FullSaveAssociations behavior", func(t *testing.T) {
// Clean up
DB.Exec("DELETE FROM user_with_profiles")
DB.Exec("DELETE FROM profiles")
profile := Profile{ID: 2, Name: "Original Profile 2"}
user := UserWithProfile{
ID: 2,
Name: "Test User 2",
ProfileID: 2,
Profile: profile,
}
// First create
if err := DB.Create(&user).Error; err != nil {
t.Fatalf("Failed to create user: %v", err)
}
// Second create with FullSaveAssociations (should update existing profile)
user.Profile.Name = "Updated Profile 2"
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error; err != nil {
t.Fatalf("Failed to create user with FullSaveAssociations: %v", err)
}
// Verify profile was updated
var updatedProfile Profile
if err := DB.First(&updatedProfile, 2).Error; err != nil {
t.Fatalf("Failed to find profile after FullSaveAssociations create: %v", err)
}
if updatedProfile.Name != "Updated Profile 2" {
t.Errorf("Expected profile name 'Updated Profile 2', got '%s'", updatedProfile.Name)
}
})
// Test 3: With DisableAssociationUpserts (should never update associations)
t.Run("DisableAssociationUpserts behavior", func(t *testing.T) {
// Clean up
DB.Exec("DELETE FROM user_with_profiles")
DB.Exec("DELETE FROM profiles")
profile := Profile{ID: 3, Name: "Original Profile 3"}
user := UserWithProfile{
ID: 3,
Name: "Test User 3",
ProfileID: 3,
Profile: profile,
}
// Create with DisableAssociationUpserts enabled
dbWithDisabledUpserts := DB.Session(&gorm.Session{DisableAssociationUpserts: true})
// First create
if err := dbWithDisabledUpserts.Create(&user).Error; err != nil {
t.Fatalf("Failed to create user: %v", err)
}
// Verify profile was created
var savedProfile Profile
if err := DB.First(&savedProfile, 3).Error; err != nil {
t.Fatalf("Failed to find created profile: %v", err)
}
if savedProfile.Name != "Original Profile 3" {
t.Errorf("Expected profile name 'Original Profile 3', got '%s'", savedProfile.Name)
}
// Second create with updated profile AND FullSaveAssociations
// DisableAssociationUpserts should override FullSaveAssociations
user.Profile.Name = "Should Not Update"
if err := dbWithDisabledUpserts.Session(&gorm.Session{
FullSaveAssociations: true,
DisableAssociationUpserts: true,
}).Create(&user).Error; err != nil {
t.Fatalf("Failed to create user second time: %v", err)
}
// Verify profile was NOT updated despite FullSaveAssociations
var unchangedProfile Profile
if err := DB.First(&unchangedProfile, 3).Error; err != nil {
t.Fatalf("Failed to find profile after second create: %v", err)
}
if unchangedProfile.Name != "Original Profile 3" {
t.Errorf("Expected profile name to remain 'Original Profile 3' (DisableAssociationUpserts should override FullSaveAssociations), got '%s'", unchangedProfile.Name)
}
})
// Test 4: Global DisableAssociationUpserts configuration
t.Run("Global DisableAssociationUpserts configuration", func(t *testing.T) {
// Clean up
DB.Exec("DELETE FROM user_with_profiles")
DB.Exec("DELETE FROM profiles")
// Create a new DB instance with DisableAssociationUpserts enabled globally
globalDB := DB.Session(&gorm.Session{DisableAssociationUpserts: true})
profile := Profile{ID: 4, Name: "Original Profile 4"}
user := UserWithProfile{
ID: 4,
Name: "Test User 4",
ProfileID: 4,
Profile: profile,
}
// First create
if err := globalDB.Create(&user).Error; err != nil {
t.Fatalf("Failed to create user: %v", err)
}
// Second create with updated profile
user.Profile.Name = "Should Not Update Global"
if err := globalDB.Create(&user).Error; err != nil {
t.Fatalf("Failed to create user second time: %v", err)
}
// Verify profile was NOT updated
var unchangedProfile Profile
if err := DB.First(&unchangedProfile, 4).Error; err != nil {
t.Fatalf("Failed to find profile: %v", err)
}
if unchangedProfile.Name != "Original Profile 4" {
t.Errorf("Expected profile name to remain 'Original Profile 4', got '%s'", unchangedProfile.Name)
}
})
// Test 5: HasMany relationship
t.Run("HasMany relationships", func(t *testing.T) {
type Order struct {
ID uint
Amount int
UserID uint
}
type UserWithOrders struct {
ID uint
Name string
Orders []Order
}
// Clean up and migrate
DB.Migrator().DropTable(&UserWithOrders{}, &Order{})
if err := DB.AutoMigrate(&UserWithOrders{}, &Order{}); err != nil {
t.Fatalf("Failed to migrate tables: %v", err)
}
order := Order{ID: 1, Amount: 100}
user := UserWithOrders{
ID: 1,
Name: "User with Orders",
Orders: []Order{order},
}
// First create
if err := DB.Create(&user).Error; err != nil {
t.Fatalf("Failed to create user with orders: %v", err)
}
// Update order and create again with DisableAssociationUpserts
user.Orders[0].Amount = 200
if err := DB.Session(&gorm.Session{DisableAssociationUpserts: true}).Create(&user).Error; err != nil {
t.Fatalf("Failed to create user second time: %v", err)
}
// Verify order was NOT updated
var unchangedOrder Order
if err := DB.First(&unchangedOrder, 1).Error; err != nil {
t.Fatalf("Failed to find order: %v", err)
}
if unchangedOrder.Amount != 100 {
t.Errorf("Expected order amount to remain 100, got %d", unchangedOrder.Amount)
}
})
}