Merge branch 'go-gorm:master' into master
This commit is contained in:
commit
da07675bae
1
.gitignore
vendored
1
.gitignore
vendored
@ -3,3 +3,4 @@ documents
|
|||||||
coverage.txt
|
coverage.txt
|
||||||
_book
|
_book
|
||||||
.idea
|
.idea
|
||||||
|
vendor
|
@ -201,13 +201,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
|||||||
switch stmt.ReflectValue.Kind() {
|
switch stmt.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
rValLen := stmt.ReflectValue.Len()
|
rValLen := stmt.ReflectValue.Len()
|
||||||
stmt.SQL.Grow(rValLen * 18)
|
|
||||||
values.Values = make([][]interface{}, rValLen)
|
|
||||||
if rValLen == 0 {
|
if rValLen == 0 {
|
||||||
stmt.AddError(gorm.ErrEmptySlice)
|
stmt.AddError(gorm.ErrEmptySlice)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stmt.SQL.Grow(rValLen * 18)
|
||||||
|
stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
|
||||||
|
values.Values = make([][]interface{}, rValLen)
|
||||||
|
|
||||||
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
|
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
|
||||||
for i := 0; i < rValLen; i++ {
|
for i := 0; i < rValLen; i++ {
|
||||||
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
||||||
|
@ -255,6 +255,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
|
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
|
||||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
||||||
@ -599,11 +600,12 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
|||||||
opt = opts[0]
|
opt = opts[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
|
switch beginner := tx.Statement.ConnPool.(type) {
|
||||||
|
case TxBeginner:
|
||||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||||
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
|
case ConnPoolBeginner:
|
||||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||||
} else {
|
default:
|
||||||
err = ErrInvalidTransaction
|
err = ErrInvalidTransaction
|
||||||
}
|
}
|
||||||
|
|
||||||
|
2
gorm.go
2
gorm.go
@ -462,7 +462,7 @@ func (db *DB) Use(plugin Plugin) error {
|
|||||||
// .First(&User{})
|
// .First(&User{})
|
||||||
// })
|
// })
|
||||||
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
|
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
|
||||||
tx := queryFn(db.Session(&Session{DryRun: true}))
|
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
|
||||||
stmt := tx.Statement
|
stmt := tx.Statement
|
||||||
|
|
||||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||||
|
@ -56,6 +56,13 @@ type TxCommitter interface {
|
|||||||
Rollback() error
|
Rollback() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tx sql.Tx interface
|
||||||
|
type Tx interface {
|
||||||
|
ConnPool
|
||||||
|
TxCommitter
|
||||||
|
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
// Valuer gorm valuer interface
|
// Valuer gorm valuer interface
|
||||||
type Valuer interface {
|
type Valuer interface {
|
||||||
GormValue(context.Context, *DB) clause.Expr
|
GormValue(context.Context, *DB) clause.Expr
|
||||||
|
@ -115,7 +115,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PreparedStmtTX struct {
|
type PreparedStmtTX struct {
|
||||||
*sql.Tx
|
Tx
|
||||||
PreparedStmtDB *PreparedStmtDB
|
PreparedStmtDB *PreparedStmtDB
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -151,7 +151,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
|||||||
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
|
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.PreparedStmtDB.Mux.Lock()
|
tx.PreparedStmtDB.Mux.Lock()
|
||||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||||
|
111
scan.go
111
scan.go
@ -50,54 +50,40 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) {
|
func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) {
|
||||||
for idx, column := range columns {
|
for idx, field := range fields {
|
||||||
if sch == nil {
|
if field != nil {
|
||||||
values[idx] = reflectValue.Interface()
|
|
||||||
} else if field := sch.LookUpField(column); field != nil && field.Readable {
|
|
||||||
values[idx] = field.NewValuePool.Get()
|
values[idx] = field.NewValuePool.Get()
|
||||||
defer field.NewValuePool.Put(values[idx])
|
} else if len(fields) == 1 {
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
if reflectValue.CanAddr() {
|
||||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
values[idx] = reflectValue.Addr().Interface()
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
} else {
|
||||||
values[idx] = field.NewValuePool.Get()
|
values[idx] = reflectValue.Interface()
|
||||||
defer field.NewValuePool.Put(values[idx])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
values[idx] = &sql.RawBytes{}
|
|
||||||
} else if len(columns) == 1 {
|
|
||||||
sch = nil
|
|
||||||
values[idx] = reflectValue.Interface()
|
|
||||||
} else {
|
|
||||||
values[idx] = &sql.RawBytes{}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
db.AddError(rows.Scan(values...))
|
db.AddError(rows.Scan(values...))
|
||||||
|
|
||||||
if sch != nil {
|
for idx, field := range fields {
|
||||||
for idx, column := range columns {
|
if field != nil {
|
||||||
if field := sch.LookUpField(column); field != nil && field.Readable {
|
if len(joinFields) == 0 || joinFields[idx][0] == nil {
|
||||||
field.Set(db.Statement.Context, reflectValue, values[idx])
|
field.Set(db.Statement.Context, reflectValue, values[idx])
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
} else {
|
||||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue)
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
||||||
relValue := rel.Field.ReflectValueOf(db.Statement.Context, reflectValue)
|
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
||||||
|
return
|
||||||
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
|
||||||
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
|
||||||
}
|
|
||||||
|
|
||||||
field.Set(db.Statement.Context, relValue, values[idx])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||||
}
|
}
|
||||||
|
joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// release data to pool
|
||||||
|
field.NewValuePool.Put(values[idx])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -176,7 +162,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
|||||||
reflectValue = db.Statement.ReflectValue
|
reflectValue = db.Statement.ReflectValue
|
||||||
)
|
)
|
||||||
|
|
||||||
for reflectValue.Kind() == reflect.Interface {
|
if reflectValue.Kind() == reflect.Interface {
|
||||||
reflectValue = reflectValue.Elem()
|
reflectValue = reflectValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -195,35 +181,38 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
|||||||
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, column := range columns {
|
|
||||||
if field := sch.LookUpField(column); field != nil && field.Readable {
|
|
||||||
fields[idx] = field
|
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
|
||||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
|
||||||
fields[idx] = field
|
|
||||||
|
|
||||||
if len(joinFields) == 0 {
|
|
||||||
joinFields = make([][2]*schema.Field, len(columns))
|
|
||||||
}
|
|
||||||
joinFields[idx] = [2]*schema.Field{rel.Field, field}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
values[idx] = &sql.RawBytes{}
|
|
||||||
} else {
|
|
||||||
values[idx] = &sql.RawBytes{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(columns) == 1 {
|
if len(columns) == 1 {
|
||||||
// isPluck
|
// Is Pluck
|
||||||
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
|
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
|
||||||
reflectValueType.Kind() != reflect.Struct || // is not struct
|
reflectValueType.Kind() != reflect.Struct || // is not struct
|
||||||
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
|
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
|
||||||
sch = nil
|
sch = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Not Pluck
|
||||||
|
if sch != nil {
|
||||||
|
for idx, column := range columns {
|
||||||
|
if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||||
|
fields[idx] = field
|
||||||
|
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||||
|
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||||
|
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||||
|
fields[idx] = field
|
||||||
|
|
||||||
|
if len(joinFields) == 0 {
|
||||||
|
joinFields = make([][2]*schema.Field, len(columns))
|
||||||
|
}
|
||||||
|
joinFields[idx] = [2]*schema.Field{rel.Field, field}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
values[idx] = &sql.RawBytes{}
|
||||||
|
} else {
|
||||||
|
values[idx] = &sql.RawBytes{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
@ -256,7 +245,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
|||||||
elem = reflect.New(reflectValueType)
|
elem = reflect.New(reflectValueType)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields)
|
db.scanIntoStruct(rows, elem, values, fields, joinFields)
|
||||||
|
|
||||||
if !update {
|
if !update {
|
||||||
if isPtr {
|
if isPtr {
|
||||||
@ -272,7 +261,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
|||||||
}
|
}
|
||||||
case reflect.Struct, reflect.Ptr:
|
case reflect.Struct, reflect.Ptr:
|
||||||
if initialized || rows.Next() {
|
if initialized || rows.Next() {
|
||||||
db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields)
|
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
db.AddError(rows.Scan(dest))
|
db.AddError(rows.Scan(dest))
|
||||||
|
@ -465,24 +465,33 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ValueOf returns field's value and if it is zero
|
// ValueOf returns field's value and if it is zero
|
||||||
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
fieldIndex := field.StructField.Index[0]
|
||||||
v = reflect.Indirect(v)
|
switch {
|
||||||
for _, fieldIdx := range field.StructField.Index {
|
case len(field.StructField.Index) == 1 && fieldIndex > 0:
|
||||||
if fieldIdx >= 0 {
|
field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
|
||||||
v = v.Field(fieldIdx)
|
fieldValue := reflect.Indirect(value).Field(fieldIndex)
|
||||||
} else {
|
return fieldValue.Interface(), fieldValue.IsZero()
|
||||||
v = v.Field(-fieldIdx - 1)
|
}
|
||||||
|
default:
|
||||||
if !v.IsNil() {
|
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
||||||
v = v.Elem()
|
v = reflect.Indirect(v)
|
||||||
|
for _, fieldIdx := range field.StructField.Index {
|
||||||
|
if fieldIdx >= 0 {
|
||||||
|
v = v.Field(fieldIdx)
|
||||||
} else {
|
} else {
|
||||||
return nil, true
|
v = v.Field(-fieldIdx - 1)
|
||||||
|
|
||||||
|
if !v.IsNil() {
|
||||||
|
v = v.Elem()
|
||||||
|
} else {
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
fv, zero := v.Interface(), v.IsZero()
|
fv, zero := v.Interface(), v.IsZero()
|
||||||
return fv, zero
|
return fv, zero
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if field.Serializer != nil {
|
if field.Serializer != nil {
|
||||||
@ -509,24 +518,31 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ReflectValueOf returns field's reflect value
|
// ReflectValueOf returns field's reflect value
|
||||||
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
|
switch {
|
||||||
v = reflect.Indirect(v)
|
case len(field.StructField.Index) == 1 && fieldIndex > 0:
|
||||||
for idx, fieldIdx := range field.StructField.Index {
|
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
|
||||||
if fieldIdx >= 0 {
|
return reflect.Indirect(value).Field(fieldIndex)
|
||||||
v = v.Field(fieldIdx)
|
}
|
||||||
} else {
|
default:
|
||||||
v = v.Field(-fieldIdx - 1)
|
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
|
||||||
|
v = reflect.Indirect(v)
|
||||||
|
for idx, fieldIdx := range field.StructField.Index {
|
||||||
|
if fieldIdx >= 0 {
|
||||||
|
v = v.Field(fieldIdx)
|
||||||
|
} else {
|
||||||
|
v = v.Field(-fieldIdx - 1)
|
||||||
|
|
||||||
if v.IsNil() {
|
if v.IsNil() {
|
||||||
v.Set(reflect.New(v.Type().Elem()))
|
v.Set(reflect.New(v.Type().Elem()))
|
||||||
}
|
}
|
||||||
|
|
||||||
if idx < len(field.StructField.Index)-1 {
|
if idx < len(field.StructField.Index)-1 {
|
||||||
v = v.Elem()
|
v = v.Elem()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return v
|
||||||
}
|
}
|
||||||
return v
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) {
|
fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) {
|
||||||
|
171
tests/connpool_test.go
Normal file
171
tests/connpool_test.go
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
package tests_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/driver/mysql"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
. "gorm.io/gorm/utils/tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type wrapperTx struct {
|
||||||
|
*sql.Tx
|
||||||
|
conn *wrapperConnPool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
|
||||||
|
c.conn.got = append(c.conn.got, query)
|
||||||
|
return c.Tx.PrepareContext(ctx, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||||
|
c.conn.got = append(c.conn.got, query)
|
||||||
|
return c.Tx.ExecContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
|
c.conn.got = append(c.conn.got, query)
|
||||||
|
return c.Tx.QueryContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||||
|
c.conn.got = append(c.conn.got, query)
|
||||||
|
return c.Tx.QueryRowContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
type wrapperConnPool struct {
|
||||||
|
db *sql.DB
|
||||||
|
got []string
|
||||||
|
expect []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperConnPool) Ping() error {
|
||||||
|
return c.db.Ping()
|
||||||
|
}
|
||||||
|
|
||||||
|
// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
|
||||||
|
// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
||||||
|
// return c.db.BeginTx(ctx, opts)
|
||||||
|
// }
|
||||||
|
// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
|
||||||
|
func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) {
|
||||||
|
tx, err := c.db.BeginTx(ctx, opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &wrapperTx{Tx: tx, conn: c}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
|
||||||
|
c.got = append(c.got, query)
|
||||||
|
return c.db.PrepareContext(ctx, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||||
|
c.got = append(c.got, query)
|
||||||
|
return c.db.ExecContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
|
c.got = append(c.got, query)
|
||||||
|
return c.db.QueryContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||||
|
c.got = append(c.got, query)
|
||||||
|
return c.db.QueryRowContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnPoolWrapper(t *testing.T) {
|
||||||
|
dialect := os.Getenv("GORM_DIALECT")
|
||||||
|
if dialect != "mysql" {
|
||||||
|
t.SkipNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
dbDSN := os.Getenv("GORM_DSN")
|
||||||
|
if dbDSN == "" {
|
||||||
|
dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
|
||||||
|
}
|
||||||
|
nativeDB, err := sql.Open("mysql", dbDSN)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Should open db success, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &wrapperConnPool{
|
||||||
|
db: nativeDB,
|
||||||
|
expect: []string{
|
||||||
|
"SELECT VERSION()",
|
||||||
|
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||||
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||||
|
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||||
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||||
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||||
|
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||||
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||||
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if !reflect.DeepEqual(conn.got, conn.expect) {
|
||||||
|
t.Errorf("expect %#v but got %#v", conn.expect, conn.got)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Should open db success, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := db.Begin()
|
||||||
|
user := *GetUser("transaction", Config{})
|
||||||
|
|
||||||
|
if err = tx.Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("No error should raise, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user1 := *GetUser("transaction1-1", Config{})
|
||||||
|
|
||||||
|
if err = tx.Save(&user1).Error; err != nil {
|
||||||
|
t.Fatalf("No error should raise, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil {
|
||||||
|
t.Fatalf("Should return the underlying sql.Tx")
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.Rollback()
|
||||||
|
|
||||||
|
if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil {
|
||||||
|
t.Fatalf("Should not find record after rollback, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
txDB := db.Where("fake_name = ?", "fake_name")
|
||||||
|
tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin()
|
||||||
|
user2 := *GetUser("transaction-2", Config{})
|
||||||
|
if err = tx2.Save(&user2).Error; err != nil {
|
||||||
|
t.Fatalf("No error should raise, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx2.Commit()
|
||||||
|
|
||||||
|
if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
|
||||||
|
t.Fatalf("Should be able to find committed record, but got %v", err)
|
||||||
|
}
|
||||||
|
}
|
@ -3,11 +3,12 @@ module gorm.io/gorm/tests
|
|||||||
go 1.14
|
go 1.14
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||||
github.com/google/uuid v1.3.0
|
github.com/google/uuid v1.3.0
|
||||||
github.com/jackc/pgx/v4 v4.15.0 // indirect
|
github.com/jackc/pgx/v4 v4.15.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.4
|
github.com/jinzhu/now v1.1.4
|
||||||
github.com/lib/pq v1.10.4
|
github.com/lib/pq v1.10.4
|
||||||
github.com/mattn/go-sqlite3 v1.14.11 // indirect
|
github.com/mattn/go-sqlite3 v1.14.12 // indirect
|
||||||
golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect
|
golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect
|
||||||
gorm.io/driver/mysql v1.3.2
|
gorm.io/driver/mysql v1.3.2
|
||||||
gorm.io/driver/postgres v1.3.1
|
gorm.io/driver/postgres v1.3.1
|
||||||
|
@ -1158,3 +1158,39 @@ func TestQueryWithTableAndConditionsAndAllFields(t *testing.T) {
|
|||||||
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DoubleInt64 struct {
|
||||||
|
data int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DoubleInt64) Scan(val interface{}) error {
|
||||||
|
switch v := val.(type) {
|
||||||
|
case int64:
|
||||||
|
t.data = v * 2
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("DoubleInt64 cant not scan with:%v", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/go-gorm/gorm/issues/5091
|
||||||
|
func TestQueryScannerWithSingleColumn(t *testing.T) {
|
||||||
|
user := User{Name: "scanner_raw_1", Age: 10}
|
||||||
|
DB.Create(&user)
|
||||||
|
|
||||||
|
var result1 DoubleInt64
|
||||||
|
if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Pluck(
|
||||||
|
"age", &result1).Error; err != nil {
|
||||||
|
t.Errorf("Failed, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, result1.data, 20)
|
||||||
|
|
||||||
|
var result2 DoubleInt64
|
||||||
|
if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Select(
|
||||||
|
"age").Scan(&result2).Error; err != nil {
|
||||||
|
t.Errorf("Failed, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, result2.data, 20)
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user