Merge branch 'go-gorm:master' into master
This commit is contained in:
commit
da07675bae
1
.gitignore
vendored
1
.gitignore
vendored
@ -3,3 +3,4 @@ documents
|
||||
coverage.txt
|
||||
_book
|
||||
.idea
|
||||
vendor
|
@ -201,13 +201,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
rValLen := stmt.ReflectValue.Len()
|
||||
stmt.SQL.Grow(rValLen * 18)
|
||||
values.Values = make([][]interface{}, rValLen)
|
||||
if rValLen == 0 {
|
||||
stmt.AddError(gorm.ErrEmptySlice)
|
||||
return
|
||||
}
|
||||
|
||||
stmt.SQL.Grow(rValLen * 18)
|
||||
stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
|
||||
values.Values = make([][]interface{}, rValLen)
|
||||
|
||||
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
|
||||
for i := 0; i < rValLen; i++ {
|
||||
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
||||
|
@ -255,6 +255,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
|
||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
||||
@ -599,11 +600,12 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
||||
opt = opts[0]
|
||||
}
|
||||
|
||||
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
|
||||
switch beginner := tx.Statement.ConnPool.(type) {
|
||||
case TxBeginner:
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
|
||||
case ConnPoolBeginner:
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||
} else {
|
||||
default:
|
||||
err = ErrInvalidTransaction
|
||||
}
|
||||
|
||||
|
2
gorm.go
2
gorm.go
@ -462,7 +462,7 @@ func (db *DB) Use(plugin Plugin) error {
|
||||
// .First(&User{})
|
||||
// })
|
||||
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
|
||||
tx := queryFn(db.Session(&Session{DryRun: true}))
|
||||
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
|
||||
stmt := tx.Statement
|
||||
|
||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||
|
@ -56,6 +56,13 @@ type TxCommitter interface {
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
// Tx sql.Tx interface
|
||||
type Tx interface {
|
||||
ConnPool
|
||||
TxCommitter
|
||||
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
|
||||
}
|
||||
|
||||
// Valuer gorm valuer interface
|
||||
type Valuer interface {
|
||||
GormValue(context.Context, *DB) clause.Expr
|
||||
|
@ -115,7 +115,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
|
||||
}
|
||||
|
||||
type PreparedStmtTX struct {
|
||||
*sql.Tx
|
||||
Tx
|
||||
PreparedStmtDB *PreparedStmtDB
|
||||
}
|
||||
|
||||
@ -151,7 +151,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
||||
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||
if err == nil {
|
||||
rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
|
||||
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
|
||||
if err != nil {
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||
|
111
scan.go
111
scan.go
@ -50,54 +50,40 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) {
|
||||
for idx, column := range columns {
|
||||
if sch == nil {
|
||||
values[idx] = reflectValue.Interface()
|
||||
} else if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||
func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) {
|
||||
for idx, field := range fields {
|
||||
if field != nil {
|
||||
values[idx] = field.NewValuePool.Get()
|
||||
defer field.NewValuePool.Put(values[idx])
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
values[idx] = field.NewValuePool.Get()
|
||||
defer field.NewValuePool.Put(values[idx])
|
||||
continue
|
||||
}
|
||||
} else if len(fields) == 1 {
|
||||
if reflectValue.CanAddr() {
|
||||
values[idx] = reflectValue.Addr().Interface()
|
||||
} else {
|
||||
values[idx] = reflectValue.Interface()
|
||||
}
|
||||
values[idx] = &sql.RawBytes{}
|
||||
} else if len(columns) == 1 {
|
||||
sch = nil
|
||||
values[idx] = reflectValue.Interface()
|
||||
} else {
|
||||
values[idx] = &sql.RawBytes{}
|
||||
}
|
||||
}
|
||||
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
if sch != nil {
|
||||
for idx, column := range columns {
|
||||
if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||
for idx, field := range fields {
|
||||
if field != nil {
|
||||
if len(joinFields) == 0 || joinFields[idx][0] == nil {
|
||||
field.Set(db.Statement.Context, reflectValue, values[idx])
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
relValue := rel.Field.ReflectValueOf(db.Statement.Context, reflectValue)
|
||||
|
||||
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
||||
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
||||
continue
|
||||
}
|
||||
|
||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||
}
|
||||
|
||||
field.Set(db.Statement.Context, relValue, values[idx])
|
||||
} else {
|
||||
relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue)
|
||||
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
||||
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
||||
return
|
||||
}
|
||||
|
||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||
}
|
||||
joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])
|
||||
}
|
||||
|
||||
// release data to pool
|
||||
field.NewValuePool.Put(values[idx])
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -176,7 +162,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
||||
reflectValue = db.Statement.ReflectValue
|
||||
)
|
||||
|
||||
for reflectValue.Kind() == reflect.Interface {
|
||||
if reflectValue.Kind() == reflect.Interface {
|
||||
reflectValue = reflectValue.Elem()
|
||||
}
|
||||
|
||||
@ -195,35 +181,38 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
||||
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
||||
}
|
||||
|
||||
for idx, column := range columns {
|
||||
if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
|
||||
if len(joinFields) == 0 {
|
||||
joinFields = make([][2]*schema.Field, len(columns))
|
||||
}
|
||||
joinFields[idx] = [2]*schema.Field{rel.Field, field}
|
||||
continue
|
||||
}
|
||||
}
|
||||
values[idx] = &sql.RawBytes{}
|
||||
} else {
|
||||
values[idx] = &sql.RawBytes{}
|
||||
}
|
||||
}
|
||||
|
||||
if len(columns) == 1 {
|
||||
// isPluck
|
||||
// Is Pluck
|
||||
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
|
||||
reflectValueType.Kind() != reflect.Struct || // is not struct
|
||||
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
|
||||
sch = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Not Pluck
|
||||
if sch != nil {
|
||||
for idx, column := range columns {
|
||||
if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
|
||||
if len(joinFields) == 0 {
|
||||
joinFields = make([][2]*schema.Field, len(columns))
|
||||
}
|
||||
joinFields[idx] = [2]*schema.Field{rel.Field, field}
|
||||
continue
|
||||
}
|
||||
}
|
||||
values[idx] = &sql.RawBytes{}
|
||||
} else {
|
||||
values[idx] = &sql.RawBytes{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
@ -256,7 +245,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
||||
elem = reflect.New(reflectValueType)
|
||||
}
|
||||
|
||||
db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields)
|
||||
db.scanIntoStruct(rows, elem, values, fields, joinFields)
|
||||
|
||||
if !update {
|
||||
if isPtr {
|
||||
@ -272,7 +261,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
||||
}
|
||||
case reflect.Struct, reflect.Ptr:
|
||||
if initialized || rows.Next() {
|
||||
db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields)
|
||||
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
|
||||
}
|
||||
default:
|
||||
db.AddError(rows.Scan(dest))
|
||||
|
@ -465,24 +465,33 @@ func (field *Field) setupValuerAndSetter() {
|
||||
}
|
||||
|
||||
// ValueOf returns field's value and if it is zero
|
||||
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
||||
v = reflect.Indirect(v)
|
||||
for _, fieldIdx := range field.StructField.Index {
|
||||
if fieldIdx >= 0 {
|
||||
v = v.Field(fieldIdx)
|
||||
} else {
|
||||
v = v.Field(-fieldIdx - 1)
|
||||
|
||||
if !v.IsNil() {
|
||||
v = v.Elem()
|
||||
fieldIndex := field.StructField.Index[0]
|
||||
switch {
|
||||
case len(field.StructField.Index) == 1 && fieldIndex > 0:
|
||||
field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
|
||||
fieldValue := reflect.Indirect(value).Field(fieldIndex)
|
||||
return fieldValue.Interface(), fieldValue.IsZero()
|
||||
}
|
||||
default:
|
||||
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
||||
v = reflect.Indirect(v)
|
||||
for _, fieldIdx := range field.StructField.Index {
|
||||
if fieldIdx >= 0 {
|
||||
v = v.Field(fieldIdx)
|
||||
} else {
|
||||
return nil, true
|
||||
v = v.Field(-fieldIdx - 1)
|
||||
|
||||
if !v.IsNil() {
|
||||
v = v.Elem()
|
||||
} else {
|
||||
return nil, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fv, zero := v.Interface(), v.IsZero()
|
||||
return fv, zero
|
||||
fv, zero := v.Interface(), v.IsZero()
|
||||
return fv, zero
|
||||
}
|
||||
}
|
||||
|
||||
if field.Serializer != nil {
|
||||
@ -509,24 +518,31 @@ func (field *Field) setupValuerAndSetter() {
|
||||
}
|
||||
|
||||
// ReflectValueOf returns field's reflect value
|
||||
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
|
||||
v = reflect.Indirect(v)
|
||||
for idx, fieldIdx := range field.StructField.Index {
|
||||
if fieldIdx >= 0 {
|
||||
v = v.Field(fieldIdx)
|
||||
} else {
|
||||
v = v.Field(-fieldIdx - 1)
|
||||
switch {
|
||||
case len(field.StructField.Index) == 1 && fieldIndex > 0:
|
||||
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
|
||||
return reflect.Indirect(value).Field(fieldIndex)
|
||||
}
|
||||
default:
|
||||
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
|
||||
v = reflect.Indirect(v)
|
||||
for idx, fieldIdx := range field.StructField.Index {
|
||||
if fieldIdx >= 0 {
|
||||
v = v.Field(fieldIdx)
|
||||
} else {
|
||||
v = v.Field(-fieldIdx - 1)
|
||||
|
||||
if v.IsNil() {
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
}
|
||||
if v.IsNil() {
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
}
|
||||
|
||||
if idx < len(field.StructField.Index)-1 {
|
||||
v = v.Elem()
|
||||
if idx < len(field.StructField.Index)-1 {
|
||||
v = v.Elem()
|
||||
}
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) {
|
||||
|
171
tests/connpool_test.go
Normal file
171
tests/connpool_test.go
Normal file
@ -0,0 +1,171 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
type wrapperTx struct {
|
||||
*sql.Tx
|
||||
conn *wrapperConnPool
|
||||
}
|
||||
|
||||
func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
|
||||
c.conn.got = append(c.conn.got, query)
|
||||
return c.Tx.PrepareContext(ctx, query)
|
||||
}
|
||||
|
||||
func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
c.conn.got = append(c.conn.got, query)
|
||||
return c.Tx.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
c.conn.got = append(c.conn.got, query)
|
||||
return c.Tx.QueryContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
c.conn.got = append(c.conn.got, query)
|
||||
return c.Tx.QueryRowContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
type wrapperConnPool struct {
|
||||
db *sql.DB
|
||||
got []string
|
||||
expect []string
|
||||
}
|
||||
|
||||
func (c *wrapperConnPool) Ping() error {
|
||||
return c.db.Ping()
|
||||
}
|
||||
|
||||
// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
|
||||
// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
||||
// return c.db.BeginTx(ctx, opts)
|
||||
// }
|
||||
// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
|
||||
func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) {
|
||||
tx, err := c.db.BeginTx(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &wrapperTx{Tx: tx, conn: c}, nil
|
||||
}
|
||||
|
||||
func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
|
||||
c.got = append(c.got, query)
|
||||
return c.db.PrepareContext(ctx, query)
|
||||
}
|
||||
|
||||
func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
c.got = append(c.got, query)
|
||||
return c.db.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
c.got = append(c.got, query)
|
||||
return c.db.QueryContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
c.got = append(c.got, query)
|
||||
return c.db.QueryRowContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func TestConnPoolWrapper(t *testing.T) {
|
||||
dialect := os.Getenv("GORM_DIALECT")
|
||||
if dialect != "mysql" {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
dbDSN := os.Getenv("GORM_DSN")
|
||||
if dbDSN == "" {
|
||||
dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
|
||||
}
|
||||
nativeDB, err := sql.Open("mysql", dbDSN)
|
||||
if err != nil {
|
||||
t.Fatalf("Should open db success, but got %v", err)
|
||||
}
|
||||
|
||||
conn := &wrapperConnPool{
|
||||
db: nativeDB,
|
||||
expect: []string{
|
||||
"SELECT VERSION()",
|
||||
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||
},
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if !reflect.DeepEqual(conn.got, conn.expect) {
|
||||
t.Errorf("expect %#v but got %#v", conn.expect, conn.got)
|
||||
}
|
||||
}()
|
||||
|
||||
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}))
|
||||
if err != nil {
|
||||
t.Fatalf("Should open db success, but got %v", err)
|
||||
}
|
||||
|
||||
tx := db.Begin()
|
||||
user := *GetUser("transaction", Config{})
|
||||
|
||||
if err = tx.Save(&user).Error; err != nil {
|
||||
t.Fatalf("No error should raise, but got %v", err)
|
||||
}
|
||||
|
||||
if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil {
|
||||
t.Fatalf("Should find saved record, but got %v", err)
|
||||
}
|
||||
|
||||
user1 := *GetUser("transaction1-1", Config{})
|
||||
|
||||
if err = tx.Save(&user1).Error; err != nil {
|
||||
t.Fatalf("No error should raise, but got %v", err)
|
||||
}
|
||||
|
||||
if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil {
|
||||
t.Fatalf("Should find saved record, but got %v", err)
|
||||
}
|
||||
|
||||
if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil {
|
||||
t.Fatalf("Should return the underlying sql.Tx")
|
||||
}
|
||||
|
||||
tx.Rollback()
|
||||
|
||||
if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil {
|
||||
t.Fatalf("Should not find record after rollback, but got %v", err)
|
||||
}
|
||||
|
||||
txDB := db.Where("fake_name = ?", "fake_name")
|
||||
tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin()
|
||||
user2 := *GetUser("transaction-2", Config{})
|
||||
if err = tx2.Save(&user2).Error; err != nil {
|
||||
t.Fatalf("No error should raise, but got %v", err)
|
||||
}
|
||||
|
||||
if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
|
||||
t.Fatalf("Should find saved record, but got %v", err)
|
||||
}
|
||||
|
||||
tx2.Commit()
|
||||
|
||||
if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
|
||||
t.Fatalf("Should be able to find committed record, but got %v", err)
|
||||
}
|
||||
}
|
@ -3,11 +3,12 @@ module gorm.io/gorm/tests
|
||||
go 1.14
|
||||
|
||||
require (
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/jackc/pgx/v4 v4.15.0 // indirect
|
||||
github.com/jinzhu/now v1.1.4
|
||||
github.com/lib/pq v1.10.4
|
||||
github.com/mattn/go-sqlite3 v1.14.11 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.12 // indirect
|
||||
golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect
|
||||
gorm.io/driver/mysql v1.3.2
|
||||
gorm.io/driver/postgres v1.3.1
|
||||
|
@ -1158,3 +1158,39 @@ func TestQueryWithTableAndConditionsAndAllFields(t *testing.T) {
|
||||
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||
}
|
||||
}
|
||||
|
||||
type DoubleInt64 struct {
|
||||
data int64
|
||||
}
|
||||
|
||||
func (t *DoubleInt64) Scan(val interface{}) error {
|
||||
switch v := val.(type) {
|
||||
case int64:
|
||||
t.data = v * 2
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("DoubleInt64 cant not scan with:%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/go-gorm/gorm/issues/5091
|
||||
func TestQueryScannerWithSingleColumn(t *testing.T) {
|
||||
user := User{Name: "scanner_raw_1", Age: 10}
|
||||
DB.Create(&user)
|
||||
|
||||
var result1 DoubleInt64
|
||||
if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Pluck(
|
||||
"age", &result1).Error; err != nil {
|
||||
t.Errorf("Failed, got error: %v", err)
|
||||
}
|
||||
|
||||
AssertEqual(t, result1.data, 20)
|
||||
|
||||
var result2 DoubleInt64
|
||||
if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Select(
|
||||
"age").Scan(&result2).Error; err != nil {
|
||||
t.Errorf("Failed, got error: %v", err)
|
||||
}
|
||||
|
||||
AssertEqual(t, result2.data, 20)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user