Merge branch 'master' into fix_drop_column
This commit is contained in:
commit
39fa1636ba
46
.github/workflows/tests.yml
vendored
46
.github/workflows/tests.yml
vendored
@ -41,7 +41,7 @@ jobs:
|
||||
mysql:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
|
||||
dbversion: ['mysql:latest', 'mysql:5.7']
|
||||
go: ['1.19', '1.18']
|
||||
platform: [ubuntu-latest]
|
||||
runs-on: ${{ matrix.platform }}
|
||||
@ -72,7 +72,6 @@ jobs:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
@ -82,6 +81,49 @@ jobs:
|
||||
- name: Tests
|
||||
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
|
||||
|
||||
mariadb:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: [ 'mariadb:latest' ]
|
||||
go: [ '1.19', '1.18' ]
|
||||
platform: [ ubuntu-latest ]
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
services:
|
||||
mysql:
|
||||
image: ${{ matrix.dbversion }}
|
||||
env:
|
||||
MYSQL_DATABASE: gorm
|
||||
MYSQL_USER: gorm
|
||||
MYSQL_PASSWORD: gorm
|
||||
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||
ports:
|
||||
- 9910:3306
|
||||
options: >-
|
||||
--health-cmd "mariadb-admin ping -ugorm -pgorm"
|
||||
--health-interval 10s
|
||||
--health-start-period 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 10
|
||||
|
||||
steps:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
|
||||
- name: Tests
|
||||
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
|
||||
|
||||
postgres:
|
||||
strategy:
|
||||
matrix:
|
||||
|
@ -72,6 +72,7 @@ func Update(config *Config) func(db *gorm.DB) {
|
||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||
if _, ok := db.Statement.Clauses["SET"]; !ok {
|
||||
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
||||
defer delete(db.Statement.Clauses, "SET")
|
||||
db.Statement.AddClause(set)
|
||||
} else {
|
||||
return
|
||||
|
@ -47,4 +47,6 @@ var (
|
||||
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
|
||||
// ErrDuplicatedKey occurs when there is a unique key constraint violation
|
||||
ErrDuplicatedKey = errors.New("duplicated key not allowed")
|
||||
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
|
||||
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
|
||||
)
|
||||
|
@ -6,8 +6,6 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
@ -107,7 +105,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
||||
updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true}))
|
||||
|
||||
if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate {
|
||||
return tx.Clauses(clause.OnConflict{UpdateAll: true}).Create(value)
|
||||
return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value)
|
||||
}
|
||||
|
||||
return updateTx
|
||||
@ -533,6 +531,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||
tx.ScanRows(rows, dest)
|
||||
} else {
|
||||
tx.RowsAffected = 0
|
||||
tx.AddError(rows.Err())
|
||||
}
|
||||
tx.AddError(rows.Close())
|
||||
}
|
||||
@ -611,15 +610,6 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) {
|
||||
return fc(tx)
|
||||
}
|
||||
|
||||
var (
|
||||
savepointIdx int64
|
||||
savepointNamePool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return fmt.Sprintf("gorm_%d", atomic.AddInt64(&savepointIdx, 1))
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
|
||||
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
|
||||
// they are rolled back.
|
||||
@ -629,17 +619,14 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||
// nested transaction
|
||||
if !db.DisableNestedTransaction {
|
||||
poolName := savepointNamePool.Get()
|
||||
defer savepointNamePool.Put(poolName)
|
||||
err = db.SavePoint(poolName.(string)).Error
|
||||
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// Make sure to rollback when panic, Block error or Commit error
|
||||
if panicked || err != nil {
|
||||
db.RollbackTo(poolName.(string))
|
||||
db.RollbackTo(fmt.Sprintf("sp%p", fc))
|
||||
}
|
||||
}()
|
||||
}
|
||||
@ -720,7 +707,21 @@ func (db *DB) Rollback() *DB {
|
||||
|
||||
func (db *DB) SavePoint(name string) *DB {
|
||||
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
||||
// close prepared statement, because SavePoint not support prepared statement.
|
||||
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
|
||||
var (
|
||||
preparedStmtTx *PreparedStmtTX
|
||||
isPreparedStmtTx bool
|
||||
)
|
||||
// close prepared statement, because SavePoint not support prepared statement.
|
||||
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx.Tx
|
||||
}
|
||||
db.AddError(savePointer.SavePoint(db, name))
|
||||
// restore prepared statement
|
||||
if isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx
|
||||
}
|
||||
} else {
|
||||
db.AddError(ErrUnsupportedDriver)
|
||||
}
|
||||
@ -729,7 +730,21 @@ func (db *DB) SavePoint(name string) *DB {
|
||||
|
||||
func (db *DB) RollbackTo(name string) *DB {
|
||||
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
||||
// close prepared statement, because RollbackTo not support prepared statement.
|
||||
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
|
||||
var (
|
||||
preparedStmtTx *PreparedStmtTX
|
||||
isPreparedStmtTx bool
|
||||
)
|
||||
// close prepared statement, because SavePoint not support prepared statement.
|
||||
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx.Tx
|
||||
}
|
||||
db.AddError(savePointer.RollbackTo(db, name))
|
||||
// restore prepared statement
|
||||
if isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx
|
||||
}
|
||||
} else {
|
||||
db.AddError(ErrUnsupportedDriver)
|
||||
}
|
||||
|
62
gorm.go
62
gorm.go
@ -146,7 +146,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
||||
}
|
||||
|
||||
if config.NamingStrategy == nil {
|
||||
config.NamingStrategy = schema.NamingStrategy{}
|
||||
config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
|
||||
}
|
||||
|
||||
if config.Logger == nil {
|
||||
@ -187,15 +187,9 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
preparedStmt := &PreparedStmtDB{
|
||||
ConnPool: db.ConnPool,
|
||||
Stmts: make(map[string]*Stmt),
|
||||
Mux: &sync.RWMutex{},
|
||||
PreparedSQL: make([]string, 0, 100),
|
||||
}
|
||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||
|
||||
if config.PrepareStmt {
|
||||
preparedStmt := NewPreparedStmtDB(db.ConnPool)
|
||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||
db.ConnPool = preparedStmt
|
||||
}
|
||||
|
||||
@ -256,24 +250,30 @@ func (db *DB) Session(config *Session) *DB {
|
||||
}
|
||||
|
||||
if config.PrepareStmt {
|
||||
var preparedStmt *PreparedStmtDB
|
||||
|
||||
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
|
||||
preparedStmt := v.(*PreparedStmtDB)
|
||||
switch t := tx.Statement.ConnPool.(type) {
|
||||
case Tx:
|
||||
tx.Statement.ConnPool = &PreparedStmtTX{
|
||||
Tx: t,
|
||||
PreparedStmtDB: preparedStmt,
|
||||
}
|
||||
default:
|
||||
tx.Statement.ConnPool = &PreparedStmtDB{
|
||||
ConnPool: db.Config.ConnPool,
|
||||
Mux: preparedStmt.Mux,
|
||||
Stmts: preparedStmt.Stmts,
|
||||
}
|
||||
}
|
||||
txConfig.ConnPool = tx.Statement.ConnPool
|
||||
txConfig.PrepareStmt = true
|
||||
preparedStmt = v.(*PreparedStmtDB)
|
||||
} else {
|
||||
preparedStmt = NewPreparedStmtDB(db.ConnPool)
|
||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||
}
|
||||
|
||||
switch t := tx.Statement.ConnPool.(type) {
|
||||
case Tx:
|
||||
tx.Statement.ConnPool = &PreparedStmtTX{
|
||||
Tx: t,
|
||||
PreparedStmtDB: preparedStmt,
|
||||
}
|
||||
default:
|
||||
tx.Statement.ConnPool = &PreparedStmtDB{
|
||||
ConnPool: db.Config.ConnPool,
|
||||
Mux: preparedStmt.Mux,
|
||||
Stmts: preparedStmt.Stmts,
|
||||
}
|
||||
}
|
||||
txConfig.ConnPool = tx.Statement.ConnPool
|
||||
txConfig.PrepareStmt = true
|
||||
}
|
||||
|
||||
if config.SkipHooks {
|
||||
@ -375,11 +375,17 @@ func (db *DB) AddError(err error) error {
|
||||
func (db *DB) DB() (*sql.DB, error) {
|
||||
connPool := db.ConnPool
|
||||
|
||||
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
return dbConnector.GetDBConn()
|
||||
if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil {
|
||||
return connector.GetDBConnWithContext(db)
|
||||
}
|
||||
|
||||
if sqldb, ok := connPool.(*sql.DB); ok {
|
||||
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil {
|
||||
return sqldb, err
|
||||
}
|
||||
}
|
||||
|
||||
if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil {
|
||||
return sqldb, nil
|
||||
}
|
||||
|
||||
|
@ -77,6 +77,12 @@ type GetDBConnector interface {
|
||||
GetDBConn() (*sql.DB, error)
|
||||
}
|
||||
|
||||
// GetDBConnectorWithContext represents SQL db connector which takes into
|
||||
// account the current database context
|
||||
type GetDBConnectorWithContext interface {
|
||||
GetDBConnWithContext(db *DB) (*sql.DB, error)
|
||||
}
|
||||
|
||||
// Rows rows interface
|
||||
type Rows interface {
|
||||
Columns() ([]string, error)
|
||||
|
@ -3,6 +3,7 @@ package gorm
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@ -20,6 +21,15 @@ type PreparedStmtDB struct {
|
||||
ConnPool
|
||||
}
|
||||
|
||||
func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
|
||||
return &PreparedStmtDB{
|
||||
ConnPool: connPool,
|
||||
Stmts: make(map[string]*Stmt),
|
||||
Mux: &sync.RWMutex{},
|
||||
PreparedSQL: make([]string, 0, 100),
|
||||
}
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
return dbConnector.GetDBConn()
|
||||
@ -163,14 +173,14 @@ type PreparedStmtTX struct {
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Commit() error {
|
||||
if tx.Tx != nil {
|
||||
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
||||
return tx.Tx.Commit()
|
||||
}
|
||||
return ErrInvalidTransaction
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Rollback() error {
|
||||
if tx.Tx != nil {
|
||||
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
||||
return tx.Tx.Rollback()
|
||||
}
|
||||
return ErrInvalidTransaction
|
||||
|
@ -846,7 +846,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
|
||||
switch data := v.(type) {
|
||||
case **time.Time:
|
||||
if data != nil {
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data))
|
||||
}
|
||||
case time.Time:
|
||||
@ -882,14 +882,12 @@ func (field *Field) setupValuerAndSetter() {
|
||||
reflectV := reflect.ValueOf(v)
|
||||
if !reflectV.IsValid() {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
||||
return
|
||||
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||
field.ReflectValueOf(ctx, value).Set(reflectV)
|
||||
} else if reflectV.Kind() == reflect.Ptr {
|
||||
if reflectV.IsNil() || !reflectV.IsValid() {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else {
|
||||
return field.Set(ctx, value, reflectV.Elem().Interface())
|
||||
}
|
||||
return field.Set(ctx, value, reflectV.Elem().Interface())
|
||||
} else {
|
||||
fieldValue := field.ReflectValueOf(ctx, value)
|
||||
if fieldValue.IsNil() {
|
||||
@ -910,14 +908,12 @@ func (field *Field) setupValuerAndSetter() {
|
||||
reflectV := reflect.ValueOf(v)
|
||||
if !reflectV.IsValid() {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
||||
return
|
||||
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||
field.ReflectValueOf(ctx, value).Set(reflectV)
|
||||
} else if reflectV.Kind() == reflect.Ptr {
|
||||
if reflectV.IsNil() || !reflectV.IsValid() {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else {
|
||||
return field.Set(ctx, value, reflectV.Elem().Interface())
|
||||
}
|
||||
return field.Set(ctx, value, reflectV.Elem().Interface())
|
||||
} else {
|
||||
if valuer, ok := v.(driver.Valuer); ok {
|
||||
v, _ = valuer.Value()
|
||||
|
@ -28,10 +28,11 @@ type Replacer interface {
|
||||
|
||||
// NamingStrategy tables, columns naming strategy
|
||||
type NamingStrategy struct {
|
||||
TablePrefix string
|
||||
SingularTable bool
|
||||
NameReplacer Replacer
|
||||
NoLowerCase bool
|
||||
TablePrefix string
|
||||
SingularTable bool
|
||||
NameReplacer Replacer
|
||||
NoLowerCase bool
|
||||
IdentifierMaxLength int
|
||||
}
|
||||
|
||||
// TableName convert string to table name
|
||||
@ -89,12 +90,16 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string {
|
||||
prefix, table, name,
|
||||
}, "_"), ".", "_")
|
||||
|
||||
if utf8.RuneCountInString(formattedName) > 64 {
|
||||
if ns.IdentifierMaxLength == 0 {
|
||||
ns.IdentifierMaxLength = 64
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength {
|
||||
h := sha1.New()
|
||||
h.Write([]byte(formattedName))
|
||||
bs := h.Sum(nil)
|
||||
|
||||
formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8]
|
||||
formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8]
|
||||
}
|
||||
return formattedName
|
||||
}
|
||||
|
@ -189,8 +189,17 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatNameWithStringLongerThan63Characters(t *testing.T) {
|
||||
ns := NamingStrategy{IdentifierMaxLength: 63}
|
||||
|
||||
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
|
||||
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVer180f2c67" {
|
||||
t.Errorf("invalid formatted name generated, got %v", formattedName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatNameWithStringLongerThan64Characters(t *testing.T) {
|
||||
ns := NamingStrategy{}
|
||||
ns := NamingStrategy{IdentifierMaxLength: 64}
|
||||
|
||||
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
|
||||
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" {
|
||||
|
@ -768,7 +768,7 @@ func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) {
|
||||
s, err := schema.Parse(
|
||||
&Book{},
|
||||
&sync.Map{},
|
||||
schema.NamingStrategy{},
|
||||
schema.NamingStrategy{IdentifierMaxLength: 64},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema")
|
||||
|
@ -358,7 +358,7 @@ func TestDuplicateMany2ManyAssociation(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConcurrentMany2ManyAssociation(t *testing.T) {
|
||||
db, err := OpenTestConnection()
|
||||
db, err := OpenTestConnection(&gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("open test connection failed, err: %+v", err)
|
||||
}
|
||||
|
@ -4,7 +4,9 @@ import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
@ -104,10 +106,14 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
|
||||
}
|
||||
|
||||
type Author struct {
|
||||
ID string
|
||||
Name string
|
||||
Email string
|
||||
Age int
|
||||
ID string
|
||||
Name string
|
||||
Email string
|
||||
Age int
|
||||
Content Content
|
||||
ContentPtr *Content
|
||||
Birthday time.Time
|
||||
BirthdayPtr *time.Time
|
||||
}
|
||||
|
||||
type HNPost struct {
|
||||
@ -135,6 +141,48 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
|
||||
if hnPost.Author != nil {
|
||||
t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author)
|
||||
}
|
||||
|
||||
now := time.Now().Round(time.Second)
|
||||
NewPost := HNPost{
|
||||
BasePost: &BasePost{Title: "embedded_pointer_type2"},
|
||||
Author: &Author{
|
||||
Name: "test",
|
||||
Content: Content{"test"},
|
||||
ContentPtr: nil,
|
||||
Birthday: now,
|
||||
BirthdayPtr: nil,
|
||||
},
|
||||
}
|
||||
DB.Create(&NewPost)
|
||||
|
||||
hnPost = HNPost{}
|
||||
if err := DB.First(&hnPost, "title = ?", NewPost.Title).Error; err != nil {
|
||||
t.Errorf("No error should happen when find embedded pointer type, but got %v", err)
|
||||
}
|
||||
|
||||
if hnPost.Title != NewPost.Title {
|
||||
t.Errorf("Should find correct value for embedded pointer type")
|
||||
}
|
||||
|
||||
if hnPost.Author.Name != NewPost.Author.Name {
|
||||
t.Errorf("Expected to get Author name %v but got: %v", NewPost.Author.Name, hnPost.Author.Name)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(NewPost.Author.Content, hnPost.Author.Content) {
|
||||
t.Errorf("Expected to get Author content %v but got: %v", NewPost.Author.Content, hnPost.Author.Content)
|
||||
}
|
||||
|
||||
if hnPost.Author.ContentPtr != nil {
|
||||
t.Errorf("Expected to get nil Author contentPtr but got: %v", hnPost.Author.ContentPtr)
|
||||
}
|
||||
|
||||
if NewPost.Author.Birthday.UnixMilli() != hnPost.Author.Birthday.UnixMilli() {
|
||||
t.Errorf("Expected to get Author birthday with %+v but got: %+v", NewPost.Author.Birthday, hnPost.Author.Birthday)
|
||||
}
|
||||
|
||||
if hnPost.Author.BirthdayPtr != nil {
|
||||
t.Errorf("Expected to get nil Author birthdayPtr but got: %+v", hnPost.Author.BirthdayPtr)
|
||||
}
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
@ -142,18 +190,26 @@ type Content struct {
|
||||
}
|
||||
|
||||
func (c Content) Value() (driver.Value, error) {
|
||||
return json.Marshal(c)
|
||||
// mssql driver with issue on handling null bytes https://github.com/denisenkom/go-mssqldb/issues/530,
|
||||
b, err := json.Marshal(c)
|
||||
return string(b[:]), err
|
||||
}
|
||||
|
||||
func (c *Content) Scan(src interface{}) error {
|
||||
b, ok := src.([]byte)
|
||||
if !ok {
|
||||
return errors.New("Embedded.Scan byte assertion failed")
|
||||
}
|
||||
|
||||
var value Content
|
||||
if err := json.Unmarshal(b, &value); err != nil {
|
||||
return err
|
||||
str, ok := src.(string)
|
||||
if !ok {
|
||||
byt, ok := src.([]byte)
|
||||
if !ok {
|
||||
return errors.New("Embedded.Scan byte assertion failed")
|
||||
}
|
||||
if err := json.Unmarshal(byt, &value); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := json.Unmarshal([]byte(str), &value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
*c = value
|
||||
|
@ -15,8 +15,8 @@ func TestDialectorWithErrorTranslatorSupport(t *testing.T) {
|
||||
db, _ := gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr})
|
||||
|
||||
err := db.AddError(untranslatedErr)
|
||||
if errors.Is(err, translatedErr) {
|
||||
t.Fatalf("expected err: %v got err: %v", translatedErr, err)
|
||||
if !errors.Is(err, untranslatedErr) {
|
||||
t.Fatalf("expected err: %v got err: %v", untranslatedErr, err)
|
||||
}
|
||||
|
||||
// it should translate error when the TranslateError flag is true
|
||||
@ -27,3 +27,85 @@ func TestDialectorWithErrorTranslatorSupport(t *testing.T) {
|
||||
t.Fatalf("expected err: %v got err: %v", translatedErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) {
|
||||
type City struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"unique"`
|
||||
}
|
||||
|
||||
db, err := OpenTestConnection(&gorm.Config{TranslateError: true})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect database, got error %v", err)
|
||||
}
|
||||
|
||||
dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
|
||||
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
|
||||
return
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&City{})
|
||||
|
||||
if err = db.AutoMigrate(&City{}); err != nil {
|
||||
t.Fatalf("failed to migrate cities table, got error: %v", err)
|
||||
}
|
||||
|
||||
err = db.Create(&City{Name: "Kabul"}).Error
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create record: %v", err)
|
||||
}
|
||||
|
||||
err = db.Create(&City{Name: "Kabul"}).Error
|
||||
if !errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportedDialectorWithErrForeignKeyViolated(t *testing.T) {
|
||||
tidbSkip(t, "not support the foreign key feature")
|
||||
|
||||
type City struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"unique"`
|
||||
}
|
||||
|
||||
type Museum struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"unique"`
|
||||
CityID uint
|
||||
City City `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:CityID;References:ID"`
|
||||
}
|
||||
|
||||
db, err := OpenTestConnection(&gorm.Config{TranslateError: true})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect database, got error %v", err)
|
||||
}
|
||||
|
||||
dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
|
||||
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
|
||||
return
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&City{}, &Museum{})
|
||||
|
||||
if err = db.AutoMigrate(&City{}, &Museum{}); err != nil {
|
||||
t.Fatalf("failed to migrate countries & cities tables, got error: %v", err)
|
||||
}
|
||||
|
||||
city := City{Name: "Amsterdam"}
|
||||
|
||||
err = db.Create(&city).Error
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create city: %v", err)
|
||||
}
|
||||
|
||||
err = db.Create(&Museum{Name: "Eye Filmmuseum", CityID: city.ID}).Error
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create museum: %v", err)
|
||||
}
|
||||
|
||||
err = db.Create(&Museum{Name: "Dungeon", CityID: 123}).Error
|
||||
if !errors.Is(err, gorm.ErrForeignKeyViolated) {
|
||||
t.Fatalf("expected err: %v got err: %v", gorm.ErrForeignKeyViolated, err)
|
||||
}
|
||||
}
|
||||
|
@ -3,9 +3,19 @@ package tests_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/driver/mysql"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestOpen(t *testing.T) {
|
||||
dsn := "gorm:gorm@tcp(localhost:9910)/gorm?loc=Asia%2FHongKong" // invalid loc
|
||||
_, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
||||
if err == nil {
|
||||
t.Fatalf("should returns error but got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReturningWithNullToZeroValues(t *testing.T) {
|
||||
dialect := DB.Dialector.Name()
|
||||
switch dialect {
|
||||
|
@ -92,7 +92,7 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPreparedStmtDeadlock(t *testing.T) {
|
||||
tx, err := OpenTestConnection()
|
||||
tx, err := OpenTestConnection(&gorm.Config{})
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
sqlDB, _ := tx.DB()
|
||||
@ -127,7 +127,7 @@ func TestPreparedStmtDeadlock(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPreparedStmtError(t *testing.T) {
|
||||
tx, err := OpenTestConnection()
|
||||
tx, err := OpenTestConnection(&gorm.Config{})
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
sqlDB, _ := tx.DB()
|
||||
|
@ -170,10 +170,10 @@ func (data *EncryptedData) Scan(value interface{}) error {
|
||||
return errors.New("Too short")
|
||||
}
|
||||
|
||||
*data = b[3:]
|
||||
*data = append((*data)[0:], b[3:]...)
|
||||
return nil
|
||||
} else if s, ok := value.(string); ok {
|
||||
*data = []byte(s)[3:]
|
||||
*data = []byte(s[3:])
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -26,7 +26,7 @@ var (
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
if DB, err = OpenTestConnection(); err != nil {
|
||||
if DB, err = OpenTestConnection(&gorm.Config{}); err != nil {
|
||||
log.Printf("failed to connect database, got error %v", err)
|
||||
os.Exit(1)
|
||||
} else {
|
||||
@ -49,7 +49,7 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
func OpenTestConnection() (db *gorm.DB, err error) {
|
||||
func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
|
||||
dbDSN := os.Getenv("GORM_DSN")
|
||||
switch os.Getenv("GORM_DIALECT") {
|
||||
case "mysql":
|
||||
@ -57,7 +57,7 @@ func OpenTestConnection() (db *gorm.DB, err error) {
|
||||
if dbDSN == "" {
|
||||
dbDSN = mysqlDSN
|
||||
}
|
||||
db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{})
|
||||
db, err = gorm.Open(mysql.Open(dbDSN), cfg)
|
||||
case "postgres":
|
||||
log.Println("testing postgres...")
|
||||
if dbDSN == "" {
|
||||
@ -66,7 +66,7 @@ func OpenTestConnection() (db *gorm.DB, err error) {
|
||||
db, err = gorm.Open(postgres.New(postgres.Config{
|
||||
DSN: dbDSN,
|
||||
PreferSimpleProtocol: true,
|
||||
}), &gorm.Config{})
|
||||
}), cfg)
|
||||
case "sqlserver":
|
||||
// go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest
|
||||
// SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930
|
||||
@ -80,16 +80,16 @@ func OpenTestConnection() (db *gorm.DB, err error) {
|
||||
if dbDSN == "" {
|
||||
dbDSN = sqlserverDSN
|
||||
}
|
||||
db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{})
|
||||
db, err = gorm.Open(sqlserver.Open(dbDSN), cfg)
|
||||
case "tidb":
|
||||
log.Println("testing tidb...")
|
||||
if dbDSN == "" {
|
||||
dbDSN = tidbDSN
|
||||
}
|
||||
db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{})
|
||||
db, err = gorm.Open(mysql.Open(dbDSN), cfg)
|
||||
default:
|
||||
log.Println("testing sqlite3...")
|
||||
db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{})
|
||||
db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -57,6 +57,19 @@ func TestTransaction(t *testing.T) {
|
||||
if err := DB.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
|
||||
t.Fatalf("Should be able to find committed record, but got %v", err)
|
||||
}
|
||||
|
||||
t.Run("this is test nested transaction and prepareStmt coexist case", func(t *testing.T) {
|
||||
// enable prepare statement
|
||||
tx3 := DB.Session(&gorm.Session{PrepareStmt: true})
|
||||
if err := tx3.Transaction(func(tx4 *gorm.DB) error {
|
||||
// nested transaction
|
||||
return tx4.Transaction(func(tx5 *gorm.DB) error {
|
||||
return tx5.First(&User{}, "name = ?", "transaction-2").Error
|
||||
})
|
||||
}); err != nil {
|
||||
t.Fatalf("prepare statement and nested transcation coexist" + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCancelTransaction(t *testing.T) {
|
||||
@ -348,7 +361,7 @@ func TestDisabledNestedTransaction(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTransactionOnClosedConn(t *testing.T) {
|
||||
DB, err := OpenTestConnection()
|
||||
DB, err := OpenTestConnection(&gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect database, got error %v", err)
|
||||
}
|
||||
|
@ -208,13 +208,17 @@ func TestUpdateColumn(t *testing.T) {
|
||||
CheckUser(t, user1, *users[0])
|
||||
CheckUser(t, user2, *users[1])
|
||||
|
||||
DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew")
|
||||
DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew").UpdateColumn("age", 19)
|
||||
AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano())
|
||||
|
||||
if users[1].Name != "update_column_02_newnew" {
|
||||
t.Errorf("user 2's name should be updated, but got %v", users[1].Name)
|
||||
}
|
||||
|
||||
if users[1].Age != 19 {
|
||||
t.Errorf("user 2's name should be updated, but got %v", users[1].Age)
|
||||
}
|
||||
|
||||
DB.Model(users[1]).UpdateColumn("age", gorm.Expr("age + 100 - 50"))
|
||||
var user3 User
|
||||
DB.First(&user3, users[1].ID)
|
||||
@ -805,3 +809,76 @@ func TestUpdateWithDiffSchema(t *testing.T) {
|
||||
AssertEqual(t, err, nil)
|
||||
AssertEqual(t, "update-diff-schema-2", user.Name)
|
||||
}
|
||||
|
||||
type TokenOwner struct {
|
||||
ID int
|
||||
Name string
|
||||
Token Token `gorm:"foreignKey:UserID"`
|
||||
}
|
||||
|
||||
func (t *TokenOwner) BeforeSave(tx *gorm.DB) error {
|
||||
t.Name += "_name"
|
||||
return nil
|
||||
}
|
||||
|
||||
type Token struct {
|
||||
UserID int `gorm:"primary_key"`
|
||||
Content string `gorm:"type:varchar(100)"`
|
||||
}
|
||||
|
||||
func (t *Token) BeforeSave(tx *gorm.DB) error {
|
||||
t.Content += "_encrypted"
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSaveWithHooks(t *testing.T) {
|
||||
DB.Migrator().DropTable(&Token{}, &TokenOwner{})
|
||||
DB.AutoMigrate(&Token{}, &TokenOwner{})
|
||||
|
||||
saveTokenOwner := func(owner *TokenOwner) (*TokenOwner, error) {
|
||||
var newOwner TokenOwner
|
||||
if err := DB.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Debug().Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Preload("Token").First(&newOwner, owner.ID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &newOwner, nil
|
||||
}
|
||||
|
||||
owner := TokenOwner{
|
||||
Name: "user",
|
||||
Token: Token{Content: "token"},
|
||||
}
|
||||
o1, err := saveTokenOwner(&owner)
|
||||
if err != nil {
|
||||
t.Errorf("failed to save token owner, got error: %v", err)
|
||||
}
|
||||
if o1.Name != "user_name" {
|
||||
t.Errorf(`owner name should be "user_name", but got: "%s"`, o1.Name)
|
||||
}
|
||||
if o1.Token.Content != "token_encrypted" {
|
||||
t.Errorf(`token content should be "token_encrypted", but got: "%s"`, o1.Token.Content)
|
||||
}
|
||||
|
||||
owner = TokenOwner{
|
||||
ID: owner.ID,
|
||||
Name: "user",
|
||||
Token: Token{Content: "token2"},
|
||||
}
|
||||
o2, err := saveTokenOwner(&owner)
|
||||
if err != nil {
|
||||
t.Errorf("failed to save token owner, got error: %v", err)
|
||||
}
|
||||
if o2.Name != "user_name" {
|
||||
t.Errorf(`owner name should be "user_name", but got: "%s"`, o2.Name)
|
||||
}
|
||||
if o2.Token.Content != "token2_encrypted" {
|
||||
t.Errorf(`token content should be "token2_encrypted", but got: "%s"`, o2.Token.Content)
|
||||
}
|
||||
}
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
// He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table)
|
||||
// He speaks many languages (many to many) and has many friends (many to many - single-table)
|
||||
// His pet also has one Toy (has one - polymorphic)
|
||||
// NamedPet is a reference to a Named `Pets` (has many)
|
||||
// NamedPet is a reference to a named `Pet` (has one)
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
|
Loading…
x
Reference in New Issue
Block a user