Merge branch 'master' into fix_drop_column

This commit is contained in:
Jinzhu 2023-07-24 20:16:07 +08:00 committed by GitHub
commit 39fa1636ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 427 additions and 97 deletions

View File

@ -41,7 +41,7 @@ jobs:
mysql:
strategy:
matrix:
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
dbversion: ['mysql:latest', 'mysql:5.7']
go: ['1.19', '1.18']
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
@ -72,6 +72,48 @@ jobs:
- name: Check out code into the Go module directory
uses: actions/checkout@v3
- name: go mod package cache
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
mariadb:
strategy:
matrix:
dbversion: [ 'mariadb:latest' ]
go: [ '1.19', '1.18' ]
platform: [ ubuntu-latest ]
runs-on: ${{ matrix.platform }}
services:
mysql:
image: ${{ matrix.dbversion }}
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
ports:
- 9910:3306
options: >-
--health-cmd "mariadb-admin ping -ugorm -pgorm"
--health-interval 10s
--health-start-period 10s
--health-timeout 5s
--health-retries 10
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v3
- name: go mod package cache
uses: actions/cache@v3

View File

View File

@ -72,6 +72,7 @@ func Update(config *Config) func(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Update{})
if _, ok := db.Statement.Clauses["SET"]; !ok {
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
defer delete(db.Statement.Clauses, "SET")
db.Statement.AddClause(set)
} else {
return

View File

@ -47,4 +47,6 @@ var (
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
// ErrDuplicatedKey occurs when there is a unique key constraint violation
ErrDuplicatedKey = errors.New("duplicated key not allowed")
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
)

View File

@ -6,8 +6,6 @@ import (
"fmt"
"reflect"
"strings"
"sync"
"sync/atomic"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
@ -107,7 +105,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true}))
if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate {
return tx.Clauses(clause.OnConflict{UpdateAll: true}).Create(value)
return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value)
}
return updateTx
@ -533,6 +531,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
tx.ScanRows(rows, dest)
} else {
tx.RowsAffected = 0
tx.AddError(rows.Err())
}
tx.AddError(rows.Close())
}
@ -611,15 +610,6 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) {
return fc(tx)
}
var (
savepointIdx int64
savepointNamePool = &sync.Pool{
New: func() interface{} {
return fmt.Sprintf("gorm_%d", atomic.AddInt64(&savepointIdx, 1))
},
}
)
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
// they are rolled back.
@ -629,17 +619,14 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// nested transaction
if !db.DisableNestedTransaction {
poolName := savepointNamePool.Get()
defer savepointNamePool.Put(poolName)
err = db.SavePoint(poolName.(string)).Error
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
if err != nil {
return
}
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
db.RollbackTo(poolName.(string))
db.RollbackTo(fmt.Sprintf("sp%p", fc))
}
}()
}
@ -720,7 +707,21 @@ func (db *DB) Rollback() *DB {
func (db *DB) SavePoint(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
// close prepared statement, because SavePoint not support prepared statement.
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
var (
preparedStmtTx *PreparedStmtTX
isPreparedStmtTx bool
)
// close prepared statement, because SavePoint not support prepared statement.
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx.Tx
}
db.AddError(savePointer.SavePoint(db, name))
// restore prepared statement
if isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx
}
} else {
db.AddError(ErrUnsupportedDriver)
}
@ -729,7 +730,21 @@ func (db *DB) SavePoint(name string) *DB {
func (db *DB) RollbackTo(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
// close prepared statement, because RollbackTo not support prepared statement.
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
var (
preparedStmtTx *PreparedStmtTX
isPreparedStmtTx bool
)
// close prepared statement, because SavePoint not support prepared statement.
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx.Tx
}
db.AddError(savePointer.RollbackTo(db, name))
// restore prepared statement
if isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx
}
} else {
db.AddError(ErrUnsupportedDriver)
}

34
gorm.go
View File

