Merge branch 'go-gorm:master' into master

This commit is contained in:
Qt 2022-03-11 22:17:17 +08:00 committed by GitHub
commit da07675bae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 322 additions and 97 deletions

1
.gitignore vendored
View File

@ -3,3 +3,4 @@ documents
coverage.txt
_book
.idea
vendor

View File

@ -201,13 +201,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
rValLen := stmt.ReflectValue.Len()
stmt.SQL.Grow(rValLen * 18)
values.Values = make([][]interface{}, rValLen)
if rValLen == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
stmt.SQL.Grow(rValLen * 18)
stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
values.Values = make([][]interface{}, rValLen)
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
for i := 0; i < rValLen; i++ {
rv := reflect.Indirect(stmt.ReflectValue.Index(i))

View File

@ -255,6 +255,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
}
}
}
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{
@ -599,11 +600,12 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
opt = opts[0]
}
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else {
default:
err = ErrInvalidTransaction
}

View File

@ -462,7 +462,7 @@ func (db *DB) Use(plugin Plugin) error {
// .First(&User{})
// })
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
tx := queryFn(db.Session(&Session{DryRun: true}))
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
stmt := tx.Statement
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)

View File

@ -56,6 +56,13 @@ type TxCommitter interface {
Rollback() error
}
// Tx sql.Tx interface
type Tx interface {
ConnPool
TxCommitter
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
}
// Valuer gorm valuer interface
type Valuer interface {
GormValue(context.Context, *DB) clause.Expr

View File

@ -115,7 +115,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
}
type PreparedStmtTX struct {
*sql.Tx
Tx
PreparedStmtDB *PreparedStmtDB
}
@ -151,7 +151,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
if err != nil {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()

111
scan.go
View File

@ -50,54 +50,40 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
}
}
func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) {
for idx, column := range columns {
if sch == nil {
values[idx] = reflectValue.Interface()
} else if field := sch.LookUpField(column); field != nil && field.Readable {
func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) {
for idx, field := range fields {
if field != nil {
values[idx] = field.NewValuePool.Get()
defer field.NewValuePool.Put(values[idx])
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
values[idx] = field.NewValuePool.Get()
defer field.NewValuePool.Put(values[idx])
continue
}
} else if len(fields) == 1 {
if reflectValue.CanAddr() {
values[idx] = reflectValue.Addr().Interface()
} else {
values[idx] = reflectValue.Interface()
}
values[idx] = &sql.RawBytes{}
} else if len(columns) == 1 {
sch = nil
values[idx] = reflectValue.Interface()
} else {
values[idx] = &sql.RawBytes{}
}
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
if sch != nil {
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
for idx, field := range fields {
if field != nil {
if len(joinFields) == 0 || joinFields[idx][0] == nil {
field.Set(db.Statement.Context, reflectValue, values[idx])
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
relValue := rel.Field.ReflectValueOf(db.Statement.Context, reflectValue)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
field.Set(db.Statement.Context, relValue, values[idx])
} else {
relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
return
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])
}
// release data to pool
field.NewValuePool.Put(values[idx])
}
}
}
@ -176,7 +162,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
reflectValue = db.Statement.ReflectValue
)
for reflectValue.Kind() == reflect.Interface {
if reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem()
}
@ -195,35 +181,38 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
}
joinFields[idx] = [2]*schema.Field{rel.Field, field}
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
if len(columns) == 1 {
// isPluck
// Is Pluck
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
sch = nil
}
}
// Not Pluck
if sch != nil {
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
}
joinFields[idx] = [2]*schema.Field{rel.Field, field}
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
}
}
switch reflectValue.Kind() {
@ -256,7 +245,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
elem = reflect.New(reflectValueType)
}
db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields)
db.scanIntoStruct(rows, elem, values, fields, joinFields)
if !update {
if isPtr {
@ -272,7 +261,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
}
case reflect.Struct, reflect.Ptr:
if initialized || rows.Next() {
db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields)
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
}
default:
db.AddError(rows.Scan(dest))

View File

