Merge branch 'master' into fix_drop_column

This commit is contained in:
Jinzhu 2023-07-24 20:16:07 +08:00 committed by GitHub
commit 39fa1636ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 427 additions and 97 deletions

View File

@ -41,7 +41,7 @@ jobs:
mysql: mysql:
strategy: strategy:
matrix: matrix:
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] dbversion: ['mysql:latest', 'mysql:5.7']
go: ['1.19', '1.18'] go: ['1.19', '1.18']
platform: [ubuntu-latest] platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }} runs-on: ${{ matrix.platform }}
@ -72,7 +72,6 @@ jobs:
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: go mod package cache - name: go mod package cache
uses: actions/cache@v3 uses: actions/cache@v3
with: with:
@ -82,6 +81,49 @@ jobs:
- name: Tests - name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
mariadb:
strategy:
matrix:
dbversion: [ 'mariadb:latest' ]
go: [ '1.19', '1.18' ]
platform: [ ubuntu-latest ]
runs-on: ${{ matrix.platform }}
services:
mysql:
image: ${{ matrix.dbversion }}
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
ports:
- 9910:3306
options: >-
--health-cmd "mariadb-admin ping -ugorm -pgorm"
--health-interval 10s
--health-start-period 10s
--health-timeout 5s
--health-retries 10
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v3
- name: go mod package cache
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
postgres: postgres:
strategy: strategy:
matrix: matrix:

View File

View File

@ -72,6 +72,7 @@ func Update(config *Config) func(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Update{}) db.Statement.AddClauseIfNotExists(clause.Update{})
if _, ok := db.Statement.Clauses["SET"]; !ok { if _, ok := db.Statement.Clauses["SET"]; !ok {
if set := ConvertToAssignments(db.Statement); len(set) != 0 { if set := ConvertToAssignments(db.Statement); len(set) != 0 {
defer delete(db.Statement.Clauses, "SET")
db.Statement.AddClause(set) db.Statement.AddClause(set)
} else { } else {
return return

View File

@ -47,4 +47,6 @@ var (
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used") ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
// ErrDuplicatedKey occurs when there is a unique key constraint violation // ErrDuplicatedKey occurs when there is a unique key constraint violation
ErrDuplicatedKey = errors.New("duplicated key not allowed") ErrDuplicatedKey = errors.New("duplicated key not allowed")
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
) )

View File

@ -6,8 +6,6 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"sync"
"sync/atomic"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
@ -107,7 +105,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true})) updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true}))
if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate { if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate {
return tx.Clauses(clause.OnConflict{UpdateAll: true}).Create(value) return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value)
} }
return updateTx return updateTx
@ -533,6 +531,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
tx.ScanRows(rows, dest) tx.ScanRows(rows, dest)
} else { } else {
tx.RowsAffected = 0 tx.RowsAffected = 0
tx.AddError(rows.Err())
} }
tx.AddError(rows.Close()) tx.AddError(rows.Close())
} }
@ -611,15 +610,6 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) {
return fc(tx) return fc(tx)
} }
var (
savepointIdx int64
savepointNamePool = &sync.Pool{
New: func() interface{} {
return fmt.Sprintf("gorm_%d", atomic.AddInt64(&savepointIdx, 1))
},
}
)
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an // Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs // arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
// they are rolled back. // they are rolled back.
@ -629,17 +619,14 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// nested transaction // nested transaction
if !db.DisableNestedTransaction { if !db.DisableNestedTransaction {
poolName := savepointNamePool.Get() err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
defer savepointNamePool.Put(poolName)
err = db.SavePoint(poolName.(string)).Error
if err != nil { if err != nil {
return return
} }
defer func() { defer func() {
// Make sure to rollback when panic, Block error or Commit error // Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil { if panicked || err != nil {
db.RollbackTo(poolName.(string)) db.RollbackTo(fmt.Sprintf("sp%p", fc))
} }
}() }()
} }
@ -720,7 +707,21 @@ func (db *DB) Rollback() *DB {
func (db *DB) SavePoint(name string) *DB { func (db *DB) SavePoint(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
// close prepared statement, because SavePoint not support prepared statement.
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
var (
preparedStmtTx *PreparedStmtTX
isPreparedStmtTx bool
)
// close prepared statement, because SavePoint not support prepared statement.
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx.Tx
}
db.AddError(savePointer.SavePoint(db, name)) db.AddError(savePointer.SavePoint(db, name))
// restore prepared statement
if isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx
}
} else { } else {
db.AddError(ErrUnsupportedDriver) db.AddError(ErrUnsupportedDriver)
} }
@ -729,7 +730,21 @@ func (db *DB) SavePoint(name string) *DB {
func (db *DB) RollbackTo(name string) *DB { func (db *DB) RollbackTo(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
// close prepared statement, because RollbackTo not support prepared statement.
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
var (
preparedStmtTx *PreparedStmtTX
isPreparedStmtTx bool
)
// close prepared statement, because SavePoint not support prepared statement.
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx.Tx
}
db.AddError(savePointer.RollbackTo(db, name)) db.AddError(savePointer.RollbackTo(db, name))
// restore prepared statement
if isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx
}
} else { } else {
db.AddError(ErrUnsupportedDriver) db.AddError(ErrUnsupportedDriver)
} }

