From 63534145fda9a2ac9ba703650b1a44da6a03e45e Mon Sep 17 00:00:00 2001 From: aclich <71011237+aclich@users.noreply.github.com> Date: Mon, 15 May 2023 09:59:26 +0800 Subject: [PATCH 01/19] =?UTF-8?q?fix:=20=F0=9F=90=9B=20embedded=20struct?= =?UTF-8?q?=20test=20failed=20with=20custom=20datatypes=20(#6311)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: 🐛 embedded struct test failed with custom datatypes Fix the pointer embedded struct within custom datatypes and *time.time should be nil issue. * fix: 🐛 change test case to avoid mssql driver issue change test cases from bytes to string to avoid mssql driver issue --- schema/field.go | 18 +++----- tests/embedded_struct_test.go | 80 +++++++++++++++++++++++++++++------ 2 files changed, 75 insertions(+), 23 deletions(-) diff --git a/schema/field.go b/schema/field.go index 7d1a1789..dd08e056 100644 --- a/schema/field.go +++ b/schema/field.go @@ -846,7 +846,7 @@ func (field *Field) setupValuerAndSetter() { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { case **time.Time: - if data != nil { + if data != nil && *data != nil { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data)) } case time.Time: @@ -882,14 +882,12 @@ func (field *Field) setupValuerAndSetter() { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { + return } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { - if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) - } else { - return field.Set(ctx, value, reflectV.Elem().Interface()) - } + return field.Set(ctx, value, reflectV.Elem().Interface()) } else { fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { @@ -910,14 +908,12 @@ func (field *Field) setupValuerAndSetter() { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { + return } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { - if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) - } else { - return field.Set(ctx, value, reflectV.Elem().Interface()) - } + return field.Set(ctx, value, reflectV.Elem().Interface()) } else { if valuer, ok := v.(driver.Valuer); ok { v, _ = valuer.Value() diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 3747dad9..4314f88c 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -4,7 +4,9 @@ import ( "database/sql/driver" "encoding/json" "errors" + "reflect" "testing" + "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" @@ -104,10 +106,14 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) { } type Author struct { - ID string - Name string - Email string - Age int + ID string + Name string + Email string + Age int + Content Content + ContentPtr *Content + Birthday time.Time + BirthdayPtr *time.Time } type HNPost struct { @@ -135,6 +141,48 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) { if hnPost.Author != nil { t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author) } + + now := time.Now().Round(time.Second) + NewPost := HNPost{ + BasePost: &BasePost{Title: "embedded_pointer_type2"}, + Author: &Author{ + Name: "test", + Content: Content{"test"}, + ContentPtr: nil, + Birthday: now, + BirthdayPtr: nil, + }, + } + DB.Create(&NewPost) + + hnPost = HNPost{} + if err := DB.First(&hnPost, "title = ?", NewPost.Title).Error; err != nil { + t.Errorf("No error should happen when find embedded pointer type, but got %v", err) + } + + if hnPost.Title != NewPost.Title { + t.Errorf("Should find correct value for embedded pointer type") + } + + if hnPost.Author.Name != NewPost.Author.Name { + t.Errorf("Expected to get Author name %v but got: %v", NewPost.Author.Name, hnPost.Author.Name) + } + + if !reflect.DeepEqual(NewPost.Author.Content, hnPost.Author.Content) { + t.Errorf("Expected to get Author content %v but got: %v", NewPost.Author.Content, hnPost.Author.Content) + } + + if hnPost.Author.ContentPtr != nil { + t.Errorf("Expected to get nil Author contentPtr but got: %v", hnPost.Author.ContentPtr) + } + + if NewPost.Author.Birthday.UnixMilli() != hnPost.Author.Birthday.UnixMilli() { + t.Errorf("Expected to get Author birthday with %+v but got: %+v", NewPost.Author.Birthday, hnPost.Author.Birthday) + } + + if hnPost.Author.BirthdayPtr != nil { + t.Errorf("Expected to get nil Author birthdayPtr but got: %+v", hnPost.Author.BirthdayPtr) + } } type Content struct { @@ -142,18 +190,26 @@ type Content struct { } func (c Content) Value() (driver.Value, error) { - return json.Marshal(c) + // mssql driver with issue on handling null bytes https://github.com/denisenkom/go-mssqldb/issues/530, + b, err := json.Marshal(c) + return string(b[:]), err } func (c *Content) Scan(src interface{}) error { - b, ok := src.([]byte) - if !ok { - return errors.New("Embedded.Scan byte assertion failed") - } - var value Content - if err := json.Unmarshal(b, &value); err != nil { - return err + str, ok := src.(string) + if !ok { + byt, ok := src.([]byte) + if !ok { + return errors.New("Embedded.Scan byte assertion failed") + } + if err := json.Unmarshal(byt, &value); err != nil { + return err + } + } else { + if err := json.Unmarshal([]byte(str), &value); err != nil { + return err + } } *c = value From c3d7d08b9a9f861e53e8b194fcc6b7cedc4191e1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 15 May 2023 15:43:44 +0800 Subject: [PATCH 02/19] Clear SET clause after build SQL --- callbacks/update.go | 1 + tests/update_test.go | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/callbacks/update.go b/callbacks/update.go index 4eb75788..ff075dcf 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -72,6 +72,7 @@ func Update(config *Config) func(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Update{}) if _, ok := db.Statement.Clauses["SET"]; !ok { if set := ConvertToAssignments(db.Statement); len(set) != 0 { + defer delete(db.Statement.Clauses, "SET") db.Statement.AddClause(set) } else { return diff --git a/tests/update_test.go b/tests/update_test.go index 36ffa6a0..f7c36d74 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -208,13 +208,17 @@ func TestUpdateColumn(t *testing.T) { CheckUser(t, user1, *users[0]) CheckUser(t, user2, *users[1]) - DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew") + DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew").UpdateColumn("age", 19) AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) if users[1].Name != "update_column_02_newnew" { t.Errorf("user 2's name should be updated, but got %v", users[1].Name) } + if users[1].Age != 19 { + t.Errorf("user 2's name should be updated, but got %v", users[1].Age) + } + DB.Model(users[1]).UpdateColumn("age", gorm.Expr("age + 100 - 50")) var user3 User DB.First(&user3, users[1].ID) From f5837deef3d0c8edc881ca24b992689c71a5cc06 Mon Sep 17 00:00:00 2001 From: 201430098137 <1850396756@qq.com> Date: Wed, 17 May 2023 10:15:41 +0800 Subject: [PATCH 03/19] fix:clickhouse error not capture(#6277) (#6321) Co-authored-by: zhuangg --- finisher_api.go | 1 + 1 file changed, 1 insertion(+) diff --git a/finisher_api.go b/finisher_api.go index 0e26f181..ad14e298 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -533,6 +533,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { tx.ScanRows(rows, dest) } else { tx.RowsAffected = 0 + tx.AddError(rows.Err()) } tx.AddError(rows.Close()) } From 6698ba709e1b08d72e6799c0fc5d0cb5a28cce4b Mon Sep 17 00:00:00 2001 From: Avinaba Bhattacharjee Date: Sun, 21 May 2023 18:54:00 +0530 Subject: [PATCH 04/19] renamed License to LICENSE (#6336) --- License => LICENSE | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename License => LICENSE (100%) diff --git a/License b/LICENSE similarity index 100% rename from License rename to LICENSE From 001738be49d26d341f88377d6006356573c9d045 Mon Sep 17 00:00:00 2001 From: Muhammad Amir Ejaz <37077032+codingamir@users.noreply.github.com> Date: Sun, 21 May 2023 18:27:22 +0500 Subject: [PATCH 05/19] Added support of "Violates Foreign Key Constraint" (#6329) * Added support of "Violates Foreign Key Constraint" Updated the translator and added the support of "foreign key constraint violation". For this, this error type is needed here. * changed the description of ErrForeignKeyViolated --- errors.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/errors.go b/errors.go index 57e3fc5e..cd76f1f5 100644 --- a/errors.go +++ b/errors.go @@ -47,4 +47,6 @@ var ( ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used") // ErrDuplicatedKey occurs when there is a unique key constraint violation ErrDuplicatedKey = errors.New("duplicated key not allowed") + // ErrForeignKeyViolated occurs when there is a foreign key constraint violation + ErrForeignKeyViolated = errors.New("violates foreign key constraint") ) From 8197c00def30911fea0e987d03835fbc6cbbbb59 Mon Sep 17 00:00:00 2001 From: Saeid Date: Thu, 25 May 2023 05:10:00 +0200 Subject: [PATCH 06/19] refactor: error translator test (#6350) Co-authored-by: Saeid Saeidee --- tests/error_translator_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/error_translator_test.go b/tests/error_translator_test.go index ead26fce..ca985a09 100644 --- a/tests/error_translator_test.go +++ b/tests/error_translator_test.go @@ -15,8 +15,8 @@ func TestDialectorWithErrorTranslatorSupport(t *testing.T) { db, _ := gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}) err := db.AddError(untranslatedErr) - if errors.Is(err, translatedErr) { - t.Fatalf("expected err: %v got err: %v", translatedErr, err) + if !errors.Is(err, untranslatedErr) { + t.Fatalf("expected err: %v got err: %v", untranslatedErr, err) } // it should translate error when the TranslateError flag is true From 812bb20c34f0a67e6799269ce5ef635e78c0e0cf Mon Sep 17 00:00:00 2001 From: wangliuyang <54885906+wangliuyang520@users.noreply.github.com> Date: Fri, 26 May 2023 10:24:28 +0800 Subject: [PATCH 07/19] fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements (#6220) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * test: add nested transaction and prepareStmt coexist test case note: please test in the MySQL environment Change-Id: I0db32adc5f74b0d443e98943d3b182236583b959 Signed-off-by: 王柳洋 * fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements 1. SavetPoint SQL Statement not support in Prepared Statements e.g. see mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html Change-Id: I082012db9b140e8ec69764c633724665cc802692 Signed-off-by: 王柳洋 * revert(transaction_api): remove savepoint name pool,meaningless Change-Id: I84aa9924fc54612005a81c83d66fdf8968ee56ad Signed-off-by: 王柳洋 --------- Signed-off-by: 王柳洋 Co-authored-by: 王柳洋 --- finisher_api.go | 46 +++++++++++++++++++++++++-------------- tests/transaction_test.go | 13 +++++++++++ 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index ad14e298..6d0b4cd2 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -6,8 +6,6 @@ import ( "fmt" "reflect" "strings" - "sync" - "sync/atomic" "gorm.io/gorm/clause" "gorm.io/gorm/logger" @@ -612,15 +610,6 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) { return fc(tx) } -var ( - savepointIdx int64 - savepointNamePool = &sync.Pool{ - New: func() interface{} { - return fmt.Sprintf("gorm_%d", atomic.AddInt64(&savepointIdx, 1)) - }, - } -) - // Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an // arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs // they are rolled back. @@ -630,17 +619,14 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { // nested transaction if !db.DisableNestedTransaction { - poolName := savepointNamePool.Get() - defer savepointNamePool.Put(poolName) - err = db.SavePoint(poolName.(string)).Error + err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error if err != nil { return } - defer func() { // Make sure to rollback when panic, Block error or Commit error if panicked || err != nil { - db.RollbackTo(poolName.(string)) + db.RollbackTo(fmt.Sprintf("sp%p", fc)) } }() } @@ -721,7 +707,21 @@ func (db *DB) Rollback() *DB { func (db *DB) SavePoint(name string) *DB { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + // close prepared statement, because SavePoint not support prepared statement. + // e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html + var ( + preparedStmtTx *PreparedStmtTX + isPreparedStmtTx bool + ) + // close prepared statement, because SavePoint not support prepared statement. + if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx.Tx + } db.AddError(savePointer.SavePoint(db, name)) + // restore prepared statement + if isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx + } } else { db.AddError(ErrUnsupportedDriver) } @@ -730,7 +730,21 @@ func (db *DB) SavePoint(name string) *DB { func (db *DB) RollbackTo(name string) *DB { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + // close prepared statement, because RollbackTo not support prepared statement. + // e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html + var ( + preparedStmtTx *PreparedStmtTX + isPreparedStmtTx bool + ) + // close prepared statement, because SavePoint not support prepared statement. + if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx.Tx + } db.AddError(savePointer.RollbackTo(db, name)) + // restore prepared statement + if isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx + } } else { db.AddError(ErrUnsupportedDriver) } diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 5872da94..bfbd8699 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -57,6 +57,19 @@ func TestTransaction(t *testing.T) { if err := DB.First(&User{}, "name = ?", "transaction-2").Error; err != nil { t.Fatalf("Should be able to find committed record, but got %v", err) } + + t.Run("this is test nested transaction and prepareStmt coexist case", func(t *testing.T) { + // enable prepare statement + tx3 := DB.Session(&gorm.Session{PrepareStmt: true}) + if err := tx3.Transaction(func(tx4 *gorm.DB) error { + // nested transaction + return tx4.Transaction(func(tx5 *gorm.DB) error { + return tx5.First(&User{}, "name = ?", "transaction-2").Error + }) + }); err != nil { + t.Fatalf("prepare statement and nested transcation coexist" + err.Error()) + } + }) } func TestCancelTransaction(t *testing.T) { From 11fdf46a9fcc393e604ea6df22ce0355b2fe1afb Mon Sep 17 00:00:00 2001 From: black-06 Date: Fri, 26 May 2023 10:28:02 +0800 Subject: [PATCH 08/19] fix: save with hook (#6285) (#6294) --- finisher_api.go | 2 +- tests/update_test.go | 73 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/finisher_api.go b/finisher_api.go index 6d0b4cd2..f80aa6c0 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -105,7 +105,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true})) if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate { - return tx.Clauses(clause.OnConflict{UpdateAll: true}).Create(value) + return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value) } return updateTx diff --git a/tests/update_test.go b/tests/update_test.go index f7c36d74..c03d2d47 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -809,3 +809,76 @@ func TestUpdateWithDiffSchema(t *testing.T) { AssertEqual(t, err, nil) AssertEqual(t, "update-diff-schema-2", user.Name) } + +type TokenOwner struct { + ID int + Name string + Token Token `gorm:"foreignKey:UserID"` +} + +func (t *TokenOwner) BeforeSave(tx *gorm.DB) error { + t.Name += "_name" + return nil +} + +type Token struct { + UserID int `gorm:"primary_key"` + Content string `gorm:"type:varchar(100)"` +} + +func (t *Token) BeforeSave(tx *gorm.DB) error { + t.Content += "_encrypted" + return nil +} + +func TestSaveWithHooks(t *testing.T) { + DB.Migrator().DropTable(&Token{}, &TokenOwner{}) + DB.AutoMigrate(&Token{}, &TokenOwner{}) + + saveTokenOwner := func(owner *TokenOwner) (*TokenOwner, error) { + var newOwner TokenOwner + if err := DB.Transaction(func(tx *gorm.DB) error { + if err := tx.Debug().Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil { + return err + } + if err := tx.Preload("Token").First(&newOwner, owner.ID).Error; err != nil { + return err + } + return nil + }); err != nil { + return nil, err + } + return &newOwner, nil + } + + owner := TokenOwner{ + Name: "user", + Token: Token{Content: "token"}, + } + o1, err := saveTokenOwner(&owner) + if err != nil { + t.Errorf("failed to save token owner, got error: %v", err) + } + if o1.Name != "user_name" { + t.Errorf(`owner name should be "user_name", but got: "%s"`, o1.Name) + } + if o1.Token.Content != "token_encrypted" { + t.Errorf(`token content should be "token_encrypted", but got: "%s"`, o1.Token.Content) + } + + owner = TokenOwner{ + ID: owner.ID, + Name: "user", + Token: Token{Content: "token2"}, + } + o2, err := saveTokenOwner(&owner) + if err != nil { + t.Errorf("failed to save token owner, got error: %v", err) + } + if o2.Name != "user_name" { + t.Errorf(`owner name should be "user_name", but got: "%s"`, o2.Name) + } + if o2.Token.Content != "token2_encrypted" { + t.Errorf(`token content should be "token2_encrypted", but got: "%s"`, o2.Token.Content) + } +} From 26663ab9bf55c603ca69a97504a1dfe7d53a1bb3 Mon Sep 17 00:00:00 2001 From: mohammad ali <2018cs92@student.uet.edu.pk> Date: Mon, 29 May 2023 19:00:48 -0700 Subject: [PATCH 09/19] max identifier length changed to 63 (#6337) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * max identifier length changed to 63 * default maxIdentifierLength is 64 * renamed License to LICENSE (#6336) * Added support of "Violates Foreign Key Constraint" (#6329) * Added support of "Violates Foreign Key Constraint" Updated the translator and added the support of "foreign key constraint violation". For this, this error type is needed here. * changed the description of ErrForeignKeyViolated * refactor: error translator test (#6350) Co-authored-by: Saeid Saeidee * fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements (#6220) * test: add nested transaction and prepareStmt coexist test case note: please test in the MySQL environment Change-Id: I0db32adc5f74b0d443e98943d3b182236583b959 Signed-off-by: 王柳洋 * fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements 1. SavetPoint SQL Statement not support in Prepared Statements e.g. see mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html Change-Id: I082012db9b140e8ec69764c633724665cc802692 Signed-off-by: 王柳洋 * revert(transaction_api): remove savepoint name pool,meaningless Change-Id: I84aa9924fc54612005a81c83d66fdf8968ee56ad Signed-off-by: 王柳洋 --------- Signed-off-by: 王柳洋 Co-authored-by: 王柳洋 * fix: save with hook (#6285) (#6294) --------- Signed-off-by: 王柳洋 Co-authored-by: Avinaba Bhattacharjee Co-authored-by: Muhammad Amir Ejaz <37077032+codingamir@users.noreply.github.com> Co-authored-by: Saeid Co-authored-by: Saeid Saeidee Co-authored-by: wangliuyang <54885906+wangliuyang520@users.noreply.github.com> Co-authored-by: 王柳洋 Co-authored-by: black-06 --- gorm.go | 2 +- schema/naming.go | 17 +++++++++++------ schema/naming_test.go | 11 ++++++++++- schema/relationship_test.go | 2 +- 4 files changed, 23 insertions(+), 9 deletions(-) diff --git a/gorm.go b/gorm.go index 07a913fc..84d4b433 100644 --- a/gorm.go +++ b/gorm.go @@ -146,7 +146,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { } if config.NamingStrategy == nil { - config.NamingStrategy = schema.NamingStrategy{} + config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64 } if config.Logger == nil { diff --git a/schema/naming.go b/schema/naming.go index a258beed..a2a0150a 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -28,10 +28,11 @@ type Replacer interface { // NamingStrategy tables, columns naming strategy type NamingStrategy struct { - TablePrefix string - SingularTable bool - NameReplacer Replacer - NoLowerCase bool + TablePrefix string + SingularTable bool + NameReplacer Replacer + NoLowerCase bool + IdentifierMaxLength int } // TableName convert string to table name @@ -89,12 +90,16 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string { prefix, table, name, }, "_"), ".", "_") - if utf8.RuneCountInString(formattedName) > 64 { + if ns.IdentifierMaxLength == 0 { + ns.IdentifierMaxLength = 64 + } + + if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength { h := sha1.New() h.Write([]byte(formattedName)) bs := h.Sum(nil) - formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8] + formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8] } return formattedName } diff --git a/schema/naming_test.go b/schema/naming_test.go index 3f598c33..ab7a5e31 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -189,8 +189,17 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) { } } +func TestFormatNameWithStringLongerThan63Characters(t *testing.T) { + ns := NamingStrategy{IdentifierMaxLength: 63} + + formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") + if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVer180f2c67" { + t.Errorf("invalid formatted name generated, got %v", formattedName) + } +} + func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { - ns := NamingStrategy{} + ns := NamingStrategy{IdentifierMaxLength: 64} formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" { diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 732f6f75..1eb66bb4 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -768,7 +768,7 @@ func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) { s, err := schema.Parse( &Book{}, &sync.Map{}, - schema.NamingStrategy{}, + schema.NamingStrategy{IdentifierMaxLength: 64}, ) if err != nil { t.Fatalf("Failed to parse schema") From 740f2be4535f7156752a5e6ce4bc94db59948a10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=9C=E6=96=B9=E4=B8=8A=E4=BA=BA?= Date: Wed, 31 May 2023 19:21:51 +0800 Subject: [PATCH 10/19] fix: begin transaction fail, rollback panic (#6365) --- prepare_stmt.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index e09fe814..4b3551c6 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -3,6 +3,7 @@ package gorm import ( "context" "database/sql" + "reflect" "sync" ) @@ -163,14 +164,14 @@ type PreparedStmtTX struct { } func (tx *PreparedStmtTX) Commit() error { - if tx.Tx != nil { + if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { return tx.Tx.Commit() } return ErrInvalidTransaction } func (tx *PreparedStmtTX) Rollback() error { - if tx.Tx != nil { + if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { return tx.Tx.Rollback() } return ErrInvalidTransaction From c1ea73036715018a1bb55cdb8690441044e13a76 Mon Sep 17 00:00:00 2001 From: black Date: Thu, 1 Jun 2023 10:54:36 +0800 Subject: [PATCH 11/19] fix: avoid panic when open fails --- gorm.go | 2 +- tests/gorm_test.go | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/gorm.go b/gorm.go index 84d4b433..46d1843d 100644 --- a/gorm.go +++ b/gorm.go @@ -181,7 +181,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { err = config.Dialector.Initialize(db) if err != nil { - if db, err := db.DB(); err == nil { + if db, _ := db.DB(); db != nil { _ = db.Close() } } diff --git a/tests/gorm_test.go b/tests/gorm_test.go index 9827465c..4c31b88b 100644 --- a/tests/gorm_test.go +++ b/tests/gorm_test.go @@ -3,9 +3,19 @@ package tests_test import ( "testing" + "gorm.io/driver/mysql" + "gorm.io/gorm" ) +func TestOpen(t *testing.T) { + dsn := "gorm:gorm@tcp(localhost:9910)/gorm?loc=Asia%2FHongKong" // invalid loc + _, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) + if err == nil { + t.Fatalf("should returns error but got nil") + } +} + func TestReturningWithNullToZeroValues(t *testing.T) { dialect := DB.Dialector.Name() switch dialect { From 7a76c042e67a22e226179864dd0bdfedce0d53b9 Mon Sep 17 00:00:00 2001 From: Lev Zakharov Date: Mon, 5 Jun 2023 11:23:17 +0300 Subject: [PATCH 12/19] refactor: remove unnecessary prepared statement allocation (#6374) --- gorm.go | 48 ++++++++++++++++++++++++------------------------ prepare_stmt.go | 9 +++++++++ 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/gorm.go b/gorm.go index 46d1843d..ecdb700b 100644 --- a/gorm.go +++ b/gorm.go @@ -187,15 +187,9 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { } } - preparedStmt := &PreparedStmtDB{ - ConnPool: db.ConnPool, - Stmts: make(map[string]*Stmt), - Mux: &sync.RWMutex{}, - PreparedSQL: make([]string, 0, 100), - } - db.cacheStore.Store(preparedStmtDBKey, preparedStmt) - if config.PrepareStmt { + preparedStmt := NewPreparedStmtDB(db.ConnPool) + db.cacheStore.Store(preparedStmtDBKey, preparedStmt) db.ConnPool = preparedStmt } @@ -256,24 +250,30 @@ func (db *DB) Session(config *Session) *DB { } if config.PrepareStmt { + var preparedStmt *PreparedStmtDB + if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { - preparedStmt := v.(*PreparedStmtDB) - switch t := tx.Statement.ConnPool.(type) { - case Tx: - tx.Statement.ConnPool = &PreparedStmtTX{ - Tx: t, - PreparedStmtDB: preparedStmt, - } - default: - tx.Statement.ConnPool = &PreparedStmtDB{ - ConnPool: db.Config.ConnPool, - Mux: preparedStmt.Mux, - Stmts: preparedStmt.Stmts, - } - } - txConfig.ConnPool = tx.Statement.ConnPool - txConfig.PrepareStmt = true + preparedStmt = v.(*PreparedStmtDB) + } else { + preparedStmt = NewPreparedStmtDB(db.ConnPool) + db.cacheStore.Store(preparedStmtDBKey, preparedStmt) } + + switch t := tx.Statement.ConnPool.(type) { + case Tx: + tx.Statement.ConnPool = &PreparedStmtTX{ + Tx: t, + PreparedStmtDB: preparedStmt, + } + default: + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + Mux: preparedStmt.Mux, + Stmts: preparedStmt.Stmts, + } + } + txConfig.ConnPool = tx.Statement.ConnPool + txConfig.PrepareStmt = true } if config.SkipHooks { diff --git a/prepare_stmt.go b/prepare_stmt.go index 4b3551c6..10fefc31 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -21,6 +21,15 @@ type PreparedStmtDB struct { ConnPool } +func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { + return &PreparedStmtDB{ + ConnPool: connPool, + Stmts: make(map[string]*Stmt), + Mux: &sync.RWMutex{}, + PreparedSQL: make([]string, 0, 100), + } +} + func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { return dbConnector.GetDBConn() From 5eaccaa624773441da797f93950117a44863df12 Mon Sep 17 00:00:00 2001 From: KantaHasegawa <66783124+KantaHasegawa@users.noreply.github.com> Date: Mon, 5 Jun 2023 17:24:00 +0900 Subject: [PATCH 13/19] reafactor: add nil detection when sqldb return (#6373) * reafactor: add null detection when sqldb return * refactor: Detecting nil in dbConnector.GetDBConn() * refactor: Revert partial code from c1ea73036715018a1bb55cdb8690441044e13a76 * fix: fix if statement --- gorm.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/gorm.go b/gorm.go index ecdb700b..21b289db 100644 --- a/gorm.go +++ b/gorm.go @@ -181,7 +181,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { err = config.Dialector.Initialize(db) if err != nil { - if db, _ := db.DB(); db != nil { + if db, err := db.DB(); err == nil { _ = db.Close() } } @@ -376,10 +376,12 @@ func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { - return dbConnector.GetDBConn() + if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil { + return sqldb, err + } } - if sqldb, ok := connPool.(*sql.DB); ok { + if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil { return sqldb, nil } From 661781a3d7a36b33d8a30d97499d7b428939a9ec Mon Sep 17 00:00:00 2001 From: Lev Zakharov Date: Mon, 5 Jun 2023 11:25:05 +0300 Subject: [PATCH 14/19] feat: add *sql.DB connector that uses database context (#6366) * feat: add SQLConnector * rename --- gorm.go | 4 ++++ interfaces.go | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/gorm.go b/gorm.go index 21b289db..9297850e 100644 --- a/gorm.go +++ b/gorm.go @@ -375,6 +375,10 @@ func (db *DB) AddError(err error) error { func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool + if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil { + return connector.GetDBConnWithContext(db) + } + if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil { return sqldb, err diff --git a/interfaces.go b/interfaces.go index 3bcc3d57..1950d740 100644 --- a/interfaces.go +++ b/interfaces.go @@ -77,6 +77,12 @@ type GetDBConnector interface { GetDBConn() (*sql.DB, error) } +// GetDBConnectorWithContext represents SQL db connector which takes into +// account the current database context +type GetDBConnectorWithContext interface { + GetDBConnWithContext(db *DB) (*sql.DB, error) +} + // Rows rows interface type Rows interface { Columns() ([]string, error) From 7157b7e375c58851da2245c9f4a2b496f3b0ae71 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Wed, 7 Jun 2023 08:02:07 +0100 Subject: [PATCH 15/19] fix: database/sql.Scanner should not retain references (#6380) --- tests/scanner_valuer_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 14121699..472434b4 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -170,10 +170,10 @@ func (data *EncryptedData) Scan(value interface{}) error { return errors.New("Too short") } - *data = b[3:] + *data = append((*data)[0:], b[3:]...) return nil } else if s, ok := value.(string); ok { - *data = []byte(s)[3:] + *data = []byte(s[3:]) return nil } From 7dd702d379bca5e28a2b65278b5a373f326c72ca Mon Sep 17 00:00:00 2001 From: Johannes Riecken Date: Wed, 7 Jun 2023 09:02:30 +0200 Subject: [PATCH 16/19] Fix incorrect documentation comment (has many -> has one) (#6382) --- utils/tests/models.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/tests/models.go b/utils/tests/models.go index ec1651a3..a4bad2fc 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -11,7 +11,7 @@ import ( // He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table) // He speaks many languages (many to many) and has many friends (many to many - single-table) // His pet also has one Toy (has one - polymorphic) -// NamedPet is a reference to a Named `Pets` (has many) +// NamedPet is a reference to a named `Pet` (has one) type User struct { gorm.Model Name string From c2d571cbc8ba7925a7438181cc3dcca10a89f81f Mon Sep 17 00:00:00 2001 From: Saeid Date: Sat, 10 Jun 2023 15:05:19 +0200 Subject: [PATCH 17/19] test: coverage for duplicated key err (#6389) * test: ErrDuplicatedKey coverage added * test: updated sqlserver version * test: removed sqlserver * test: support added for sqlserver --------- Co-authored-by: Saeid Saeidee --- tests/associations_many2many_test.go | 2 +- tests/error_translator_test.go | 31 ++++++++++++++++++++++++++++ tests/go.mod | 5 ++--- tests/prepared_stmt_test.go | 4 ++-- tests/tests_test.go | 14 ++++++------- tests/transaction_test.go | 2 +- 6 files changed, 44 insertions(+), 14 deletions(-) diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index b69d668a..39410aed 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -358,7 +358,7 @@ func TestDuplicateMany2ManyAssociation(t *testing.T) { } func TestConcurrentMany2ManyAssociation(t *testing.T) { - db, err := OpenTestConnection() + db, err := OpenTestConnection(&gorm.Config{}) if err != nil { t.Fatalf("open test connection failed, err: %+v", err) } diff --git a/tests/error_translator_test.go b/tests/error_translator_test.go index ca985a09..f6c70677 100644 --- a/tests/error_translator_test.go +++ b/tests/error_translator_test.go @@ -27,3 +27,34 @@ func TestDialectorWithErrorTranslatorSupport(t *testing.T) { t.Fatalf("expected err: %v got err: %v", translatedErr, err) } } + +func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) { + type City struct { + gorm.Model + Name string `gorm:"unique"` + } + + db, err := OpenTestConnection(&gorm.Config{TranslateError: true}) + if err != nil { + t.Fatalf("failed to connect database, got error %v", err) + } + + dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true} + if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) { + return + } + + if err = db.AutoMigrate(&City{}); err != nil { + t.Fatalf("failed to migrate cities table, got error: %v", err) + } + + err = db.Create(&City{Name: "Kabul"}).Error + if err != nil { + t.Fatalf("failed to create record: %v", err) + } + + err = db.Create(&City{Name: "Kabul"}).Error + if !errors.Is(err, gorm.ErrDuplicatedKey) { + t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err) + } +} diff --git a/tests/go.mod b/tests/go.mod index f47d175f..0b38b9d0 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,12 +8,11 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.8 github.com/mattn/go-sqlite3 v1.14.16 // indirect - golang.org/x/crypto v0.8.0 // indirect gorm.io/driver/mysql v1.5.0 gorm.io/driver/postgres v1.5.0 gorm.io/driver/sqlite v1.5.0 - gorm.io/driver/sqlserver v1.4.3 - gorm.io/gorm v1.25.0 + gorm.io/driver/sqlserver v1.5.1 + gorm.io/gorm v1.25.1 ) replace gorm.io/gorm => ../ diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 64baa01b..b234c8bf 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -92,7 +92,7 @@ func TestPreparedStmtFromTransaction(t *testing.T) { } func TestPreparedStmtDeadlock(t *testing.T) { - tx, err := OpenTestConnection() + tx, err := OpenTestConnection(&gorm.Config{}) AssertEqual(t, err, nil) sqlDB, _ := tx.DB() @@ -127,7 +127,7 @@ func TestPreparedStmtDeadlock(t *testing.T) { } func TestPreparedStmtError(t *testing.T) { - tx, err := OpenTestConnection() + tx, err := OpenTestConnection(&gorm.Config{}) AssertEqual(t, err, nil) sqlDB, _ := tx.DB() diff --git a/tests/tests_test.go b/tests/tests_test.go index 90eb847f..0167d406 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -26,7 +26,7 @@ var ( func init() { var err error - if DB, err = OpenTestConnection(); err != nil { + if DB, err = OpenTestConnection(&gorm.Config{}); err != nil { log.Printf("failed to connect database, got error %v", err) os.Exit(1) } else { @@ -49,7 +49,7 @@ func init() { } } -func OpenTestConnection() (db *gorm.DB, err error) { +func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) { dbDSN := os.Getenv("GORM_DSN") switch os.Getenv("GORM_DIALECT") { case "mysql": @@ -57,7 +57,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { if dbDSN == "" { dbDSN = mysqlDSN } - db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) + db, err = gorm.Open(mysql.Open(dbDSN), cfg) case "postgres": log.Println("testing postgres...") if dbDSN == "" { @@ -66,7 +66,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { db, err = gorm.Open(postgres.New(postgres.Config{ DSN: dbDSN, PreferSimpleProtocol: true, - }), &gorm.Config{}) + }), cfg) case "sqlserver": // go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest // SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 @@ -80,16 +80,16 @@ func OpenTestConnection() (db *gorm.DB, err error) { if dbDSN == "" { dbDSN = sqlserverDSN } - db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{}) + db, err = gorm.Open(sqlserver.Open(dbDSN), cfg) case "tidb": log.Println("testing tidb...") if dbDSN == "" { dbDSN = tidbDSN } - db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) + db, err = gorm.Open(mysql.Open(dbDSN), cfg) default: log.Println("testing sqlite3...") - db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) + db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), cfg) } if err != nil { diff --git a/tests/transaction_test.go b/tests/transaction_test.go index bfbd8699..126ccb23 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -361,7 +361,7 @@ func TestDisabledNestedTransaction(t *testing.T) { } func TestTransactionOnClosedConn(t *testing.T) { - DB, err := OpenTestConnection() + DB, err := OpenTestConnection(&gorm.Config{}) if err != nil { t.Fatalf("failed to connect database, got error %v", err) } From 206613868439c5ee7e62e116a46503eddf55a548 Mon Sep 17 00:00:00 2001 From: Saeid Date: Sun, 11 Jun 2023 01:42:18 +0200 Subject: [PATCH 18/19] ci: fix mariadb mysqladmin (#6401) Co-authored-by: Saeid Saeidee --- .github/workflows/tests.yml | 46 +++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index bf225d42..1191a8ea 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -41,7 +41,7 @@ jobs: mysql: strategy: matrix: - dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] + dbversion: ['mysql:latest', 'mysql:5.7'] go: ['1.19', '1.18'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -72,7 +72,6 @@ jobs: - name: Check out code into the Go module directory uses: actions/checkout@v3 - - name: go mod package cache uses: actions/cache@v3 with: @@ -82,6 +81,49 @@ jobs: - name: Tests run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + mariadb: + strategy: + matrix: + dbversion: [ 'mariadb:latest' ] + go: [ '1.19', '1.18' ] + platform: [ ubuntu-latest ] + runs-on: ${{ matrix.platform }} + + services: + mysql: + image: ${{ matrix.dbversion }} + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + ports: + - 9910:3306 + options: >- + --health-cmd "mariadb-admin ping -ugorm -pgorm" + --health-interval 10s + --health-start-period 10s + --health-timeout 5s + --health-retries 10 + + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v3 + + - name: go mod package cache + uses: actions/cache@v3 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} + + - name: Tests + run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + postgres: strategy: matrix: From c10f807d3c0b2b3204c80d62cf7650e21d7e3316 Mon Sep 17 00:00:00 2001 From: Saeid Date: Wed, 12 Jul 2023 16:21:22 +0300 Subject: [PATCH 19/19] test: coverage for foreign key violation err (#6403) * test: coverage for foreign key violation err * test: enabled foreign keys constraint for sqlite * test: enabled mysql& mssql for ErrForeignKeyViolate * test: disabled mysql & updated sqlserver driver version * test: skipped tidb --------- Co-authored-by: Saeid Saeidee --- tests/error_translator_test.go | 51 ++++++++++++++++++++++++++++++++++ tests/go.mod | 14 ++++------ tests/tests_test.go | 2 +- 3 files changed, 58 insertions(+), 9 deletions(-) diff --git a/tests/error_translator_test.go b/tests/error_translator_test.go index f6c70677..ee54300e 100644 --- a/tests/error_translator_test.go +++ b/tests/error_translator_test.go @@ -44,6 +44,8 @@ func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) { return } + DB.Migrator().DropTable(&City{}) + if err = db.AutoMigrate(&City{}); err != nil { t.Fatalf("failed to migrate cities table, got error: %v", err) } @@ -58,3 +60,52 @@ func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) { t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err) } } + +func TestSupportedDialectorWithErrForeignKeyViolated(t *testing.T) { + tidbSkip(t, "not support the foreign key feature") + + type City struct { + gorm.Model + Name string `gorm:"unique"` + } + + type Museum struct { + gorm.Model + Name string `gorm:"unique"` + CityID uint + City City `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:CityID;References:ID"` + } + + db, err := OpenTestConnection(&gorm.Config{TranslateError: true}) + if err != nil { + t.Fatalf("failed to connect database, got error %v", err) + } + + dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true} + if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) { + return + } + + DB.Migrator().DropTable(&City{}, &Museum{}) + + if err = db.AutoMigrate(&City{}, &Museum{}); err != nil { + t.Fatalf("failed to migrate countries & cities tables, got error: %v", err) + } + + city := City{Name: "Amsterdam"} + + err = db.Create(&city).Error + if err != nil { + t.Fatalf("failed to create city: %v", err) + } + + err = db.Create(&Museum{Name: "Eye Filmmuseum", CityID: city.ID}).Error + if err != nil { + t.Fatalf("failed to create museum: %v", err) + } + + err = db.Create(&Museum{Name: "Dungeon", CityID: 123}).Error + if !errors.Is(err, gorm.ErrForeignKeyViolated) { + t.Fatalf("expected err: %v got err: %v", gorm.ErrForeignKeyViolated, err) + } +} diff --git a/tests/go.mod b/tests/go.mod index 0b38b9d0..aebe5a06 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -4,15 +4,13 @@ go 1.16 require ( github.com/google/uuid v1.3.0 - github.com/jackc/pgx/v5 v5.3.1 // indirect github.com/jinzhu/now v1.1.5 - github.com/lib/pq v1.10.8 - github.com/mattn/go-sqlite3 v1.14.16 // indirect - gorm.io/driver/mysql v1.5.0 - gorm.io/driver/postgres v1.5.0 - gorm.io/driver/sqlite v1.5.0 - gorm.io/driver/sqlserver v1.5.1 - gorm.io/gorm v1.25.1 + github.com/lib/pq v1.10.9 + gorm.io/driver/mysql v1.5.2-0.20230612053416-48b6526a21f0 + gorm.io/driver/postgres v1.5.3-0.20230607070428-18bc84b75196 + gorm.io/driver/sqlite v1.5.2 + gorm.io/driver/sqlserver v1.5.2-0.20230613072041-6e2cde390b0a + gorm.io/gorm v1.25.2-0.20230610234218-206613868439 ) replace gorm.io/gorm => ../ diff --git a/tests/tests_test.go b/tests/tests_test.go index 0167d406..47c2a7c1 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -89,7 +89,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) { db, err = gorm.Open(mysql.Open(dbDSN), cfg) default: log.Println("testing sqlite3...") - db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), cfg) + db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg) } if err != nil {