Merge branch 'go-gorm:master' into master
This commit is contained in:
		
						commit
						da07675bae
					
				
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -3,3 +3,4 @@ documents
 | 
			
		||||
coverage.txt
 | 
			
		||||
_book
 | 
			
		||||
.idea
 | 
			
		||||
vendor
 | 
			
		||||
@ -201,13 +201,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
 | 
			
		||||
		switch stmt.ReflectValue.Kind() {
 | 
			
		||||
		case reflect.Slice, reflect.Array:
 | 
			
		||||
			rValLen := stmt.ReflectValue.Len()
 | 
			
		||||
			stmt.SQL.Grow(rValLen * 18)
 | 
			
		||||
			values.Values = make([][]interface{}, rValLen)
 | 
			
		||||
			if rValLen == 0 {
 | 
			
		||||
				stmt.AddError(gorm.ErrEmptySlice)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			stmt.SQL.Grow(rValLen * 18)
 | 
			
		||||
			stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
 | 
			
		||||
			values.Values = make([][]interface{}, rValLen)
 | 
			
		||||
 | 
			
		||||
			defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
 | 
			
		||||
			for i := 0; i < rValLen; i++ {
 | 
			
		||||
				rv := reflect.Indirect(stmt.ReflectValue.Index(i))
 | 
			
		||||
 | 
			
		||||
@ -255,6 +255,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
 | 
			
		||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
			
		||||
	queryTx := db.Limit(1).Order(clause.OrderByColumn{
 | 
			
		||||
@ -599,11 +600,12 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
 | 
			
		||||
		opt = opts[0]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
 | 
			
		||||
	switch beginner := tx.Statement.ConnPool.(type) {
 | 
			
		||||
	case TxBeginner:
 | 
			
		||||
		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
 | 
			
		||||
	} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
 | 
			
		||||
	case ConnPoolBeginner:
 | 
			
		||||
		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
 | 
			
		||||
	} else {
 | 
			
		||||
	default:
 | 
			
		||||
		err = ErrInvalidTransaction
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								gorm.go
									
									
									
									
									
								
							@ -462,7 +462,7 @@ func (db *DB) Use(plugin Plugin) error {
 | 
			
		||||
//			.First(&User{})
 | 
			
		||||
// })
 | 
			
		||||
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
 | 
			
		||||
	tx := queryFn(db.Session(&Session{DryRun: true}))
 | 
			
		||||
	tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
 | 
			
		||||
	stmt := tx.Statement
 | 
			
		||||
 | 
			
		||||
	return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
 | 
			
		||||
 | 
			
		||||
@ -56,6 +56,13 @@ type TxCommitter interface {
 | 
			
		||||
	Rollback() error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tx sql.Tx interface
 | 
			
		||||
type Tx interface {
 | 
			
		||||
	ConnPool
 | 
			
		||||
	TxCommitter
 | 
			
		||||
	StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Valuer gorm valuer interface
 | 
			
		||||
type Valuer interface {
 | 
			
		||||
	GormValue(context.Context, *DB) clause.Expr
 | 
			
		||||
 | 
			
		||||
@ -115,7 +115,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PreparedStmtTX struct {
 | 
			
		||||
	*sql.Tx
 | 
			
		||||
	Tx
 | 
			
		||||
	PreparedStmtDB *PreparedStmtDB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -151,7 +151,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
 | 
			
		||||
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
 | 
			
		||||
	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
 | 
			
		||||
		rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			tx.PreparedStmtDB.Mux.Lock()
 | 
			
		||||
			defer tx.PreparedStmtDB.Mux.Unlock()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										75
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										75
									
								
								scan.go
									
									
									
									
									
								
							@ -50,54 +50,40 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) {
 | 
			
		||||
	for idx, column := range columns {
 | 
			
		||||
		if sch == nil {
 | 
			
		||||
			values[idx] = reflectValue.Interface()
 | 
			
		||||
		} else if field := sch.LookUpField(column); field != nil && field.Readable {
 | 
			
		||||
func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) {
 | 
			
		||||
	for idx, field := range fields {
 | 
			
		||||
		if field != nil {
 | 
			
		||||
			values[idx] = field.NewValuePool.Get()
 | 
			
		||||
			defer field.NewValuePool.Put(values[idx])
 | 
			
		||||
		} else if names := strings.Split(column, "__"); len(names) > 1 {
 | 
			
		||||
			if rel, ok := sch.Relationships.Relations[names[0]]; ok {
 | 
			
		||||
				if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
 | 
			
		||||
					values[idx] = field.NewValuePool.Get()
 | 
			
		||||
					defer field.NewValuePool.Put(values[idx])
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			values[idx] = &sql.RawBytes{}
 | 
			
		||||
		} else if len(columns) == 1 {
 | 
			
		||||
			sch = nil
 | 
			
		||||
			values[idx] = reflectValue.Interface()
 | 
			
		||||
		} else if len(fields) == 1 {
 | 
			
		||||
			if reflectValue.CanAddr() {
 | 
			
		||||
				values[idx] = reflectValue.Addr().Interface()
 | 
			
		||||
			} else {
 | 
			
		||||
			values[idx] = &sql.RawBytes{}
 | 
			
		||||
				values[idx] = reflectValue.Interface()
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db.RowsAffected++
 | 
			
		||||
	db.AddError(rows.Scan(values...))
 | 
			
		||||
 | 
			
		||||
	if sch != nil {
 | 
			
		||||
		for idx, column := range columns {
 | 
			
		||||
			if field := sch.LookUpField(column); field != nil && field.Readable {
 | 
			
		||||
	for idx, field := range fields {
 | 
			
		||||
		if field != nil {
 | 
			
		||||
			if len(joinFields) == 0 || joinFields[idx][0] == nil {
 | 
			
		||||
				field.Set(db.Statement.Context, reflectValue, values[idx])
 | 
			
		||||
			} else if names := strings.Split(column, "__"); len(names) > 1 {
 | 
			
		||||
				if rel, ok := sch.Relationships.Relations[names[0]]; ok {
 | 
			
		||||
					if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
 | 
			
		||||
						relValue := rel.Field.ReflectValueOf(db.Statement.Context, reflectValue)
 | 
			
		||||
 | 
			
		||||
			} else {
 | 
			
		||||
				relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue)
 | 
			
		||||
				if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
 | 
			
		||||
					if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
 | 
			
		||||
								continue
 | 
			
		||||
						return
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					relValue.Set(reflect.New(relValue.Type().Elem()))
 | 
			
		||||
				}
 | 
			
		||||
				joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
						field.Set(db.Statement.Context, relValue, values[idx])
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			// release data to pool
 | 
			
		||||
			field.NewValuePool.Put(values[idx])
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -176,7 +162,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
 | 
			
		||||
			reflectValue = db.Statement.ReflectValue
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		for reflectValue.Kind() == reflect.Interface {
 | 
			
		||||
		if reflectValue.Kind() == reflect.Interface {
 | 
			
		||||
			reflectValue = reflectValue.Elem()
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@ -195,6 +181,17 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
 | 
			
		||||
				sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if len(columns) == 1 {
 | 
			
		||||
				// Is Pluck
 | 
			
		||||
				if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
 | 
			
		||||
					reflectValueType.Kind() != reflect.Struct || // is not struct
 | 
			
		||||
					sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
 | 
			
		||||
					sch = nil
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Not Pluck
 | 
			
		||||
			if sch != nil {
 | 
			
		||||
				for idx, column := range columns {
 | 
			
		||||
					if field := sch.LookUpField(column); field != nil && field.Readable {
 | 
			
		||||
						fields[idx] = field
 | 
			
		||||
@ -215,14 +212,6 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
 | 
			
		||||
						values[idx] = &sql.RawBytes{}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
			if len(columns) == 1 {
 | 
			
		||||
				// isPluck
 | 
			
		||||
				if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
 | 
			
		||||
					reflectValueType.Kind() != reflect.Struct || // is not struct
 | 
			
		||||
					sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
 | 
			
		||||
					sch = nil
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@ -256,7 +245,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
 | 
			
		||||
					elem = reflect.New(reflectValueType)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields)
 | 
			
		||||
				db.scanIntoStruct(rows, elem, values, fields, joinFields)
 | 
			
		||||
 | 
			
		||||
				if !update {
 | 
			
		||||
					if isPtr {
 | 
			
		||||
@ -272,7 +261,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Struct, reflect.Ptr:
 | 
			
		||||
			if initialized || rows.Next() {
 | 
			
		||||
				db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields)
 | 
			
		||||
				db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
 | 
			
		||||
			}
 | 
			
		||||
		default:
 | 
			
		||||
			db.AddError(rows.Scan(dest))
 | 
			
		||||
 | 
			
		||||
@ -465,6 +465,14 @@ func (field *Field) setupValuerAndSetter() {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ValueOf returns field's value and if it is zero
 | 
			
		||||
	fieldIndex := field.StructField.Index[0]
 | 
			
		||||
	switch {
 | 
			
		||||
	case len(field.StructField.Index) == 1 && fieldIndex > 0:
 | 
			
		||||
		field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
 | 
			
		||||
			fieldValue := reflect.Indirect(value).Field(fieldIndex)
 | 
			
		||||
			return fieldValue.Interface(), fieldValue.IsZero()
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
		field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
 | 
			
		||||
			v = reflect.Indirect(v)
 | 
			
		||||
			for _, fieldIdx := range field.StructField.Index {
 | 
			
		||||
@ -484,6 +492,7 @@ func (field *Field) setupValuerAndSetter() {
 | 
			
		||||
			fv, zero := v.Interface(), v.IsZero()
 | 
			
		||||
			return fv, zero
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if field.Serializer != nil {
 | 
			
		||||
		oldValuerOf := field.ValueOf
 | 
			
		||||
@ -509,6 +518,12 @@ func (field *Field) setupValuerAndSetter() {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ReflectValueOf returns field's reflect value
 | 
			
		||||
	switch {
 | 
			
		||||
	case len(field.StructField.Index) == 1 && fieldIndex > 0:
 | 
			
		||||
		field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
 | 
			
		||||
			return reflect.Indirect(value).Field(fieldIndex)
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
		field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
 | 
			
		||||
			v = reflect.Indirect(v)
 | 
			
		||||
			for idx, fieldIdx := range field.StructField.Index {
 | 
			
		||||
@ -528,6 +543,7 @@ func (field *Field) setupValuerAndSetter() {
 | 
			
		||||
			}
 | 
			
		||||
			return v
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) {
 | 
			
		||||
		if v == nil {
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										171
									
								
								tests/connpool_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										171
									
								
								tests/connpool_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,171 @@
 | 
			
		||||
package tests_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"os"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/driver/mysql"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	. "gorm.io/gorm/utils/tests"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type wrapperTx struct {
 | 
			
		||||
	*sql.Tx
 | 
			
		||||
	conn *wrapperConnPool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
 | 
			
		||||
	c.conn.got = append(c.conn.got, query)
 | 
			
		||||
	return c.Tx.PrepareContext(ctx, query)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
 | 
			
		||||
	c.conn.got = append(c.conn.got, query)
 | 
			
		||||
	return c.Tx.ExecContext(ctx, query, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
 | 
			
		||||
	c.conn.got = append(c.conn.got, query)
 | 
			
		||||
	return c.Tx.QueryContext(ctx, query, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
 | 
			
		||||
	c.conn.got = append(c.conn.got, query)
 | 
			
		||||
	return c.Tx.QueryRowContext(ctx, query, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type wrapperConnPool struct {
 | 
			
		||||
	db     *sql.DB
 | 
			
		||||
	got    []string
 | 
			
		||||
	expect []string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrapperConnPool) Ping() error {
 | 
			
		||||
	return c.db.Ping()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
 | 
			
		||||
// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
 | 
			
		||||
//	 return c.db.BeginTx(ctx, opts)
 | 
			
		||||
// }
 | 
			
		||||
// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
 | 
			
		||||
func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) {
 | 
			
		||||
	tx, err := c.db.BeginTx(ctx, opts)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &wrapperTx{Tx: tx, conn: c}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
 | 
			
		||||
	c.got = append(c.got, query)
 | 
			
		||||
	return c.db.PrepareContext(ctx, query)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
 | 
			
		||||
	c.got = append(c.got, query)
 | 
			
		||||
	return c.db.ExecContext(ctx, query, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
 | 
			
		||||
	c.got = append(c.got, query)
 | 
			
		||||
	return c.db.QueryContext(ctx, query, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
 | 
			
		||||
	c.got = append(c.got, query)
 | 
			
		||||
	return c.db.QueryRowContext(ctx, query, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestConnPoolWrapper(t *testing.T) {
 | 
			
		||||
	dialect := os.Getenv("GORM_DIALECT")
 | 
			
		||||
	if dialect != "mysql" {
 | 
			
		||||
		t.SkipNow()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dbDSN := os.Getenv("GORM_DSN")
 | 
			
		||||
	if dbDSN == "" {
 | 
			
		||||
		dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
 | 
			
		||||
	}
 | 
			
		||||
	nativeDB, err := sql.Open("mysql", dbDSN)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Should open db success, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	conn := &wrapperConnPool{
 | 
			
		||||
		db: nativeDB,
 | 
			
		||||
		expect: []string{
 | 
			
		||||
			"SELECT VERSION()",
 | 
			
		||||
			"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
 | 
			
		||||
			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
			
		||||
			"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
 | 
			
		||||
			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
			
		||||
			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
			
		||||
			"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
 | 
			
		||||
			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
			
		||||
			"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if !reflect.DeepEqual(conn.got, conn.expect) {
 | 
			
		||||
			t.Errorf("expect %#v but got %#v", conn.expect, conn.got)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Should open db success, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx := db.Begin()
 | 
			
		||||
	user := *GetUser("transaction", Config{})
 | 
			
		||||
 | 
			
		||||
	if err = tx.Save(&user).Error; err != nil {
 | 
			
		||||
		t.Fatalf("No error should raise, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil {
 | 
			
		||||
		t.Fatalf("Should find saved record, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	user1 := *GetUser("transaction1-1", Config{})
 | 
			
		||||
 | 
			
		||||
	if err = tx.Save(&user1).Error; err != nil {
 | 
			
		||||
		t.Fatalf("No error should raise, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil {
 | 
			
		||||
		t.Fatalf("Should find saved record, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil {
 | 
			
		||||
		t.Fatalf("Should return the underlying sql.Tx")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx.Rollback()
 | 
			
		||||
 | 
			
		||||
	if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil {
 | 
			
		||||
		t.Fatalf("Should not find record after rollback, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	txDB := db.Where("fake_name = ?", "fake_name")
 | 
			
		||||
	tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin()
 | 
			
		||||
	user2 := *GetUser("transaction-2", Config{})
 | 
			
		||||
	if err = tx2.Save(&user2).Error; err != nil {
 | 
			
		||||
		t.Fatalf("No error should raise, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
 | 
			
		||||
		t.Fatalf("Should find saved record, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx2.Commit()
 | 
			
		||||
 | 
			
		||||
	if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
 | 
			
		||||
		t.Fatalf("Should be able to find committed record, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -3,11 +3,12 @@ module gorm.io/gorm/tests
 | 
			
		||||
go 1.14
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
 | 
			
		||||
	github.com/google/uuid v1.3.0
 | 
			
		||||
	github.com/jackc/pgx/v4 v4.15.0 // indirect
 | 
			
		||||
	github.com/jinzhu/now v1.1.4
 | 
			
		||||
	github.com/lib/pq v1.10.4
 | 
			
		||||
	github.com/mattn/go-sqlite3 v1.14.11 // indirect
 | 
			
		||||
	github.com/mattn/go-sqlite3 v1.14.12 // indirect
 | 
			
		||||
	golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect
 | 
			
		||||
	gorm.io/driver/mysql v1.3.2
 | 
			
		||||
	gorm.io/driver/postgres v1.3.1
 | 
			
		||||
 | 
			
		||||
@ -1158,3 +1158,39 @@ func TestQueryWithTableAndConditionsAndAllFields(t *testing.T) {
 | 
			
		||||
		t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DoubleInt64 struct {
 | 
			
		||||
	data int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *DoubleInt64) Scan(val interface{}) error {
 | 
			
		||||
	switch v := val.(type) {
 | 
			
		||||
	case int64:
 | 
			
		||||
		t.data = v * 2
 | 
			
		||||
		return nil
 | 
			
		||||
	default:
 | 
			
		||||
		return fmt.Errorf("DoubleInt64 cant not scan with:%v", v)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// https://github.com/go-gorm/gorm/issues/5091
 | 
			
		||||
func TestQueryScannerWithSingleColumn(t *testing.T) {
 | 
			
		||||
	user := User{Name: "scanner_raw_1", Age: 10}
 | 
			
		||||
	DB.Create(&user)
 | 
			
		||||
 | 
			
		||||
	var result1 DoubleInt64
 | 
			
		||||
	if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Pluck(
 | 
			
		||||
		"age", &result1).Error; err != nil {
 | 
			
		||||
		t.Errorf("Failed, got error: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	AssertEqual(t, result1.data, 20)
 | 
			
		||||
 | 
			
		||||
	var result2 DoubleInt64
 | 
			
		||||
	if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Select(
 | 
			
		||||
		"age").Scan(&result2).Error; err != nil {
 | 
			
		||||
		t.Errorf("Failed, got error: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	AssertEqual(t, result2.data, 20)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user