From 397b583b8ecc5a31c838db5822fe1003b53a91ef Mon Sep 17 00:00:00 2001 From: chenrui Date: Fri, 25 Feb 2022 22:38:48 +0800 Subject: [PATCH 1/9] fix: query scanner in single column --- scan.go | 12 +++++++++++- tests/query_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/scan.go b/scan.go index 0da12daf..a1cb582e 100644 --- a/scan.go +++ b/scan.go @@ -272,7 +272,17 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } case reflect.Struct, reflect.Ptr: if initialized || rows.Next() { - db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) + if update { + db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) + } else { + elem := reflect.New(reflectValueType) + db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) + if isPtr { + db.Statement.ReflectValue.Set(elem) + } else { + db.Statement.ReflectValue.Set(elem.Elem()) + } + } } default: db.AddError(rows.Scan(dest)) diff --git a/tests/query_test.go b/tests/query_test.go index d10df180..6542774a 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1158,3 +1158,39 @@ func TestQueryWithTableAndConditionsAndAllFields(t *testing.T) { t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) } } + +type DoubleInt64 struct { + data int64 +} + +func (t *DoubleInt64) Scan(val interface{}) error { + switch v := val.(type) { + case int64: + t.data = v * 2 + return nil + default: + return fmt.Errorf("DoubleInt64 cant not scan with:%v", v) + } +} + +// https://github.com/go-gorm/gorm/issues/5091 +func TestQueryScannerWithSingleColumn(t *testing.T) { + user := User{Name: "scanner_raw_1", Age: 10} + DB.Create(&user) + + var result1 DoubleInt64 + if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Pluck( + "age", &result1).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + AssertEqual(t, result1.data, 20) + + var result2 DoubleInt64 + if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Select( + "age").Scan(&result2).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + AssertEqual(t, result2.data, 20) +} From f2edda50e11728e7aee6b1d4c961d575f7afbb2d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 25 Feb 2022 10:48:23 +0800 Subject: [PATCH 2/9] Refactor check missing where condition --- callbacks/delete.go | 19 +++++++------------ callbacks/helper.go | 16 ++++++++++++++++ callbacks/update.go | 20 +++++++------------- soft_delete.go | 11 ++--------- tests/update_test.go | 2 +- 5 files changed, 33 insertions(+), 35 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index 1fb5261c..84f446a3 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -118,6 +118,12 @@ func Delete(config *Config) func(db *gorm.DB) { return } + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.DeleteClauses { + db.Statement.AddClause(c) + } + } + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(100) db.Statement.AddClauseIfNotExists(clause.Delete{}) @@ -141,22 +147,11 @@ func Delete(config *Config) func(db *gorm.DB) { } db.Statement.AddClauseIfNotExists(clause.From{}) - } - if db.Statement.Schema != nil { - for _, c := range db.Statement.Schema.DeleteClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.Len() == 0 { db.Statement.Build(db.Statement.BuildClauses...) } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil { - db.AddError(gorm.ErrMissingWhereClause) - return - } + checkMissingWhereConditions(db) if !db.DryRun && db.Error == nil { ok, mode := hasReturning(db, supportReturning) diff --git a/callbacks/helper.go b/callbacks/helper.go index a59e1880..a5eb047e 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -104,3 +104,19 @@ func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) { } return false, 0 } + +func checkMissingWhereConditions(db *gorm.DB) { + if !db.AllowGlobalUpdate && db.Error == nil { + where, withCondition := db.Statement.Clauses["WHERE"] + if withCondition { + if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete { + whereClause, _ := where.Expression.(clause.Where) + withCondition = len(whereClause.Exprs) > 1 + } + } + if !withCondition { + db.AddError(gorm.ErrMissingWhereClause) + } + return + } +} diff --git a/callbacks/update.go b/callbacks/update.go index 4a2e5c79..da03261e 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -59,6 +59,12 @@ func Update(config *Config) func(db *gorm.DB) { return } + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } + } + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Update{}) @@ -68,22 +74,10 @@ func Update(config *Config) func(db *gorm.DB) { return } - } - - if db.Statement.Schema != nil { - for _, c := range db.Statement.Schema.UpdateClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.Len() == 0 { db.Statement.Build(db.Statement.BuildClauses...) } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { - db.AddError(gorm.ErrMissingWhereClause) - return - } + checkMissingWhereConditions(db) if !db.DryRun && db.Error == nil { if ok, mode := hasReturning(db, supportReturning); ok { diff --git a/soft_delete.go b/soft_delete.go index ba6d2118..6d646288 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -104,9 +104,7 @@ func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) { func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { - if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok { - SoftDeleteQueryClause(sd).ModifyStatement(stmt) - } + SoftDeleteQueryClause(sd).ModifyStatement(stmt) } } @@ -152,12 +150,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } } - if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { - stmt.DB.AddError(ErrMissingWhereClause) - } else { - SoftDeleteQueryClause(sd).ModifyStatement(stmt) - } - + SoftDeleteQueryClause(sd).ModifyStatement(stmt) stmt.AddClauseIfNotExists(clause.Update{}) stmt.Build(stmt.DB.Callback().Update().Clauses...) } diff --git a/tests/update_test.go b/tests/update_test.go index b471ba9b..41ea5d27 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -645,7 +645,7 @@ func TestSave(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryDB.Save(&user).Statement - if !regexp.MustCompile(`.id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { + if !regexp.MustCompile(`.users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) } From 68bb5379d91a7f7fae4dc65205db66004f515d0c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Feb 2022 09:09:29 +0800 Subject: [PATCH 3/9] Refactor scan into struct --- scan.go | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/scan.go b/scan.go index a1cb582e..e83390ca 100644 --- a/scan.go +++ b/scan.go @@ -68,7 +68,11 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re values[idx] = &sql.RawBytes{} } else if len(columns) == 1 { sch = nil - values[idx] = reflectValue.Interface() + if reflectValue.CanAddr() { + values[idx] = reflectValue.Addr().Interface() + } else { + values[idx] = reflectValue.Interface() + } } else { values[idx] = &sql.RawBytes{} } @@ -272,17 +276,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } case reflect.Struct, reflect.Ptr: if initialized || rows.Next() { - if update { - db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) - } else { - elem := reflect.New(reflectValueType) - db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) - if isPtr { - db.Statement.ReflectValue.Set(elem) - } else { - db.Statement.ReflectValue.Set(elem.Elem()) - } - } + db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) } default: db.AddError(rows.Scan(dest)) From 530b0a12b4c63bb2dc7abef2934dc8406f1d0f13 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Feb 2022 22:10:17 +0800 Subject: [PATCH 4/9] Add fast path for ValueOf, ReflectValueOf --- schema/field.go | 70 ++++++++++++++++++++++++++++++------------------- tests/go.mod | 1 + 2 files changed, 44 insertions(+), 27 deletions(-) diff --git a/schema/field.go b/schema/field.go index 8c793f93..826680c5 100644 --- a/schema/field.go +++ b/schema/field.go @@ -465,24 +465,33 @@ func (field *Field) setupValuerAndSetter() { } // ValueOf returns field's value and if it is zero - field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { - v = reflect.Indirect(v) - for _, fieldIdx := range field.StructField.Index { - if fieldIdx >= 0 { - v = v.Field(fieldIdx) - } else { - v = v.Field(-fieldIdx - 1) - - if !v.IsNil() { - v = v.Elem() + fieldIndex := field.StructField.Index[0] + switch { + case len(field.StructField.Index) == 1 && fieldIndex > 0: + field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { + fieldValue := reflect.Indirect(value).Field(fieldIndex) + return fieldValue.Interface(), fieldValue.IsZero() + } + default: + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + v = reflect.Indirect(v) + for _, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) } else { - return nil, true + v = v.Field(-fieldIdx - 1) + + if !v.IsNil() { + v = v.Elem() + } else { + return nil, true + } } } - } - fv, zero := v.Interface(), v.IsZero() - return fv, zero + fv, zero := v.Interface(), v.IsZero() + return fv, zero + } } if field.Serializer != nil { @@ -509,24 +518,31 @@ func (field *Field) setupValuerAndSetter() { } // ReflectValueOf returns field's reflect value - field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { - v = reflect.Indirect(v) - for idx, fieldIdx := range field.StructField.Index { - if fieldIdx >= 0 { - v = v.Field(fieldIdx) - } else { - v = v.Field(-fieldIdx - 1) + switch { + case len(field.StructField.Index) == 1 && fieldIndex > 0: + field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { + return reflect.Indirect(value).Field(fieldIndex) + } + default: + field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { + v = reflect.Indirect(v) + for idx, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) + } else { + v = v.Field(-fieldIdx - 1) - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } - if idx < len(field.StructField.Index)-1 { - v = v.Elem() + if idx < len(field.StructField.Index)-1 { + v = v.Elem() + } } } + return v } - return v } fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) { diff --git a/tests/go.mod b/tests/go.mod index cefe6f96..9e3453b7 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,6 +3,7 @@ module gorm.io/gorm/tests go 1.14 require ( + github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jackc/pgx/v4 v4.15.0 // indirect github.com/jinzhu/now v1.1.4 From 43a72b369e670bd91e32784d063608931a59a66e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Feb 2022 22:54:43 +0800 Subject: [PATCH 5/9] Refactor Scan --- scan.go | 104 +++++++++++++++++++++++--------------------------------- 1 file changed, 43 insertions(+), 61 deletions(-) diff --git a/scan.go b/scan.go index e83390ca..d7b58e03 100644 --- a/scan.go +++ b/scan.go @@ -50,58 +50,37 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns } } -func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) { - for idx, column := range columns { - if sch == nil { - values[idx] = reflectValue.Interface() - } else if field := sch.LookUpField(column); field != nil && field.Readable { +func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) { + for idx, field := range fields { + if field != nil { values[idx] = field.NewValuePool.Get() defer field.NewValuePool.Put(values[idx]) - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - values[idx] = field.NewValuePool.Get() - defer field.NewValuePool.Put(values[idx]) - continue - } + if len(joinFields) == 0 || joinFields[idx][0] == nil { + defer field.Set(db.Statement.Context, reflectValue, values[idx]) } - values[idx] = &sql.RawBytes{} - } else if len(columns) == 1 { - sch = nil + } else if len(fields) == 1 { if reflectValue.CanAddr() { values[idx] = reflectValue.Addr().Interface() } else { values[idx] = reflectValue.Interface() } - } else { - values[idx] = &sql.RawBytes{} } } db.RowsAffected++ db.AddError(rows.Scan(values...)) - if sch != nil { - for idx, column := range columns { - if field := sch.LookUpField(column); field != nil && field.Readable { - field.Set(db.Statement.Context, reflectValue, values[idx]) - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - relValue := rel.Field.ReflectValueOf(db.Statement.Context, reflectValue) - - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - continue - } - - relValue.Set(reflect.New(relValue.Type().Elem())) - } - - field.Set(db.Statement.Context, relValue, values[idx]) - } + for idx, joinField := range joinFields { + if joinField[0] != nil { + relValue := joinField[0].ReflectValueOf(db.Statement.Context, reflectValue) + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + return } + + relValue.Set(reflect.New(relValue.Type().Elem())) } + joinField[1].Set(db.Statement.Context, relValue, values[idx]) } } } @@ -180,7 +159,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { reflectValue = db.Statement.ReflectValue ) - for reflectValue.Kind() == reflect.Interface { + if reflectValue.Kind() == reflect.Interface { reflectValue = reflectValue.Elem() } @@ -199,35 +178,38 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) } - for idx, column := range columns { - if field := sch.LookUpField(column); field != nil && field.Readable { - fields[idx] = field - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - fields[idx] = field - - if len(joinFields) == 0 { - joinFields = make([][2]*schema.Field, len(columns)) - } - joinFields[idx] = [2]*schema.Field{rel.Field, field} - continue - } - } - values[idx] = &sql.RawBytes{} - } else { - values[idx] = &sql.RawBytes{} - } - } - if len(columns) == 1 { - // isPluck + // Is Pluck if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner reflectValueType.Kind() != reflect.Struct || // is not struct sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time sch = nil } } + + // Not Pluck + if sch != nil { + for idx, column := range columns { + if field := sch.LookUpField(column); field != nil && field.Readable { + fields[idx] = field + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := sch.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + fields[idx] = field + + if len(joinFields) == 0 { + joinFields = make([][2]*schema.Field, len(columns)) + } + joinFields[idx] = [2]*schema.Field{rel.Field, field} + continue + } + } + values[idx] = &sql.RawBytes{} + } else { + values[idx] = &sql.RawBytes{} + } + } + } } switch reflectValue.Kind() { @@ -260,7 +242,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { elem = reflect.New(reflectValueType) } - db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) + db.scanIntoStruct(rows, elem, values, fields, joinFields) if !update { if isPtr { @@ -276,7 +258,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } case reflect.Struct, reflect.Ptr: if initialized || rows.Next() { - db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) + db.scanIntoStruct(rows, reflectValue, values, fields, joinFields) } default: db.AddError(rows.Scan(dest)) From e2e802b837a234ede6dc122dbb26de965e35e55f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 28 Feb 2022 09:28:19 +0800 Subject: [PATCH 6/9] Refactor Scan --- callbacks/create.go | 6 ++++-- scan.go | 29 ++++++++++++++++------------- tests/go.mod | 2 +- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index b0964e2b..6e2883f7 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -201,13 +201,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: rValLen := stmt.ReflectValue.Len() - stmt.SQL.Grow(rValLen * 18) - values.Values = make([][]interface{}, rValLen) if rValLen == 0 { stmt.AddError(gorm.ErrEmptySlice) return } + stmt.SQL.Grow(rValLen * 18) + stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns)) + values.Values = make([][]interface{}, rValLen) + defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} for i := 0; i < rValLen; i++ { rv := reflect.Indirect(stmt.ReflectValue.Index(i)) diff --git a/scan.go b/scan.go index d7b58e03..a4243d12 100644 --- a/scan.go +++ b/scan.go @@ -54,10 +54,6 @@ func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values for idx, field := range fields { if field != nil { values[idx] = field.NewValuePool.Get() - defer field.NewValuePool.Put(values[idx]) - if len(joinFields) == 0 || joinFields[idx][0] == nil { - defer field.Set(db.Statement.Context, reflectValue, values[idx]) - } } else if len(fields) == 1 { if reflectValue.CanAddr() { values[idx] = reflectValue.Addr().Interface() @@ -70,17 +66,24 @@ func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values db.RowsAffected++ db.AddError(rows.Scan(values...)) - for idx, joinField := range joinFields { - if joinField[0] != nil { - relValue := joinField[0].ReflectValueOf(db.Statement.Context, reflectValue) - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - return - } + for idx, field := range fields { + if field != nil { + if len(joinFields) == 0 || joinFields[idx][0] == nil { + field.Set(db.Statement.Context, reflectValue, values[idx]) + } else { + relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue) + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + return + } - relValue.Set(reflect.New(relValue.Type().Elem())) + relValue.Set(reflect.New(relValue.Type().Elem())) + } + joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]) } - joinField[1].Set(db.Statement.Context, relValue, values[idx]) + + // release data to pool + field.NewValuePool.Put(values[idx]) } } } diff --git a/tests/go.mod b/tests/go.mod index 9e3453b7..c65ea953 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,7 +8,7 @@ require ( github.com/jackc/pgx/v4 v4.15.0 // indirect github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 - github.com/mattn/go-sqlite3 v1.14.11 // indirect + github.com/mattn/go-sqlite3 v1.14.12 // indirect golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect gorm.io/driver/mysql v1.3.2 gorm.io/driver/postgres v1.3.1 From 996b96e81268335b22faf694dfb4674f84177f17 Mon Sep 17 00:00:00 2001 From: lianghuan Date: Mon, 28 Feb 2022 17:12:09 +0800 Subject: [PATCH 7/9] Add TxConnPoolBeginner and Tx interface --- .gitignore | 1 + finisher_api.go | 3 + interfaces.go | 13 +++ prepare_stmt.go | 7 +- tests/connpool_test.go | 181 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 203 insertions(+), 2 deletions(-) create mode 100644 tests/connpool_test.go diff --git a/.gitignore b/.gitignore index e1b9ecea..45505cc9 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ documents coverage.txt _book .idea +vendor \ No newline at end of file diff --git a/finisher_api.go b/finisher_api.go index f994ec31..5d49ddf9 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -255,6 +255,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { } } } + // FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ @@ -603,6 +604,8 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) } else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + } else if beginner, ok := tx.Statement.ConnPool.(TxConnPoolBeginner); ok { + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) } else { err = ErrInvalidTransaction } diff --git a/interfaces.go b/interfaces.go index 44a85cb5..ed7112f2 100644 --- a/interfaces.go +++ b/interfaces.go @@ -50,12 +50,25 @@ type ConnPoolBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) } +// TxConnPoolBeginner tx conn pool beginner +type TxConnPoolBeginner interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) +} + // TxCommitter tx committer type TxCommitter interface { Commit() error Rollback() error } +// Tx sql.Tx interface +type Tx interface { + ConnPool + Commit() error + Rollback() error + StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt +} + // Valuer gorm valuer interface type Valuer interface { GormValue(context.Context, *DB) clause.Expr diff --git a/prepare_stmt.go b/prepare_stmt.go index 88bec4e9..94282fad 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -73,6 +73,9 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn if beginner, ok := db.ConnPool.(TxBeginner); ok { tx, err := beginner.BeginTx(ctx, opt) return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err + } else if beginner, ok := db.ConnPool.(TxConnPoolBeginner); ok { + tx, err := beginner.BeginTx(ctx, opt) + return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err } return nil, ErrInvalidTransaction } @@ -115,7 +118,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg } type PreparedStmtTX struct { - *sql.Tx + Tx PreparedStmtDB *PreparedStmtDB } @@ -151,7 +154,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { - rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) + rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() defer tx.PreparedStmtDB.Mux.Unlock() diff --git a/tests/connpool_test.go b/tests/connpool_test.go new file mode 100644 index 00000000..3713ad7c --- /dev/null +++ b/tests/connpool_test.go @@ -0,0 +1,181 @@ +package tests_test + +import ( + "context" + "database/sql" + "log" + "os" + "reflect" + "testing" + "time" + + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/logger" + . "gorm.io/gorm/utils/tests" +) + +type wrapperTx struct { + *sql.Tx + conn *wrapperConnPool +} + +func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + c.conn.got = append(c.conn.got, query) + return c.Tx.PrepareContext(ctx, query) +} + +func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + c.conn.got = append(c.conn.got, query) + return c.Tx.ExecContext(ctx, query, args...) +} + +func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + c.conn.got = append(c.conn.got, query) + return c.Tx.QueryContext(ctx, query, args...) +} + +func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + c.conn.got = append(c.conn.got, query) + return c.Tx.QueryRowContext(ctx, query, args...) +} + +type wrapperConnPool struct { + db *sql.DB + got []string + expect []string +} + +func (c *wrapperConnPool) Ping() error { + return c.db.Ping() +} + +// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction. +// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { +// return c.db.BeginTx(ctx, opts) +// } +// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries. +func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.Tx, error) { + tx, err := c.db.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &wrapperTx{Tx: tx, conn: c}, nil +} + +func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + c.got = append(c.got, query) + return c.db.PrepareContext(ctx, query) +} + +func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + c.got = append(c.got, query) + return c.db.ExecContext(ctx, query, args...) +} + +func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + c.got = append(c.got, query) + return c.db.QueryContext(ctx, query, args...) +} + +func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + c.got = append(c.got, query) + return c.db.QueryRowContext(ctx, query, args...) +} + +func TestConnPoolWrapper(t *testing.T) { + dialect := os.Getenv("GORM_DIALECT") + if dialect != "mysql" { + t.SkipNow() + } + + dbDSN := os.Getenv("GORM_DSN") + if dbDSN == "" { + dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" + } + nativeDB, err := sql.Open("mysql", dbDSN) + if err != nil { + t.Fatalf("Should open db success, but got %v", err) + } + + conn := &wrapperConnPool{ + db: nativeDB, + expect: []string{ + "SELECT VERSION()", + "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + }, + } + + defer func() { + if !reflect.DeepEqual(conn.got, conn.expect) { + t.Errorf("expect %#v but got %#v", conn.expect, conn.got) + } + }() + + l := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ + SlowThreshold: 200 * time.Millisecond, + LogLevel: logger.Info, + IgnoreRecordNotFoundError: false, + Colorful: true, + }) + + db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}), &gorm.Config{Logger: l}) + if err != nil { + t.Fatalf("Should open db success, but got %v", err) + } + + tx := db.Begin() + user := *GetUser("transaction", Config{}) + + if err = tx.Save(&user).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + user1 := *GetUser("transaction1-1", Config{}) + + if err = tx.Save(&user1).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil { + t.Fatalf("Should return the underlying sql.Tx") + } + + tx.Rollback() + + if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil { + t.Fatalf("Should not find record after rollback, but got %v", err) + } + + txDB := db.Where("fake_name = ?", "fake_name") + tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin() + user2 := *GetUser("transaction-2", Config{}) + if err = tx2.Save(&user2).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + tx2.Commit() + + if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil { + t.Fatalf("Should be able to find committed record, but got %v", err) + } +} From 4e523499d191d02e032b126774efd26daa8697a8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 1 Mar 2022 16:48:46 +0800 Subject: [PATCH 8/9] Refactor Tx interface --- finisher_api.go | 9 ++++----- interfaces.go | 8 +------- prepare_stmt.go | 3 --- tests/connpool_test.go | 14 ++------------ 4 files changed, 7 insertions(+), 27 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 5d49ddf9..4b428a59 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -600,13 +600,12 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { opt = opts[0] } - if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { + switch beginner := tx.Statement.ConnPool.(type) { + case TxBeginner: tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - } else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { + case ConnPoolBeginner: tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - } else if beginner, ok := tx.Statement.ConnPool.(TxConnPoolBeginner); ok { - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - } else { + default: err = ErrInvalidTransaction } diff --git a/interfaces.go b/interfaces.go index ed7112f2..84dc94bb 100644 --- a/interfaces.go +++ b/interfaces.go @@ -50,11 +50,6 @@ type ConnPoolBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) } -// TxConnPoolBeginner tx conn pool beginner -type TxConnPoolBeginner interface { - BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) -} - // TxCommitter tx committer type TxCommitter interface { Commit() error @@ -64,8 +59,7 @@ type TxCommitter interface { // Tx sql.Tx interface type Tx interface { ConnPool - Commit() error - Rollback() error + TxCommitter StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt } diff --git a/prepare_stmt.go b/prepare_stmt.go index 94282fad..b062b0d6 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -73,9 +73,6 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn if beginner, ok := db.ConnPool.(TxBeginner); ok { tx, err := beginner.BeginTx(ctx, opt) return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err - } else if beginner, ok := db.ConnPool.(TxConnPoolBeginner); ok { - tx, err := beginner.BeginTx(ctx, opt) - return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err } return nil, ErrInvalidTransaction } diff --git a/tests/connpool_test.go b/tests/connpool_test.go index 3713ad7c..fbae2294 100644 --- a/tests/connpool_test.go +++ b/tests/connpool_test.go @@ -3,15 +3,12 @@ package tests_test import ( "context" "database/sql" - "log" "os" "reflect" "testing" - "time" "gorm.io/driver/mysql" "gorm.io/gorm" - "gorm.io/gorm/logger" . "gorm.io/gorm/utils/tests" ) @@ -55,7 +52,7 @@ func (c *wrapperConnPool) Ping() error { // return c.db.BeginTx(ctx, opts) // } // You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries. -func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.Tx, error) { +func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) { tx, err := c.db.BeginTx(ctx, opts) if err != nil { return nil, err @@ -119,14 +116,7 @@ func TestConnPoolWrapper(t *testing.T) { } }() - l := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ - SlowThreshold: 200 * time.Millisecond, - LogLevel: logger.Info, - IgnoreRecordNotFoundError: false, - Colorful: true, - }) - - db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}), &gorm.Config{Logger: l}) + db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn})) if err != nil { t.Fatalf("Should open db success, but got %v", err) } From 29a8557384b060bf5d99b4b8824cb75c8a8b9917 Mon Sep 17 00:00:00 2001 From: Cao Manh Dat Date: Thu, 3 Mar 2022 09:17:29 +0700 Subject: [PATCH 9/9] ToSQL should enable SkipDefaultTransaction by default --- gorm.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gorm.go b/gorm.go index 7967b094..aca7cb5e 100644 --- a/gorm.go +++ b/gorm.go @@ -462,7 +462,7 @@ func (db *DB) Use(plugin Plugin) error { // .First(&User{}) // }) func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { - tx := queryFn(db.Session(&Session{DryRun: true})) + tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true})) stmt := tx.Statement return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)