@ -146,7 +146,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
}
if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{}
config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
}
if config.Logger == nil {
@ -187,15 +187,9 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
}
}
preparedStmt := &PreparedStmtDB{
ConnPool: db.ConnPool,
Stmts: make(map[string]*Stmt),
Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100),
}
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
if config.PrepareStmt {
preparedStmt := NewPreparedStmtDB(db.ConnPool)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
db.ConnPool = preparedStmt
}
@ -256,8 +250,15 @@ func (db *DB) Session(config *Session) *DB {
}
if config.PrepareStmt {
var preparedStmt *PreparedStmtDB
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
preparedStmt := v.(*PreparedStmtDB)
preparedStmt = v.(*PreparedStmtDB)
} else {
preparedStmt = NewPreparedStmtDB(db.ConnPool)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
}
switch t := tx.Statement.ConnPool.(type) {
case Tx:
tx.Statement.ConnPool = &PreparedStmtTX{
@ -274,7 +275,6 @@ func (db *DB) Session(config *Session) *DB {
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
}
}
if config.SkipHooks {
tx.Statement.SkipHooks = true
@ -375,11 +375,17 @@ func (db *DB) AddError(err error) error {
func (db *DB) DB() (*sql.DB, error) {
connPool := db.ConnPool
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil {
return connector.GetDBConnWithContext(db)
}
if sqldb, ok := connPool.(*sql.DB); ok {
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil {
return sqldb, err
}
}
if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil {
return sqldb, nil
}

View File

@ -77,6 +77,12 @@ type GetDBConnector interface {
GetDBConn() (*sql.DB, error)
}
// GetDBConnectorWithContext represents SQL db connector which takes into
// account the current database context
type GetDBConnectorWithContext interface {
GetDBConnWithContext(db *DB) (*sql.DB, error)
}
// Rows rows interface
type Rows interface {
Columns() ([]string, error)

View File

@ -3,6 +3,7 @@ package gorm
import (
"context"
"database/sql"
"reflect"
"sync"
)
@ -20,6 +21,15 @@ type PreparedStmtDB struct {
ConnPool
}
func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
return &PreparedStmtDB{
ConnPool: connPool,
Stmts: make(map[string]*Stmt),
Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100),
}
}
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
@ -163,14 +173,14 @@ type PreparedStmtTX struct {
}
func (tx *PreparedStmtTX) Commit() error {
if tx.Tx != nil {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Commit()
}
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) Rollback() error {
if tx.Tx != nil {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Rollback()
}
return ErrInvalidTransaction

View File

@ -846,7 +846,7 @@ func (field *Field) setupValuerAndSetter() {
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
switch data := v.(type) {
case **time.Time:
if data != nil {
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data))
}
case time.Time:
@ -882,14 +882,12 @@ func (field *Field) setupValuerAndSetter() {
reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
return
} else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() || !reflectV.IsValid() {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else {
return field.Set(ctx, value, reflectV.Elem().Interface())
}
} else {
fieldValue := field.ReflectValueOf(ctx, value)
if fieldValue.IsNil() {
@ -910,14 +908,12 @@ func (field *Field) setupValuerAndSetter() {
reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
return
} else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() || !reflectV.IsValid() {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else {
return field.Set(ctx, value, reflectV.Elem().Interface())
}
} else {
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()

View File

@ -32,6 +32,7 @@ type NamingStrategy struct {
SingularTable bool
NameReplacer Replacer
NoLowerCase bool
IdentifierMaxLength int
}
// TableName convert string to table name
@ -89,12 +90,16 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string {
prefix, table, name,
}, "_"), ".", "_")
if utf8.RuneCountInString(formattedName) > 64 {
if ns.IdentifierMaxLength == 0 {
ns.IdentifierMaxLength = 64
}
if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength {
h := sha1.New()
h.Write([]byte(formattedName))
bs := h.Sum(nil)
formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8]
formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8]
}
return formattedName
}

View File

@ -189,8 +189,17 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) {
}
}
func TestFormatNameWithStringLongerThan63Characters(t *testing.T) {
ns := NamingStrategy{IdentifierMaxLength: 63}
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVer180f2c67" {
t.Errorf("invalid formatted name generated, got %v", formattedName)
}
}
func TestFormatNameWithStringLongerThan64Characters(t *testing.T) {
ns := NamingStrategy{}
ns := NamingStrategy{IdentifierMaxLength: 64}
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" {

View File

@ -768,7 +768,7 @@ func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) {
s, err := schema.Parse(
&Book{},
&sync.Map{},
schema.NamingStrategy{},
schema.NamingStrategy{IdentifierMaxLength: 64},
)
if err != nil {
t.Fatalf("Failed to parse schema")

View File

@ -358,7 +358,7 @@ func TestDuplicateMany2ManyAssociation(t *testing.T) {
}
func TestConcurrentMany2ManyAssociation(t *testing.T) {
db, err := OpenTestConnection()
db, err := OpenTestConnection(&gorm.Config{})
if err != nil {
t.Fatalf("open test connection failed, err: %+v", err)
}

View File

@ -4,7 +4,9 @@ import (
"database/sql/driver"
"encoding/json"
"errors"
"reflect"
"testing"
"time"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
@ -108,6 +110,10 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
Name string
Email string
Age int
Content Content
ContentPtr *Content
Birthday time.Time
BirthdayPtr *time.Time
}
type HNPost struct {
@ -135,6 +141,48 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
if hnPost.Author != nil {
t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author)
}
now := time.Now().Round(time.Second)
NewPost := HNPost{
BasePost: &BasePost{Title: "embedded_pointer_type2"},
Author: &Author{
Name: "test",
Content: Content{"test"},
ContentPtr: nil,
Birthday: now,
BirthdayPtr: nil,
},
}
DB.Create(&NewPost)
hnPost = HNPost{}
if err := DB.First(&hnPost, "title = ?", NewPost.Title).Error; err != nil {
t.Errorf("No error should happen when find embedded pointer type, but got %v", err)
}
if hnPost.Title != NewPost.Title {
t.Errorf("Should find correct value for embedded pointer type")
}
if hnPost.Author.Name != NewPost.Author.Name {
t.Errorf("Expected to get Author name %v but got: %v", NewPost.Author.Name, hnPost.Author.Name)
}
if !reflect.DeepEqual(NewPost.Author.Content, hnPost.Author.Content) {
t.Errorf("Expected to get Author content %v but got: %v", NewPost.Author.Content, hnPost.Author.Content)
}
if hnPost.Author.ContentPtr != nil {
t.Errorf("Expected to get nil Author contentPtr but got: %v", hnPost.Author.ContentPtr)
}
if NewPost.Author.Birthday.UnixMilli() != hnPost.Author.Birthday.UnixMilli() {
t.Errorf("Expected to get Author birthday with %+v but got: %+v", NewPost.Author.Birthday, hnPost.Author.Birthday)
}
if hnPost.Author.BirthdayPtr != nil {
t.Errorf("Expected to get nil Author birthdayPtr but got: %+v", hnPost.Author.BirthdayPtr)
}
}
type Content struct {
@ -142,19 +190,27 @@ type Content struct {
}
func (c Content) Value() (driver.Value, error) {
return json.Marshal(c)
// mssql driver with issue on handling null bytes https://github.com/denisenkom/go-mssqldb/issues/530,
b, err := json.Marshal(c)
return string(b[:]), err
}
func (c *Content) Scan(src interface{}) error {
b, ok := src.([]byte)
var value Content
str, ok := src.(string)
if !ok {
byt, ok := src.([]byte)
if !ok {
return errors.New("Embedded.Scan byte assertion failed")
}
var value Content
if err := json.Unmarshal(b, &value); err != nil {
if err := json.Unmarshal(byt, &value); err != nil {
return err
}
} else {
if err := json.Unmarshal([]byte(str), &value); err != nil {
return err
}
}
*c = value

View File

@ -15,8 +15,8 @@ func TestDialectorWithErrorTranslatorSupport(t *testing.T) {
db, _ := gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr})
err := db.AddError(untranslatedErr)
if errors.Is(err, translatedErr) {
t.Fatalf("expected err: %v got err: %v", translatedErr, err)
if !errors.Is(err, untranslatedErr) {
t.Fatalf("expected err: %v got err: %v", untranslatedErr, err)
}
// it should translate error when the TranslateError flag is true
@ -27,3 +27,85 @@ func TestDialectorWithErrorTranslatorSupport(t *testing.T) {
t.Fatalf("expected err: %v got err: %v", translatedErr, err)
}
}
func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) {
type City struct {
gorm.Model
Name string `gorm:"unique"`
}
db, err := OpenTestConnection(&gorm.Config{TranslateError: true})
if err != nil {
t.Fatalf("failed to connect database, got error %v", err)
}
dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
return
}
DB.Migrator().DropTable(&City{})
if err = db.AutoMigrate(&City{}); err != nil {
t.Fatalf("failed to migrate cities table, got error: %v", err)
}
err = db.Create(&City{Name: "Kabul"}).Error
if err != nil {
t.Fatalf("failed to create record: %v", err)
}
err = db.Create(&City{Name: "Kabul"}).Error
if !errors.Is(err, gorm.ErrDuplicatedKey) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
}
}
func TestSupportedDialectorWithErrForeignKeyViolated(t *testing.T) {
tidbSkip(t, "not support the foreign key feature")
type City struct {
gorm.Model
Name string `gorm:"unique"`
}
type Museum struct {
gorm.Model
Name string `gorm:"unique"`
CityID uint
City City `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:CityID;References:ID"`
}
db, err := OpenTestConnection(&gorm.Config{TranslateError: true})
if err != nil {
t.Fatalf("failed to connect database, got error %v", err)
}
dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
return
}
DB.Migrator().DropTable(&City{}, &Museum{})
if err = db.AutoMigrate(&City{}, &Museum{}); err != nil {
t.Fatalf("failed to migrate countries & cities tables, got error: %v", err)
}
city := City{Name: "Amsterdam"}
err = db.Create(&city).Error
if err != nil {
t.Fatalf("failed to create city: %v", err)
}
err = db.Create(&Museum{Name: "Eye Filmmuseum", CityID: city.ID}).Error
if err != nil {
t.Fatalf("failed to create museum: %v", err)
}
err = db.Create(&Museum{Name: "Dungeon", CityID: 123}).Error
if !errors.Is(err, gorm.ErrForeignKeyViolated) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrForeignKeyViolated, err)
}
}