62
gorm.go
View File

@ -146,7 +146,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
} }
if config.NamingStrategy == nil { if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{} config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
} }
if config.Logger == nil { if config.Logger == nil {
@ -187,15 +187,9 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
} }
} }
preparedStmt := &PreparedStmtDB{
ConnPool: db.ConnPool,
Stmts: make(map[string]*Stmt),
Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100),
}
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
if config.PrepareStmt { if config.PrepareStmt {
preparedStmt := NewPreparedStmtDB(db.ConnPool)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
db.ConnPool = preparedStmt db.ConnPool = preparedStmt
} }
@ -256,24 +250,30 @@ func (db *DB) Session(config *Session) *DB {
} }
if config.PrepareStmt { if config.PrepareStmt {
var preparedStmt *PreparedStmtDB
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
preparedStmt := v.(*PreparedStmtDB) preparedStmt = v.(*PreparedStmtDB)
switch t := tx.Statement.ConnPool.(type) { } else {
case Tx: preparedStmt = NewPreparedStmtDB(db.ConnPool)
tx.Statement.ConnPool = &PreparedStmtTX{ db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
Tx: t,
PreparedStmtDB: preparedStmt,
}
default:
tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool,
Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts,
}
}
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
} }
switch t := tx.Statement.ConnPool.(type) {
case Tx:
tx.Statement.ConnPool = &PreparedStmtTX{
Tx: t,
PreparedStmtDB: preparedStmt,
}
default:
tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool,
Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts,
}
}
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
} }
if config.SkipHooks { if config.SkipHooks {
@ -375,11 +375,17 @@ func (db *DB) AddError(err error) error {
func (db *DB) DB() (*sql.DB, error) { func (db *DB) DB() (*sql.DB, error) {
connPool := db.ConnPool connPool := db.ConnPool
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil {
return dbConnector.GetDBConn() return connector.GetDBConnWithContext(db)
} }
if sqldb, ok := connPool.(*sql.DB); ok { if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil {
return sqldb, err
}
}
if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil {
return sqldb, nil return sqldb, nil
} }

View File

@ -77,6 +77,12 @@ type GetDBConnector interface {
GetDBConn() (*sql.DB, error) GetDBConn() (*sql.DB, error)
} }
// GetDBConnectorWithContext represents SQL db connector which takes into
// account the current database context
type GetDBConnectorWithContext interface {
GetDBConnWithContext(db *DB) (*sql.DB, error)
}
// Rows rows interface // Rows rows interface
type Rows interface { type Rows interface {
Columns() ([]string, error) Columns() ([]string, error)

View File

@ -3,6 +3,7 @@ package gorm
import ( import (
"context" "context"
"database/sql" "database/sql"
"reflect"
"sync" "sync"
) )
@ -20,6 +21,15 @@ type PreparedStmtDB struct {
ConnPool ConnPool
} }
func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
return &PreparedStmtDB{
ConnPool: connPool,
Stmts: make(map[string]*Stmt),
Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100),
}
}
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn() return dbConnector.GetDBConn()
@ -163,14 +173,14 @@ type PreparedStmtTX struct {
} }
func (tx *PreparedStmtTX) Commit() error { func (tx *PreparedStmtTX) Commit() error {
if tx.Tx != nil { if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Commit() return tx.Tx.Commit()
} }
return ErrInvalidTransaction return ErrInvalidTransaction
} }
func (tx *PreparedStmtTX) Rollback() error { func (tx *PreparedStmtTX) Rollback() error {
if tx.Tx != nil { if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Rollback() return tx.Tx.Rollback()
} }
return ErrInvalidTransaction return ErrInvalidTransaction

View File

@ -846,7 +846,7 @@ func (field *Field) setupValuerAndSetter() {
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
switch data := v.(type) { switch data := v.(type) {
case **time.Time: case **time.Time:
if data != nil { if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data)) field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data))
} }
case time.Time: case time.Time:
@ -882,14 +882,12 @@ func (field *Field) setupValuerAndSetter() {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() { if !reflectV.IsValid() {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
return
} else if reflectV.Type().AssignableTo(field.FieldType) { } else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV) field.ReflectValueOf(ctx, value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr { } else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() || !reflectV.IsValid() { return field.Set(ctx, value, reflectV.Elem().Interface())
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else {
return field.Set(ctx, value, reflectV.Elem().Interface())
}
} else { } else {
fieldValue := field.ReflectValueOf(ctx, value) fieldValue := field.ReflectValueOf(ctx, value)
if fieldValue.IsNil() { if fieldValue.IsNil() {
@ -910,14 +908,12 @@ func (field *Field) setupValuerAndSetter() {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() { if !reflectV.IsValid() {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
return
} else if reflectV.Type().AssignableTo(field.FieldType) { } else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV) field.ReflectValueOf(ctx, value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr { } else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() || !reflectV.IsValid() { return field.Set(ctx, value, reflectV.Elem().Interface())
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else {
return field.Set(ctx, value, reflectV.Elem().Interface())
}
} else { } else {
if valuer, ok := v.(driver.Valuer); ok { if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value() v, _ = valuer.Value()

View File

@ -28,10 +28,11 @@ type Replacer interface {
// NamingStrategy tables, columns naming strategy // NamingStrategy tables, columns naming strategy
type NamingStrategy struct { type NamingStrategy struct {
TablePrefix string TablePrefix string
SingularTable bool SingularTable bool
NameReplacer Replacer NameReplacer Replacer
NoLowerCase bool NoLowerCase bool
IdentifierMaxLength int
} }
// TableName convert string to table name // TableName convert string to table name
@ -89,12 +90,16 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string {
prefix, table, name, prefix, table, name,
}, "_"), ".", "_") }, "_"), ".", "_")
if utf8.RuneCountInString(formattedName) > 64 { if ns.IdentifierMaxLength == 0 {
ns.IdentifierMaxLength = 64
}
if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength {
h := sha1.New() h := sha1.New()
h.Write([]byte(formattedName)) h.Write([]byte(formattedName))
bs := h.Sum(nil) bs := h.Sum(nil)
formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8] formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8]
} }
return formattedName return formattedName
} }

View File

@ -189,8 +189,17 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) {
} }
} }
func TestFormatNameWithStringLongerThan63Characters(t *testing.T) {
ns := NamingStrategy{IdentifierMaxLength: 63}
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVer180f2c67" {
t.Errorf("invalid formatted name generated, got %v", formattedName)
}
}
func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { func TestFormatNameWithStringLongerThan64Characters(t *testing.T) {
ns := NamingStrategy{} ns := NamingStrategy{IdentifierMaxLength: 64}
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" { if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" {

View File

@ -768,7 +768,7 @@ func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) {
s, err := schema.Parse( s, err := schema.Parse(
&Book{}, &Book{},
&sync.Map{}, &sync.Map{},
schema.NamingStrategy{}, schema.NamingStrategy{IdentifierMaxLength: 64},
) )
if err != nil { if err != nil {
t.Fatalf("Failed to parse schema") t.Fatalf("Failed to parse schema")

View File

@ -358,7 +358,7 @@ func TestDuplicateMany2ManyAssociation(t *testing.T) {
} }
func TestConcurrentMany2ManyAssociation(t *testing.T) { func TestConcurrentMany2ManyAssociation(t *testing.T) {
db, err := OpenTestConnection() db, err := OpenTestConnection(&gorm.Config{})
if err != nil { if err != nil {
t.Fatalf("open test connection failed, err: %+v", err) t.Fatalf("open test connection failed, err: %+v", err)
} }

View File

@ -4,7 +4,9 @@ import (
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"errors" "errors"
"reflect"
"testing" "testing"
"time"
"gorm.io/gorm" "gorm.io/gorm"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
@ -104,10 +106,14 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
} }
type Author struct { type Author struct {
ID string ID string
Name string Name string
Email string Email string
Age int Age int
Content Content
ContentPtr *Content
Birthday time.Time
BirthdayPtr *time.Time
} }
type HNPost struct { type HNPost struct {
@ -135,6 +141,48 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
if hnPost.Author != nil { if hnPost.Author != nil {
t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author) t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author)
} }
now := time.Now().Round(time.Second)
NewPost := HNPost{
BasePost: &BasePost{Title: "embedded_pointer_type2"},
Author: &Author{
Name: "test",
Content: Content{"test"},
ContentPtr: nil,
Birthday: now,
BirthdayPtr: nil,
},
}
DB.Create(&NewPost)
hnPost = HNPost{}
if err := DB.First(&hnPost, "title = ?", NewPost.Title).Error; err != nil {
t.Errorf("No error should happen when find embedded pointer type, but got %v", err)
}
if hnPost.Title != NewPost.Title {
t.Errorf("Should find correct value for embedded pointer type")
}
if hnPost.Author.Name != NewPost.Author.Name {
t.Errorf("Expected to get Author name %v but got: %v", NewPost.Author.Name, hnPost.Author.Name)
}
if !reflect.DeepEqual(NewPost.Author.Content, hnPost.Author.Content) {
t.Errorf("Expected to get Author content %v but got: %v", NewPost.Author.Content, hnPost.Author.Content)
}
if hnPost.Author.ContentPtr != nil {
t.Errorf("Expected to get nil Author contentPtr but got: %v", hnPost.Author.ContentPtr)
}
if NewPost.Author.Birthday.UnixMilli() != hnPost.Author.Birthday.UnixMilli() {
t.Errorf("Expected to get Author birthday with %+v but got: %+v", NewPost.Author.Birthday, hnPost.Author.Birthday)
}
if hnPost.Author.BirthdayPtr != nil {
t.Errorf("Expected to get nil Author birthdayPtr but got: %+v", hnPost.Author.BirthdayPtr)
}
} }
type Content struct { type Content struct {
@ -142,18 +190,26 @@ type Content struct {
} }
func (c Content) Value() (driver.Value, error) { func (c Content) Value() (driver.Value, error) {
return json.Marshal(c) // mssql driver with issue on handling null bytes https://github.com/denisenkom/go-mssqldb/issues/530,
b, err := json.Marshal(c)
return string(b[:]), err
} }
func (c *Content) Scan(src interface{}) error { func (c *Content) Scan(src interface{}) error {
b, ok := src.([]byte)
if !ok {
return errors.New("Embedded.Scan byte assertion failed")
}
var value Content var value Content
if err := json.Unmarshal(b, &value); err != nil { str, ok := src.(string)
return err if !ok {
byt, ok := src.([]byte)
if !ok {
return errors.New("Embedded.Scan byte assertion failed")
}
if err := json.Unmarshal(byt, &value); err != nil {
return err
}
} else {
if err := json.Unmarshal([]byte(str), &value); err != nil {
return err
}
} }
*c = value *c = value

View File

@ -15,8 +15,8 @@ func TestDialectorWithErrorTranslatorSupport(t *testing.T) {
db, _ := gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}) db, _ := gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr})
err := db.AddError(untranslatedErr) err := db.AddError(untranslatedErr)
if errors.Is(err, translatedErr) { if !errors.Is(err, untranslatedErr) {
t.Fatalf("expected err: %v got err: %v", translatedErr, err) t.Fatalf("expected err: %v got err: %v", untranslatedErr, err)
} }
// it should translate error when the TranslateError flag is true // it should translate error when the TranslateError flag is true
@ -27,3 +27,85 @@ func TestDialectorWithErrorTranslatorSupport(t *testing.T) {
t.Fatalf("expected err: %v got err: %v", translatedErr, err) t.Fatalf("expected err: %v got err: %v", translatedErr, err)
} }
} }
func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) {
type City struct {
gorm.Model
Name string `gorm:"unique"`
}
db, err := OpenTestConnection(&gorm.Config{TranslateError: true})
if err != nil {
t.Fatalf("failed to connect database, got error %v", err)
}
dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
return
}
DB.Migrator().DropTable(&City{})
if err = db.AutoMigrate(&City{}); err != nil {
t.Fatalf("failed to migrate cities table, got error: %v", err)
}
err = db.Create(&City{Name: "Kabul"}).Error
if err != nil {
t.Fatalf("failed to create record: %v", err)
}
err = db.Create(&City{Name: "Kabul"}).Error
if !errors.Is(err, gorm.ErrDuplicatedKey) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
}
}
func TestSupportedDialectorWithErrForeignKeyViolated(t *testing.T) {
tidbSkip(t, "not support the foreign key feature")
type City struct {
gorm.Model
Name string `gorm:"unique"`
}
type Museum struct {
gorm.Model
Name string `gorm:"unique"`
CityID uint
City City `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:CityID;References:ID"`
}
db, err := OpenTestConnection(&gorm.Config{TranslateError: true})
if err != nil {
t.Fatalf("failed to connect database, got error %v", err)
}
dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
return
}
DB.Migrator().DropTable(&City{}, &Museum{})
if err = db.AutoMigrate(&City{}, &Museum{}); err != nil {
t.Fatalf("failed to migrate countries & cities tables, got error: %v", err)
}
city := City{Name: "Amsterdam"}
err = db.Create(&city).Error
if err != nil {
t.Fatalf("failed to create city: %v", err)
}
err = db.Create(&Museum{Name: "Eye Filmmuseum", CityID: city.ID}).Error
if err != nil {
t.Fatalf("failed to create museum: %v", err)
}
err = db.Create(&Museum{Name: "Dungeon", CityID: 123}).Error
if !errors.Is(err, gorm.ErrForeignKeyViolated) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrForeignKeyViolated, err)
}
}

