From a0c1da1160a4c6c80106d8d751672067a0b21e02 Mon Sep 17 00:00:00 2001 From: Martin Munilla Date: Wed, 17 Apr 2024 11:55:40 +0200 Subject: [PATCH] chore: add tests --- schema/schema.go | 10 +++---- schema/schema_test.go | 45 ++++++++++++++++++++++++++++ tests/primary_key_uuid_test.go | 54 ++++++++++++++++++++++++++++++++++ tests/upsert_test.go | 4 ++- 4 files changed, 106 insertions(+), 7 deletions(-) create mode 100644 tests/primary_key_uuid_test.go diff --git a/schema/schema.go b/schema/schema.go index 6d882e86..bbc8397d 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -300,13 +300,11 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam field.AutoIncrement = true } case String: - if _, ok := field.TagSettings["PRIMARYKEY"]; !ok { - if !field.HasDefaultValue || field.DefaultValueInterface != nil { - schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) - } - - field.HasDefaultValue = true + if !field.HasDefaultValue || field.DefaultValueInterface != nil { + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } + + field.HasDefaultValue = true } } diff --git a/schema/schema_test.go b/schema/schema_test.go index 45e152e9..9c85ee9e 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -334,3 +334,48 @@ func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) { t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil") } } + +func TestStringPrimaryKeyDefault(t *testing.T) { + type Product struct { + ID string + Code string + Name string + } + type ProductWithNamedPrimaryKey struct { + ProductID string `gorm:"primaryKey"` + Code string + Name string + } + + product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse product struct with composite primary key, got error %v", err) + } + + isInDefault := false + for _, field := range product.FieldsWithDefaultDBValue { + if field.Name == "ID" { + isInDefault = true + break + } + } + if !isInDefault { + t.Errorf("ID should be fields with default") + } + + productWithNamedPrimaryKey, err := schema.Parse(&ProductWithNamedPrimaryKey{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse product struct with composite primary key, got error %v", err) + } + + isInDefault = false + for _, field := range productWithNamedPrimaryKey.FieldsWithDefaultDBValue { + if field.Name == "ProductID" { + isInDefault = true + break + } + } + if !isInDefault { + t.Errorf("ProductID should be fields with default") + } +} diff --git a/tests/primary_key_uuid_test.go b/tests/primary_key_uuid_test.go new file mode 100644 index 00000000..99baddf9 --- /dev/null +++ b/tests/primary_key_uuid_test.go @@ -0,0 +1,54 @@ +package tests_test + +import ( + "sync" + "testing" + + "github.com/google/uuid" + "gorm.io/gorm/schema" +) + +func TestStringPrimaryKeyDefault(t *testing.T) { + type Product struct { + ID uuid.UUID + Code string + Name string + } + type ProductWithNamedPrimaryKey struct { + ProductID uuid.UUID `gorm:"primaryKey"` + Code string + Name string + } + + product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse product struct with composite primary key, got error %v", err) + } + + isInDefault := false + for _, field := range product.FieldsWithDefaultDBValue { + if field.Name == "ID" { + isInDefault = true + break + } + } + if !isInDefault { + t.Errorf("ID should be fields with default") + } + + productWithNamedPrimaryKey, err := schema.Parse(&ProductWithNamedPrimaryKey{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse product struct with composite primary key, got error %v", err) + } + + isInDefault = false + for _, field := range productWithNamedPrimaryKey.FieldsWithDefaultDBValue { + if field.Name == "ProductID" { + isInDefault = true + break + } + } + if !isInDefault { + t.Errorf("ProductID should be fields with default") + } +} diff --git a/tests/upsert_test.go b/tests/upsert_test.go index e84dc14a..b8342c88 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "fmt" "regexp" "testing" "time" @@ -62,7 +63,8 @@ func TestUpsert(t *testing.T) { } r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"}) - if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.\W*$`).MatchString(r.Statement.SQL.String()) { + fmt.Println(r.Statement.SQL.String()) + if !regexp.MustCompile(`INTO .restricted_languages. .*\(.name.,.lang.,.code.\) .* (SET|UPDATE) .name.=.*.name. RETURNING .code.\W*$`).MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } }