diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index bf225d42..1191a8ea 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -41,7 +41,7 @@ jobs: mysql: strategy: matrix: - dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] + dbversion: ['mysql:latest', 'mysql:5.7'] go: ['1.19', '1.18'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -72,7 +72,6 @@ jobs: - name: Check out code into the Go module directory uses: actions/checkout@v3 - - name: go mod package cache uses: actions/cache@v3 with: @@ -82,6 +81,49 @@ jobs: - name: Tests run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + mariadb: + strategy: + matrix: + dbversion: [ 'mariadb:latest' ] + go: [ '1.19', '1.18' ] + platform: [ ubuntu-latest ] + runs-on: ${{ matrix.platform }} + + services: + mysql: + image: ${{ matrix.dbversion }} + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + ports: + - 9910:3306 + options: >- + --health-cmd "mariadb-admin ping -ugorm -pgorm" + --health-interval 10s + --health-start-period 10s + --health-timeout 5s + --health-retries 10 + + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v3 + + - name: go mod package cache + uses: actions/cache@v3 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} + + - name: Tests + run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + postgres: strategy: matrix: diff --git a/License b/LICENSE similarity index 100% rename from License rename to LICENSE diff --git a/callbacks/update.go b/callbacks/update.go index 4eb75788..ff075dcf 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -72,6 +72,7 @@ func Update(config *Config) func(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Update{}) if _, ok := db.Statement.Clauses["SET"]; !ok { if set := ConvertToAssignments(db.Statement); len(set) != 0 { + defer delete(db.Statement.Clauses, "SET") db.Statement.AddClause(set) } else { return diff --git a/errors.go b/errors.go index 57e3fc5e..cd76f1f5 100644 --- a/errors.go +++ b/errors.go @@ -47,4 +47,6 @@ var ( ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used") // ErrDuplicatedKey occurs when there is a unique key constraint violation ErrDuplicatedKey = errors.New("duplicated key not allowed") + // ErrForeignKeyViolated occurs when there is a foreign key constraint violation + ErrForeignKeyViolated = errors.New("violates foreign key constraint") ) diff --git a/finisher_api.go b/finisher_api.go index 0e26f181..f80aa6c0 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -6,8 +6,6 @@ import ( "fmt" "reflect" "strings" - "sync" - "sync/atomic" "gorm.io/gorm/clause" "gorm.io/gorm/logger" @@ -107,7 +105,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true})) if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate { - return tx.Clauses(clause.OnConflict{UpdateAll: true}).Create(value) + return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value) } return updateTx @@ -533,6 +531,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { tx.ScanRows(rows, dest) } else { tx.RowsAffected = 0 + tx.AddError(rows.Err()) } tx.AddError(rows.Close()) } @@ -611,15 +610,6 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) { return fc(tx) } -var ( - savepointIdx int64 - savepointNamePool = &sync.Pool{ - New: func() interface{} { - return fmt.Sprintf("gorm_%d", atomic.AddInt64(&savepointIdx, 1)) - }, - } -) - // Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an // arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs // they are rolled back. @@ -629,17 +619,14 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { // nested transaction if !db.DisableNestedTransaction { - poolName := savepointNamePool.Get() - defer savepointNamePool.Put(poolName) - err = db.SavePoint(poolName.(string)).Error + err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error if err != nil { return } - defer func() { // Make sure to rollback when panic, Block error or Commit error if panicked || err != nil { - db.RollbackTo(poolName.(string)) + db.RollbackTo(fmt.Sprintf("sp%p", fc)) } }() } @@ -720,7 +707,21 @@ func (db *DB) Rollback() *DB { func (db *DB) SavePoint(name string) *DB { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + // close prepared statement, because SavePoint not support prepared statement. + // e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html + var ( + preparedStmtTx *PreparedStmtTX + isPreparedStmtTx bool + ) + // close prepared statement, because SavePoint not support prepared statement. + if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx.Tx + } db.AddError(savePointer.SavePoint(db, name)) + // restore prepared statement + if isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx + } } else { db.AddError(ErrUnsupportedDriver) } @@ -729,7 +730,21 @@ func (db *DB) SavePoint(name string) *DB { func (db *DB) RollbackTo(name string) *DB { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + // close prepared statement, because RollbackTo not support prepared statement. + // e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html + var ( + preparedStmtTx *PreparedStmtTX + isPreparedStmtTx bool + ) + // close prepared statement, because SavePoint not support prepared statement. + if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx.Tx + } db.AddError(savePointer.RollbackTo(db, name)) + // restore prepared statement + if isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx + } } else { db.AddError(ErrUnsupportedDriver) } diff --git a/gorm.go b/gorm.go index 07a913fc..9297850e 100644 --- a/gorm.go +++ b/gorm.go @@ -146,7 +146,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { } if config.NamingStrategy == nil { - config.NamingStrategy = schema.NamingStrategy{} + config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64 } if config.Logger == nil { @@ -187,15 +187,9 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { } } - preparedStmt := &PreparedStmtDB{ - ConnPool: db.ConnPool, - Stmts: make(map[string]*Stmt), - Mux: &sync.RWMutex{}, - PreparedSQL: make([]string, 0, 100), - } - db.cacheStore.Store(preparedStmtDBKey, preparedStmt) - if config.PrepareStmt { + preparedStmt := NewPreparedStmtDB(db.ConnPool) + db.cacheStore.Store(preparedStmtDBKey, preparedStmt) db.ConnPool = preparedStmt } @@ -256,24 +250,30 @@ func (db *DB) Session(config *Session) *DB { } if config.PrepareStmt { + var preparedStmt *PreparedStmtDB + if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { - preparedStmt := v.(*PreparedStmtDB) - switch t := tx.Statement.ConnPool.(type) { - case Tx: - tx.Statement.ConnPool = &PreparedStmtTX{ - Tx: t, - PreparedStmtDB: preparedStmt, - } - default: - tx.Statement.ConnPool = &PreparedStmtDB{ - ConnPool: db.Config.ConnPool, - Mux: preparedStmt.Mux, - Stmts: preparedStmt.Stmts, - } - } - txConfig.ConnPool = tx.Statement.ConnPool - txConfig.PrepareStmt = true + preparedStmt = v.(*PreparedStmtDB) + } else { + preparedStmt = NewPreparedStmtDB(db.ConnPool) + db.cacheStore.Store(preparedStmtDBKey, preparedStmt) } + + switch t := tx.Statement.ConnPool.(type) { + case Tx: + tx.Statement.ConnPool = &PreparedStmtTX{ + Tx: t, + PreparedStmtDB: preparedStmt, + } + default: + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + Mux: preparedStmt.Mux, + Stmts: preparedStmt.Stmts, + } + } + txConfig.ConnPool = tx.Statement.ConnPool + txConfig.PrepareStmt = true } if config.SkipHooks { @@ -375,11 +375,17 @@ func (db *DB) AddError(err error) error { func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool - if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { - return dbConnector.GetDBConn() + if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil { + return connector.GetDBConnWithContext(db) } - if sqldb, ok := connPool.(*sql.DB); ok { + if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { + if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil { + return sqldb, err + } + } + + if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil { return sqldb, nil } diff --git a/interfaces.go b/interfaces.go index 3bcc3d57..1950d740 100644 --- a/interfaces.go +++ b/interfaces.go @@ -77,6 +77,12 @@ type GetDBConnector interface { GetDBConn() (*sql.DB, error) } +// GetDBConnectorWithContext represents SQL db connector which takes into +// account the current database context +type GetDBConnectorWithContext interface { + GetDBConnWithContext(db *DB) (*sql.DB, error) +} + // Rows rows interface type Rows interface { Columns() ([]string, error) diff --git a/prepare_stmt.go b/prepare_stmt.go index e09fe814..10fefc31 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -3,6 +3,7 @@ package gorm import ( "context" "database/sql" + "reflect" "sync" ) @@ -20,6 +21,15 @@ type PreparedStmtDB struct { ConnPool } +func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { + return &PreparedStmtDB{ + ConnPool: connPool, + Stmts: make(map[string]*Stmt), + Mux: &sync.RWMutex{}, + PreparedSQL: make([]string, 0, 100), + } +} + func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { return dbConnector.GetDBConn() @@ -163,14 +173,14 @@ type PreparedStmtTX struct { } func (tx *PreparedStmtTX) Commit() error { - if tx.Tx != nil { + if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { return tx.Tx.Commit() } return ErrInvalidTransaction } func (tx *PreparedStmtTX) Rollback() error { - if tx.Tx != nil { + if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { return tx.Tx.Rollback() } return ErrInvalidTransaction diff --git a/schema/field.go b/schema/field.go index 7d1a1789..dd08e056 100644 --- a/schema/field.go +++ b/schema/field.go @@ -846,7 +846,7 @@ func (field *Field) setupValuerAndSetter() { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { case **time.Time: - if data != nil { + if data != nil && *data != nil { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data)) } case time.Time: @@ -882,14 +882,12 @@ func (field *Field) setupValuerAndSetter() { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { + return } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { - if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) - } else { - return field.Set(ctx, value, reflectV.Elem().Interface()) - } + return field.Set(ctx, value, reflectV.Elem().Interface()) } else { fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { @@ -910,14 +908,12 @@ func (field *Field) setupValuerAndSetter() { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { + return } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { - if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) - } else { - return field.Set(ctx, value, reflectV.Elem().Interface()) - } + return field.Set(ctx, value, reflectV.Elem().Interface()) } else { if valuer, ok := v.(driver.Valuer); ok { v, _ = valuer.Value() diff --git a/schema/naming.go b/schema/naming.go index a258beed..a2a0150a 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -28,10 +28,11 @@ type Replacer interface { // NamingStrategy tables, columns naming strategy type NamingStrategy struct { - TablePrefix string - SingularTable bool - NameReplacer Replacer - NoLowerCase bool + TablePrefix string + SingularTable bool + NameReplacer Replacer + NoLowerCase bool + IdentifierMaxLength int } // TableName convert string to table name @@ -89,12 +90,16 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string { prefix, table, name, }, "_"), ".", "_") - if utf8.RuneCountInString(formattedName) > 64 { + if ns.IdentifierMaxLength == 0 { + ns.IdentifierMaxLength = 64 + } + + if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength { h := sha1.New() h.Write([]byte(formattedName)) bs := h.Sum(nil) - formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8] + formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8] } return formattedName } diff --git a/schema/naming_test.go b/schema/naming_test.go index 3f598c33..ab7a5e31 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -189,8 +189,17 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) { } } +func TestFormatNameWithStringLongerThan63Characters(t *testing.T) { + ns := NamingStrategy{IdentifierMaxLength: 63} + + formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") + if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVer180f2c67" { + t.Errorf("invalid formatted name generated, got %v", formattedName) + } +} + func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { - ns := NamingStrategy{} + ns := NamingStrategy{IdentifierMaxLength: 64} formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" { diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 732f6f75..1eb66bb4 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -768,7 +768,7 @@ func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) { s, err := schema.Parse( &Book{}, &sync.Map{}, - schema.NamingStrategy{}, + schema.NamingStrategy{IdentifierMaxLength: 64}, ) if err != nil { t.Fatalf("Failed to parse schema") diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index b69d668a..39410aed 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -358,7 +358,7 @@ func TestDuplicateMany2ManyAssociation(t *testing.T) { } func TestConcurrentMany2ManyAssociation(t *testing.T) { - db, err := OpenTestConnection() + db, err := OpenTestConnection(&gorm.Config{}) if err != nil { t.Fatalf("open test connection failed, err: %+v", err) } diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 3747dad9..4314f88c 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -4,7 +4,9 @@ import ( "database/sql/driver" "encoding/json" "errors" + "reflect" "testing" + "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" @@ -104,10 +106,14 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) { } type Author struct { - ID string - Name string - Email string - Age int + ID string + Name string + Email string + Age int + Content Content + ContentPtr *Content + Birthday time.Time + BirthdayPtr *time.Time } type HNPost struct { @@ -135,6 +141,48 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) { if hnPost.Author != nil { t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author) } + + now := time.Now().Round(time.Second) + NewPost := HNPost{ + BasePost: &BasePost{Title: "embedded_pointer_type2"}, + Author: &Author{ + Name: "test", + Content: Content{"test"}, + ContentPtr: nil, + Birthday: now, + BirthdayPtr: nil, + }, + } + DB.Create(&NewPost) + + hnPost = HNPost{} + if err := DB.First(&hnPost, "title = ?", NewPost.Title).Error; err != nil { + t.Errorf("No error should happen when find embedded pointer type, but got %v", err) + } + + if hnPost.Title != NewPost.Title { + t.Errorf("Should find correct value for embedded pointer type") + } + + if hnPost.Author.Name != NewPost.Author.Name { + t.Errorf("Expected to get Author name %v but got: %v", NewPost.Author.Name, hnPost.Author.Name) + } + + if !reflect.DeepEqual(NewPost.Author.Content, hnPost.Author.Content) { + t.Errorf("Expected to get Author content %v but got: %v", NewPost.Author.Content, hnPost.Author.Content) + } + + if hnPost.Author.ContentPtr != nil { + t.Errorf("Expected to get nil Author contentPtr but got: %v", hnPost.Author.ContentPtr) + } + + if NewPost.Author.Birthday.UnixMilli() != hnPost.Author.Birthday.UnixMilli() { + t.Errorf("Expected to get Author birthday with %+v but got: %+v", NewPost.Author.Birthday, hnPost.Author.Birthday) + } + + if hnPost.Author.BirthdayPtr != nil { + t.Errorf("Expected to get nil Author birthdayPtr but got: %+v", hnPost.Author.BirthdayPtr) + } } type Content struct { @@ -142,18 +190,26 @@ type Content struct { } func (c Content) Value() (driver.Value, error) { - return json.Marshal(c) + // mssql driver with issue on handling null bytes https://github.com/denisenkom/go-mssqldb/issues/530, + b, err := json.Marshal(c) + return string(b[:]), err } func (c *Content) Scan(src interface{}) error { - b, ok := src.([]byte) - if !ok { - return errors.New("Embedded.Scan byte assertion failed") - } - var value Content - if err := json.Unmarshal(b, &value); err != nil { - return err + str, ok := src.(string) + if !ok { + byt, ok := src.([]byte) + if !ok { + return errors.New("Embedded.Scan byte assertion failed") + } + if err := json.Unmarshal(byt, &value); err != nil { + return err + } + } else { + if err := json.Unmarshal([]byte(str), &value); err != nil { + return err + } } *c = value diff --git a/tests/error_translator_test.go b/tests/error_translator_test.go index ead26fce..ee54300e 100644 --- a/tests/error_translator_test.go +++ b/tests/error_translator_test.go @@ -15,8 +15,8 @@ func TestDialectorWithErrorTranslatorSupport(t *testing.T) { db, _ := gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}) err := db.AddError(untranslatedErr) - if errors.Is(err, translatedErr) { - t.Fatalf("expected err: %v got err: %v", translatedErr, err) + if !errors.Is(err, untranslatedErr) { + t.Fatalf("expected err: %v got err: %v", untranslatedErr, err) } // it should translate error when the TranslateError flag is true @@ -27,3 +27,85 @@ func TestDialectorWithErrorTranslatorSupport(t *testing.T) { t.Fatalf("expected err: %v got err: %v", translatedErr, err) } } + +func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) { + type City struct { + gorm.Model + Name string `gorm:"unique"` + } + + db, err := OpenTestConnection(&gorm.Config{TranslateError: true}) + if err != nil { + t.Fatalf("failed to connect database, got error %v", err) + } + + dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true} + if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) { + return + } + + DB.Migrator().DropTable(&City{}) + + if err = db.AutoMigrate(&City{}); err != nil { + t.Fatalf("failed to migrate cities table, got error: %v", err) + } + + err = db.Create(&City{Name: "Kabul"}).Error + if err != nil { + t.Fatalf("failed to create record: %v", err) + } + + err = db.Create(&City{Name: "Kabul"}).Error + if !errors.Is(err, gorm.ErrDuplicatedKey) { + t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err) + } +} + +func TestSupportedDialectorWithErrForeignKeyViolated(t *testing.T) { + tidbSkip(t, "not support the foreign key feature") + + type City struct { + gorm.Model + Name string `gorm:"unique"` + } + + type Museum struct { + gorm.Model + Name string `gorm:"unique"` + CityID uint + City City `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:CityID;References:ID"` + } + + db, err := OpenTestConnection(&gorm.Config{TranslateError: true}) + if err != nil { + t.Fatalf("failed to connect database, got error %v", err) + } + + dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true} + if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) { + return + } + + DB.Migrator().DropTable(&City{}, &Museum{}) + + if err = db.AutoMigrate(&City{}, &Museum{}); err != nil { + t.Fatalf("failed to migrate countries & cities tables, got error: %v", err) + } + + city := City{Name: "Amsterdam"} + + err = db.Create(&city).Error + if err != nil { + t.Fatalf("failed to create city: %v", err) + } + + err = db.Create(&Museum{Name: "Eye Filmmuseum", CityID: city.ID}).Error + if err != nil { + t.Fatalf("failed to create museum: %v", err) + } + + err = db.Create(&Museum{Name: "Dungeon", CityID: 123}).Error + if !errors.Is(err, gorm.ErrForeignKeyViolated) { + t.Fatalf("expected err: %v got err: %v", gorm.ErrForeignKeyViolated, err) + } +} diff --git a/tests/gorm_test.go b/tests/gorm_test.go index 9827465c..4c31b88b 100644 --- a/tests/gorm_test.go +++ b/tests/gorm_test.go @@ -3,9 +3,19 @@ package tests_test import ( "testing" + "gorm.io/driver/mysql" + "gorm.io/gorm" ) +func TestOpen(t *testing.T) { + dsn := "gorm:gorm@tcp(localhost:9910)/gorm?loc=Asia%2FHongKong" // invalid loc + _, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) + if err == nil { + t.Fatalf("should returns error but got nil") + } +} + func TestReturningWithNullToZeroValues(t *testing.T) { dialect := DB.Dialector.Name() switch dialect { diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 64baa01b..b234c8bf 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -92,7 +92,7 @@ func TestPreparedStmtFromTransaction(t *testing.T) { } func TestPreparedStmtDeadlock(t *testing.T) { - tx, err := OpenTestConnection() + tx, err := OpenTestConnection(&gorm.Config{}) AssertEqual(t, err, nil) sqlDB, _ := tx.DB() @@ -127,7 +127,7 @@ func TestPreparedStmtDeadlock(t *testing.T) { } func TestPreparedStmtError(t *testing.T) { - tx, err := OpenTestConnection() + tx, err := OpenTestConnection(&gorm.Config{}) AssertEqual(t, err, nil) sqlDB, _ := tx.DB() diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 14121699..472434b4 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -170,10 +170,10 @@ func (data *EncryptedData) Scan(value interface{}) error { return errors.New("Too short") } - *data = b[3:] + *data = append((*data)[0:], b[3:]...) return nil } else if s, ok := value.(string); ok { - *data = []byte(s)[3:] + *data = []byte(s[3:]) return nil } diff --git a/tests/tests_test.go b/tests/tests_test.go index 90eb847f..47c2a7c1 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -26,7 +26,7 @@ var ( func init() { var err error - if DB, err = OpenTestConnection(); err != nil { + if DB, err = OpenTestConnection(&gorm.Config{}); err != nil { log.Printf("failed to connect database, got error %v", err) os.Exit(1) } else { @@ -49,7 +49,7 @@ func init() { } } -func OpenTestConnection() (db *gorm.DB, err error) { +func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) { dbDSN := os.Getenv("GORM_DSN") switch os.Getenv("GORM_DIALECT") { case "mysql": @@ -57,7 +57,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { if dbDSN == "" { dbDSN = mysqlDSN } - db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) + db, err = gorm.Open(mysql.Open(dbDSN), cfg) case "postgres": log.Println("testing postgres...") if dbDSN == "" { @@ -66,7 +66,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { db, err = gorm.Open(postgres.New(postgres.Config{ DSN: dbDSN, PreferSimpleProtocol: true, - }), &gorm.Config{}) + }), cfg) case "sqlserver": // go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest // SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 @@ -80,16 +80,16 @@ func OpenTestConnection() (db *gorm.DB, err error) { if dbDSN == "" { dbDSN = sqlserverDSN } - db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{}) + db, err = gorm.Open(sqlserver.Open(dbDSN), cfg) case "tidb": log.Println("testing tidb...") if dbDSN == "" { dbDSN = tidbDSN } - db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) + db, err = gorm.Open(mysql.Open(dbDSN), cfg) default: log.Println("testing sqlite3...") - db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) + db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg) } if err != nil { diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 5872da94..126ccb23 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -57,6 +57,19 @@ func TestTransaction(t *testing.T) { if err := DB.First(&User{}, "name = ?", "transaction-2").Error; err != nil { t.Fatalf("Should be able to find committed record, but got %v", err) } + + t.Run("this is test nested transaction and prepareStmt coexist case", func(t *testing.T) { + // enable prepare statement + tx3 := DB.Session(&gorm.Session{PrepareStmt: true}) + if err := tx3.Transaction(func(tx4 *gorm.DB) error { + // nested transaction + return tx4.Transaction(func(tx5 *gorm.DB) error { + return tx5.First(&User{}, "name = ?", "transaction-2").Error + }) + }); err != nil { + t.Fatalf("prepare statement and nested transcation coexist" + err.Error()) + } + }) } func TestCancelTransaction(t *testing.T) { @@ -348,7 +361,7 @@ func TestDisabledNestedTransaction(t *testing.T) { } func TestTransactionOnClosedConn(t *testing.T) { - DB, err := OpenTestConnection() + DB, err := OpenTestConnection(&gorm.Config{}) if err != nil { t.Fatalf("failed to connect database, got error %v", err) } diff --git a/tests/update_test.go b/tests/update_test.go index 36ffa6a0..c03d2d47 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -208,13 +208,17 @@ func TestUpdateColumn(t *testing.T) { CheckUser(t, user1, *users[0]) CheckUser(t, user2, *users[1]) - DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew") + DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew").UpdateColumn("age", 19) AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) if users[1].Name != "update_column_02_newnew" { t.Errorf("user 2's name should be updated, but got %v", users[1].Name) } + if users[1].Age != 19 { + t.Errorf("user 2's name should be updated, but got %v", users[1].Age) + } + DB.Model(users[1]).UpdateColumn("age", gorm.Expr("age + 100 - 50")) var user3 User DB.First(&user3, users[1].ID) @@ -805,3 +809,76 @@ func TestUpdateWithDiffSchema(t *testing.T) { AssertEqual(t, err, nil) AssertEqual(t, "update-diff-schema-2", user.Name) } + +type TokenOwner struct { + ID int + Name string + Token Token `gorm:"foreignKey:UserID"` +} + +func (t *TokenOwner) BeforeSave(tx *gorm.DB) error { + t.Name += "_name" + return nil +} + +type Token struct { + UserID int `gorm:"primary_key"` + Content string `gorm:"type:varchar(100)"` +} + +func (t *Token) BeforeSave(tx *gorm.DB) error { + t.Content += "_encrypted" + return nil +} + +func TestSaveWithHooks(t *testing.T) { + DB.Migrator().DropTable(&Token{}, &TokenOwner{}) + DB.AutoMigrate(&Token{}, &TokenOwner{}) + + saveTokenOwner := func(owner *TokenOwner) (*TokenOwner, error) { + var newOwner TokenOwner + if err := DB.Transaction(func(tx *gorm.DB) error { + if err := tx.Debug().Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil { + return err + } + if err := tx.Preload("Token").First(&newOwner, owner.ID).Error; err != nil { + return err + } + return nil + }); err != nil { + return nil, err + } + return &newOwner, nil + } + + owner := TokenOwner{ + Name: "user", + Token: Token{Content: "token"}, + } + o1, err := saveTokenOwner(&owner) + if err != nil { + t.Errorf("failed to save token owner, got error: %v", err) + } + if o1.Name != "user_name" { + t.Errorf(`owner name should be "user_name", but got: "%s"`, o1.Name) + } + if o1.Token.Content != "token_encrypted" { + t.Errorf(`token content should be "token_encrypted", but got: "%s"`, o1.Token.Content) + } + + owner = TokenOwner{ + ID: owner.ID, + Name: "user", + Token: Token{Content: "token2"}, + } + o2, err := saveTokenOwner(&owner) + if err != nil { + t.Errorf("failed to save token owner, got error: %v", err) + } + if o2.Name != "user_name" { + t.Errorf(`owner name should be "user_name", but got: "%s"`, o2.Name) + } + if o2.Token.Content != "token2_encrypted" { + t.Errorf(`token content should be "token2_encrypted", but got: "%s"`, o2.Token.Content) + } +} diff --git a/utils/tests/models.go b/utils/tests/models.go index ec1651a3..a4bad2fc 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -11,7 +11,7 @@ import ( // He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table) // He speaks many languages (many to many) and has many friends (many to many - single-table) // His pet also has one Toy (has one - polymorphic) -// NamedPet is a reference to a Named `Pets` (has many) +// NamedPet is a reference to a named `Pet` (has one) type User struct { gorm.Model Name string