View File

@ -3,9 +3,19 @@ package tests_test
import ( import (
"testing" "testing"
"gorm.io/driver/mysql"
"gorm.io/gorm" "gorm.io/gorm"
) )
func TestOpen(t *testing.T) {
dsn := "gorm:gorm@tcp(localhost:9910)/gorm?loc=Asia%2FHongKong" // invalid loc
_, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
if err == nil {
t.Fatalf("should returns error but got nil")
}
}
func TestReturningWithNullToZeroValues(t *testing.T) { func TestReturningWithNullToZeroValues(t *testing.T) {
dialect := DB.Dialector.Name() dialect := DB.Dialector.Name()
switch dialect { switch dialect {

View File

@ -92,7 +92,7 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
} }
func TestPreparedStmtDeadlock(t *testing.T) { func TestPreparedStmtDeadlock(t *testing.T) {
tx, err := OpenTestConnection() tx, err := OpenTestConnection(&gorm.Config{})
AssertEqual(t, err, nil) AssertEqual(t, err, nil)
sqlDB, _ := tx.DB() sqlDB, _ := tx.DB()
@ -127,7 +127,7 @@ func TestPreparedStmtDeadlock(t *testing.T) {
} }
func TestPreparedStmtError(t *testing.T) { func TestPreparedStmtError(t *testing.T) {
tx, err := OpenTestConnection() tx, err := OpenTestConnection(&gorm.Config{})
AssertEqual(t, err, nil) AssertEqual(t, err, nil)
sqlDB, _ := tx.DB() sqlDB, _ := tx.DB()

View File

@ -170,10 +170,10 @@ func (data *EncryptedData) Scan(value interface{}) error {
return errors.New("Too short") return errors.New("Too short")
} }
*data = b[3:] *data = append((*data)[0:], b[3:]...)
return nil return nil
} else if s, ok := value.(string); ok { } else if s, ok := value.(string); ok {
*data = []byte(s)[3:] *data = []byte(s[3:])
return nil return nil
} }

View File

@ -26,7 +26,7 @@ var (
func init() { func init() {
var err error var err error
if DB, err = OpenTestConnection(); err != nil { if DB, err = OpenTestConnection(&gorm.Config{}); err != nil {
log.Printf("failed to connect database, got error %v", err) log.Printf("failed to connect database, got error %v", err)
os.Exit(1) os.Exit(1)
} else { } else {
@ -49,7 +49,7 @@ func init() {
} }
} }
func OpenTestConnection() (db *gorm.DB, err error) { func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
dbDSN := os.Getenv("GORM_DSN") dbDSN := os.Getenv("GORM_DSN")
switch os.Getenv("GORM_DIALECT") { switch os.Getenv("GORM_DIALECT") {
case "mysql": case "mysql":
@ -57,7 +57,7 @@ func OpenTestConnection() (db *gorm.DB, err error) {
if dbDSN == "" { if dbDSN == "" {
dbDSN = mysqlDSN dbDSN = mysqlDSN
} }
db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) db, err = gorm.Open(mysql.Open(dbDSN), cfg)
case "postgres": case "postgres":
log.Println("testing postgres...") log.Println("testing postgres...")
if dbDSN == "" { if dbDSN == "" {
@ -66,7 +66,7 @@ func OpenTestConnection() (db *gorm.DB, err error) {
db, err = gorm.Open(postgres.New(postgres.Config{ db, err = gorm.Open(postgres.New(postgres.Config{
DSN: dbDSN, DSN: dbDSN,
PreferSimpleProtocol: true, PreferSimpleProtocol: true,
}), &gorm.Config{}) }), cfg)
case "sqlserver": case "sqlserver":
// go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest // go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest
// SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 // SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930
@ -80,16 +80,16 @@ func OpenTestConnection() (db *gorm.DB, err error) {
if dbDSN == "" { if dbDSN == "" {
dbDSN = sqlserverDSN dbDSN = sqlserverDSN
} }
db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{}) db, err = gorm.Open(sqlserver.Open(dbDSN), cfg)
case "tidb": case "tidb":
log.Println("testing tidb...") log.Println("testing tidb...")
if dbDSN == "" { if dbDSN == "" {
dbDSN = tidbDSN dbDSN = tidbDSN
} }
db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) db, err = gorm.Open(mysql.Open(dbDSN), cfg)
default: default:
log.Println("testing sqlite3...") log.Println("testing sqlite3...")
db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg)
} }
if err != nil { if err != nil {

View File

@ -57,6 +57,19 @@ func TestTransaction(t *testing.T) {
if err := DB.First(&User{}, "name = ?", "transaction-2").Error; err != nil { if err := DB.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
t.Fatalf("Should be able to find committed record, but got %v", err) t.Fatalf("Should be able to find committed record, but got %v", err)
} }
t.Run("this is test nested transaction and prepareStmt coexist case", func(t *testing.T) {
// enable prepare statement
tx3 := DB.Session(&gorm.Session{PrepareStmt: true})
if err := tx3.Transaction(func(tx4 *gorm.DB) error {
// nested transaction
return tx4.Transaction(func(tx5 *gorm.DB) error {
return tx5.First(&User{}, "name = ?", "transaction-2").Error
})
}); err != nil {
t.Fatalf("prepare statement and nested transcation coexist" + err.Error())
}
})
} }
func TestCancelTransaction(t *testing.T) { func TestCancelTransaction(t *testing.T) {
@ -348,7 +361,7 @@ func TestDisabledNestedTransaction(t *testing.T) {
} }
func TestTransactionOnClosedConn(t *testing.T) { func TestTransactionOnClosedConn(t *testing.T) {
DB, err := OpenTestConnection() DB, err := OpenTestConnection(&gorm.Config{})
if err != nil { if err != nil {
t.Fatalf("failed to connect database, got error %v", err) t.Fatalf("failed to connect database, got error %v", err)
} }

View File

@ -208,13 +208,17 @@ func TestUpdateColumn(t *testing.T) {
CheckUser(t, user1, *users[0]) CheckUser(t, user1, *users[0])
CheckUser(t, user2, *users[1]) CheckUser(t, user2, *users[1])
DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew") DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew").UpdateColumn("age", 19)
AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano())
if users[1].Name != "update_column_02_newnew" { if users[1].Name != "update_column_02_newnew" {
t.Errorf("user 2's name should be updated, but got %v", users[1].Name) t.Errorf("user 2's name should be updated, but got %v", users[1].Name)
} }
if users[1].Age != 19 {
t.Errorf("user 2's name should be updated, but got %v", users[1].Age)
}
DB.Model(users[1]).UpdateColumn("age", gorm.Expr("age + 100 - 50")) DB.Model(users[1]).UpdateColumn("age", gorm.Expr("age + 100 - 50"))
var user3 User var user3 User
DB.First(&user3, users[1].ID) DB.First(&user3, users[1].ID)
@ -805,3 +809,76 @@ func TestUpdateWithDiffSchema(t *testing.T) {
AssertEqual(t, err, nil) AssertEqual(t, err, nil)
AssertEqual(t, "update-diff-schema-2", user.Name) AssertEqual(t, "update-diff-schema-2", user.Name)
} }
type TokenOwner struct {
ID int
Name string
Token Token `gorm:"foreignKey:UserID"`
}
func (t *TokenOwner) BeforeSave(tx *gorm.DB) error {
t.Name += "_name"
return nil
}
type Token struct {
UserID int `gorm:"primary_key"`
Content string `gorm:"type:varchar(100)"`
}
func (t *Token) BeforeSave(tx *gorm.DB) error {
t.Content += "_encrypted"
return nil
}
func TestSaveWithHooks(t *testing.T) {
DB.Migrator().DropTable(&Token{}, &TokenOwner{})
DB.AutoMigrate(&Token{}, &TokenOwner{})
saveTokenOwner := func(owner *TokenOwner) (*TokenOwner, error) {
var newOwner TokenOwner
if err := DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Debug().Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil {
return err
}
if err := tx.Preload("Token").First(&newOwner, owner.ID).Error; err != nil {
return err
}
return nil
}); err != nil {
return nil, err
}
return &newOwner, nil
}
owner := TokenOwner{
Name: "user",
Token: Token{Content: "token"},
}
o1, err := saveTokenOwner(&owner)
if err != nil {
t.Errorf("failed to save token owner, got error: %v", err)
}
if o1.Name != "user_name" {
t.Errorf(`owner name should be "user_name", but got: "%s"`, o1.Name)
}
if o1.Token.Content != "token_encrypted" {
t.Errorf(`token content should be "token_encrypted", but got: "%s"`, o1.Token.Content)
}
owner = TokenOwner{
ID: owner.ID,
Name: "user",
Token: Token{Content: "token2"},
}
o2, err := saveTokenOwner(&owner)
if err != nil {
t.Errorf("failed to save token owner, got error: %v", err)
}
if o2.Name != "user_name" {
t.Errorf(`owner name should be "user_name", but got: "%s"`, o2.Name)
}
if o2.Token.Content != "token2_encrypted" {
t.Errorf(`token content should be "token2_encrypted", but got: "%s"`, o2.Token.Content)
}
}

View File

@ -11,7 +11,7 @@ import (
// He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table) // He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table)
// He speaks many languages (many to many) and has many friends (many to many - single-table) // He speaks many languages (many to many) and has many friends (many to many - single-table)
// His pet also has one Toy (has one - polymorphic) // His pet also has one Toy (has one - polymorphic)
// NamedPet is a reference to a Named `Pets` (has many) // NamedPet is a reference to a named `Pet` (has one)
type User struct { type User struct {
gorm.Model gorm.Model
Name string Name string