@ -465,24 +465,33 @@ func (field *Field) setupValuerAndSetter() {
}
// ValueOf returns field's value and if it is zero
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
v = reflect.Indirect(v)
for _, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
} else {
v = v.Field(-fieldIdx - 1)
if !v.IsNil() {
v = v.Elem()
fieldIndex := field.StructField.Index[0]
switch {
case len(field.StructField.Index) == 1 && fieldIndex > 0:
field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(fieldIndex)
return fieldValue.Interface(), fieldValue.IsZero()
}
default:
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
v = reflect.Indirect(v)
for _, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
} else {
return nil, true
v = v.Field(-fieldIdx - 1)
if !v.IsNil() {
v = v.Elem()
} else {
return nil, true
}
}
}
}
fv, zero := v.Interface(), v.IsZero()
return fv, zero
fv, zero := v.Interface(), v.IsZero()
return fv, zero
}
}
if field.Serializer != nil {
@ -509,24 +518,31 @@ func (field *Field) setupValuerAndSetter() {
}
// ReflectValueOf returns field's reflect value
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
v = reflect.Indirect(v)
for idx, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
} else {
v = v.Field(-fieldIdx - 1)
switch {
case len(field.StructField.Index) == 1 && fieldIndex > 0:
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(fieldIndex)
}
default:
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
v = reflect.Indirect(v)
for idx, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
} else {
v = v.Field(-fieldIdx - 1)
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
if idx < len(field.StructField.Index)-1 {
v = v.Elem()
if idx < len(field.StructField.Index)-1 {
v = v.Elem()
}
}
}
return v
}
return v
}
fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) {

171
tests/connpool_test.go Normal file
View File

@ -0,0 +1,171 @@
package tests_test
import (
"context"
"database/sql"
"os"
"reflect"
"testing"
"gorm.io/driver/mysql"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
type wrapperTx struct {
*sql.Tx
conn *wrapperConnPool
}
func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
c.conn.got = append(c.conn.got, query)
return c.Tx.PrepareContext(ctx, query)
}
func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
c.conn.got = append(c.conn.got, query)
return c.Tx.ExecContext(ctx, query, args...)
}
func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
c.conn.got = append(c.conn.got, query)
return c.Tx.QueryContext(ctx, query, args...)
}
func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
c.conn.got = append(c.conn.got, query)
return c.Tx.QueryRowContext(ctx, query, args...)
}
type wrapperConnPool struct {
db *sql.DB
got []string
expect []string
}
func (c *wrapperConnPool) Ping() error {
return c.db.Ping()
}
// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
// return c.db.BeginTx(ctx, opts)
// }
// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) {
tx, err := c.db.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &wrapperTx{Tx: tx, conn: c}, nil
}
func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
c.got = append(c.got, query)
return c.db.PrepareContext(ctx, query)
}
func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
c.got = append(c.got, query)
return c.db.ExecContext(ctx, query, args...)
}
func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
c.got = append(c.got, query)
return c.db.QueryContext(ctx, query, args...)
}
func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
c.got = append(c.got, query)
return c.db.QueryRowContext(ctx, query, args...)
}
func TestConnPoolWrapper(t *testing.T) {
dialect := os.Getenv("GORM_DIALECT")
if dialect != "mysql" {
t.SkipNow()
}
dbDSN := os.Getenv("GORM_DSN")
if dbDSN == "" {
dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
}
nativeDB, err := sql.Open("mysql", dbDSN)
if err != nil {
t.Fatalf("Should open db success, but got %v", err)
}
conn := &wrapperConnPool{
db: nativeDB,
expect: []string{
"SELECT VERSION()",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
},
}
defer func() {
if !reflect.DeepEqual(conn.got, conn.expect) {
t.Errorf("expect %#v but got %#v", conn.expect, conn.got)
}
}()
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}))
if err != nil {
t.Fatalf("Should open db success, but got %v", err)
}
tx := db.Begin()
user := *GetUser("transaction", Config{})
if err = tx.Save(&user).Error; err != nil {
t.Fatalf("No error should raise, but got %v", err)
}
if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil {
t.Fatalf("Should find saved record, but got %v", err)
}
user1 := *GetUser("transaction1-1", Config{})
if err = tx.Save(&user1).Error; err != nil {
t.Fatalf("No error should raise, but got %v", err)
}
if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil {
t.Fatalf("Should find saved record, but got %v", err)
}
if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil {
t.Fatalf("Should return the underlying sql.Tx")
}
tx.Rollback()
if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil {
t.Fatalf("Should not find record after rollback, but got %v", err)
}
txDB := db.Where("fake_name = ?", "fake_name")
tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin()
user2 := *GetUser("transaction-2", Config{})
if err = tx2.Save(&user2).Error; err != nil {
t.Fatalf("No error should raise, but got %v", err)
}
if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
t.Fatalf("Should find saved record, but got %v", err)
}
tx2.Commit()
if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
t.Fatalf("Should be able to find committed record, but got %v", err)
}
}

View File

@ -3,11 +3,12 @@ module gorm.io/gorm/tests
go 1.14
require (
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/google/uuid v1.3.0
github.com/jackc/pgx/v4 v4.15.0 // indirect
github.com/jinzhu/now v1.1.4
github.com/lib/pq v1.10.4
github.com/mattn/go-sqlite3 v1.14.11 // indirect
github.com/mattn/go-sqlite3 v1.14.12 // indirect
golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect
gorm.io/driver/mysql v1.3.2
gorm.io/driver/postgres v1.3.1

View File

@ -1158,3 +1158,39 @@ func TestQueryWithTableAndConditionsAndAllFields(t *testing.T) {
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
}
}
type DoubleInt64 struct {
data int64
}
func (t *DoubleInt64) Scan(val interface{}) error {
switch v := val.(type) {
case int64:
t.data = v * 2
return nil
default:
return fmt.Errorf("DoubleInt64 cant not scan with:%v", v)
}
}
// https://github.com/go-gorm/gorm/issues/5091
func TestQueryScannerWithSingleColumn(t *testing.T) {
user := User{Name: "scanner_raw_1", Age: 10}
DB.Create(&user)
var result1 DoubleInt64
if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Pluck(
"age", &result1).Error; err != nil {
t.Errorf("Failed, got error: %v", err)
}
AssertEqual(t, result1.data, 20)
var result2 DoubleInt64
if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Select(
"age").Scan(&result2).Error; err != nil {
t.Errorf("Failed, got error: %v", err)
}
AssertEqual(t, result2.data, 20)
}