View File

@ -3,9 +3,19 @@ package tests_test
import (
"testing"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
func TestOpen(t *testing.T) {
dsn := "gorm:gorm@tcp(localhost:9910)/gorm?loc=Asia%2FHongKong" // invalid loc
_, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
if err == nil {
t.Fatalf("should returns error but got nil")
}
}
func TestReturningWithNullToZeroValues(t *testing.T) {
dialect := DB.Dialector.Name()
switch dialect {

View File

@ -92,7 +92,7 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
}
func TestPreparedStmtDeadlock(t *testing.T) {
tx, err := OpenTestConnection()
tx, err := OpenTestConnection(&gorm.Config{})
AssertEqual(t, err, nil)
sqlDB, _ := tx.DB()
@ -127,7 +127,7 @@ func TestPreparedStmtDeadlock(t *testing.T) {
}
func TestPreparedStmtError(t *testing.T) {
tx, err := OpenTestConnection()
tx, err := OpenTestConnection(&gorm.Config{})
AssertEqual(t, err, nil)
sqlDB, _ := tx.DB()

View File

@ -170,10 +170,10 @@ func (data *EncryptedData) Scan(value interface{}) error {
return errors.New("Too short")
}
*data = b[3:]
*data = append((*data)[0:], b[3:]...)
return nil
} else if s, ok := value.(string); ok {
*data = []byte(s)[3:]
*data = []byte(s[3:])
return nil
}

View File

@ -26,7 +26,7 @@ var (
func init() {
var err error
if DB, err = OpenTestConnection(); err != nil {
if DB, err = OpenTestConnection(&gorm.Config{}); err != nil {
log.Printf("failed to connect database, got error %v", err)
os.Exit(1)
} else {
@ -49,7 +49,7 @@ func init() {
}
}
func OpenTestConnection() (db *gorm.DB, err error) {
func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
dbDSN := os.Getenv("GORM_DSN")
switch os.Getenv("GORM_DIALECT") {
case "mysql":
@ -57,7 +57,7 @@ func OpenTestConnection() (db *gorm.DB, err error) {
if dbDSN == "" {
dbDSN = mysqlDSN
}
db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{})
db, err = gorm.Open(mysql.Open(dbDSN), cfg)
case "postgres":
log.Println("testing postgres...")
if dbDSN == "" {
@ -66,7 +66,7 @@ func OpenTestConnection() (db *gorm.DB, err error) {
db, err = gorm.Open(postgres.New(postgres.Config{
DSN: dbDSN,
PreferSimpleProtocol: true,
}), &gorm.Config{})
}), cfg)
case "sqlserver":
// go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest
// SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930
@ -80,16 +80,16 @@ func OpenTestConnection() (db *gorm.DB, err error) {
if dbDSN == "" {
dbDSN = sqlserverDSN
}
db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{})
db, err = gorm.Open(sqlserver.Open(dbDSN), cfg)
case "tidb":
log.Println("testing tidb...")
if dbDSN == "" {
dbDSN = tidbDSN
}
db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{})
db, err = gorm.Open(mysql.Open(dbDSN), cfg)
default:
log.Println("testing sqlite3...")
db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{})
db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg)
}
if err != nil {

View File

@ -57,6 +57,19 @@ func TestTransaction(t *testing.T) {
if err := DB.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
t.Fatalf("Should be able to find committed record, but got %v", err)
}
t.Run("this is test nested transaction and prepareStmt coexist case", func(t *testing.T) {
// enable prepare statement
tx3 := DB.Session(&gorm.Session{PrepareStmt: true})
if err := tx3.Transaction(func(tx4 *gorm.DB) error {
// nested transaction
return tx4.Transaction(func(tx5 *gorm.DB) error {
return tx5.First(&User{}, "name = ?", "transaction-2").Error
})
}); err != nil {
t.Fatalf("prepare statement and nested transcation coexist" + err.Error())
}
})
}
func TestCancelTransaction(t *testing.T) {
@ -348,7 +361,7 @@ func TestDisabledNestedTransaction(t *testing.T) {
}
func TestTransactionOnClosedConn(t *testing.T) {
DB, err := OpenTestConnection()
DB, err := OpenTestConnection(&gorm.Config{})
if err != nil {
t.Fatalf("failed to connect database, got error %v", err)
}

View File

@ -208,13 +208,17 @@ func TestUpdateColumn(t *testing.T) {
CheckUser(t, user1, *users[0])
CheckUser(t, user2, *users[1])
DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew")
DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew").UpdateColumn("age", 19)
AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano())
if users[1].Name != "update_column_02_newnew" {
t.Errorf("user 2's name should be updated, but got %v", users[1].Name)
}
if users[1].Age != 19 {
t.Errorf("user 2's name should be updated, but got %v", users[1].Age)
}
DB.Model(users[1]).UpdateColumn("age", gorm.Expr("age + 100 - 50"))
var user3 User
DB.First(&user3, users[1].ID)
@ -805,3 +809,76 @@ func TestUpdateWithDiffSchema(t *testing.T) {
AssertEqual(t, err, nil)
AssertEqual(t, "update-diff-schema-2", user.Name)
}
type TokenOwner struct {
ID int
Name string
Token Token `gorm:"foreignKey:UserID"`
}
func (t *TokenOwner) BeforeSave(tx *gorm.DB) error {
t.Name += "_name"
return nil
}
type Token struct {
UserID int `gorm:"primary_key"`
Content string `gorm:"type:varchar(100)"`
}
func (t *Token) BeforeSave(tx *gorm.DB) error {
t.Content += "_encrypted"
return nil
}
func TestSaveWithHooks(t *testing.T) {
DB.Migrator().DropTable(&Token{}, &TokenOwner{})
DB.AutoMigrate(&Token{}, &TokenOwner{})
saveTokenOwner := func(owner *TokenOwner) (*TokenOwner, error) {
var newOwner TokenOwner
if err := DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Debug().Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil {
return err
}
if err := tx.Preload("Token").First(&newOwner, owner.ID).Error; err != nil {
return err
}
return nil
}); err != nil {
return nil, err
}
return &newOwner, nil
}
owner := TokenOwner{
Name: "user",
Token: Token{Content: "token"},
}
o1, err := saveTokenOwner(&owner)
if err != nil {
t.Errorf("failed to save token owner, got error: %v", err)
}
if o1.Name != "user_name" {
t.Errorf(`owner name should be "user_name", but got: "%s"`, o1.Name)
}
if o1.Token.Content != "token_encrypted" {
t.Errorf(`token content should be "token_encrypted", but got: "%s"`, o1.Token.Content)
}
owner = TokenOwner{
ID: owner.ID,
Name: "user",
Token: Token{Content: "token2"},
}
o2, err := saveTokenOwner(&owner)
if err != nil {
t.Errorf("failed to save token owner, got error: %v", err)
}
if o2.Name != "user_name" {
t.Errorf(`owner name should be "user_name", but got: "%s"`, o2.Name)
}
if o2.Token.Content != "token2_encrypted" {
t.Errorf(`token content should be "token2_encrypted", but got: "%s"`, o2.Token.Content)
}
}

View File

@ -11,7 +11,7 @@ import (
// He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table)
// He speaks many languages (many to many) and has many friends (many to many - single-table)
// His pet also has one Toy (has one - polymorphic)
// NamedPet is a reference to a Named `Pets` (has many)
// NamedPet is a reference to a named `Pet` (has one)
type User struct {
gorm.Model
Name string