From c291c2f42cc66892198d5254592602e000c0dac6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 30 May 2020 21:05:27 +0800 Subject: [PATCH] Add Scanner, Valuer tests --- clause/expression.go | 27 ++++-- logger/sql.go | 7 +- schema/field.go | 2 +- statement.go | 3 + tests/scanner_valuer_test.go | 175 +++++++++++++++++++++++++++++++++++ tests/utils.go | 14 ++- 6 files changed, 211 insertions(+), 17 deletions(-) create mode 100644 tests/scanner_valuer_test.go diff --git a/clause/expression.go b/clause/expression.go index e54da1af..ecf8ba85 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -1,6 +1,9 @@ package clause -import "reflect" +import ( + "database/sql/driver" + "reflect" +) // Expression expression interface type Expression interface { @@ -28,16 +31,20 @@ func (expr Expr) Build(builder Builder) { for _, v := range []byte(expr.SQL) { if v == '?' { if afterParenthesis { - switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < rv.Len(); i++ { - if i > 0 { - builder.WriteByte(',') - } - builder.AddVar(builder, rv.Index(i).Interface()) - } - default: + if _, ok := expr.Vars[idx].(driver.Valuer); ok { builder.AddVar(builder, expr.Vars[idx]) + } else { + switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + default: + builder.AddVar(builder, expr.Vars[idx]) + } } } else { builder.AddVar(builder, expr.Vars[idx]) diff --git a/logger/sql.go b/logger/sql.go index bb4e3e06..dd502324 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -57,6 +57,9 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v vars[idx] = "NULL" } else if rv.Kind() == reflect.Ptr && rv.IsNil() { vars[idx] = "NULL" + } else if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + convertParams(v, idx) } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { convertParams(reflect.Indirect(rv).Interface(), idx) } else { @@ -74,10 +77,6 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v } for idx, v := range vars { - if valuer, ok := v.(driver.Valuer); ok { - v, _ = valuer.Value() - } - convertParams(v, idx) } diff --git a/schema/field.go b/schema/field.go index d435c928..57ba3ac7 100644 --- a/schema/field.go +++ b/schema/field.go @@ -207,7 +207,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DBDataType = val } - switch fieldValue.Elem().Kind() { + switch reflect.Indirect(fieldValue).Kind() { case reflect.Bool: field.DataType = Bool if field.HasDefaultValue { diff --git a/statement.go b/statement.go index 42df148a..e0d92c5e 100644 --- a/statement.go +++ b/statement.go @@ -146,6 +146,9 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { case clause.Expr: writer.WriteString(v.SQL) stmt.Vars = append(stmt.Vars, v.Vars...) + case driver.Valuer: + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) case []interface{}: if len(v) > 0 { writer.WriteByte('(') diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go new file mode 100644 index 00000000..38ffc919 --- /dev/null +++ b/tests/scanner_valuer_test.go @@ -0,0 +1,175 @@ +package tests_test + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "reflect" + "strconv" + "testing" + "time" + + "github.com/jinzhu/gorm" + . "github.com/jinzhu/gorm/tests" +) + +func TestScannerValuer(t *testing.T) { + DB.Migrator().DropTable(&ScannerValuerStruct{}) + if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { + t.Errorf("no error should happen when migrate scanner, valuer struct") + } + + data := ScannerValuerStruct{ + Name: sql.NullString{String: "name", Valid: true}, + Gender: &sql.NullString{String: "M", Valid: true}, + Age: sql.NullInt64{Int64: 18, Valid: true}, + Male: sql.NullBool{Bool: true, Valid: true}, + Height: sql.NullFloat64{Float64: 1.8888, Valid: true}, + Birthday: sql.NullTime{Time: time.Now(), Valid: true}, + Password: EncryptedData("pass1"), + Num: 18, + Strings: StringsSlice{"a", "b", "c"}, + Structs: StructsSlice{ + {"name1", "value1"}, + {"name2", "value2"}, + }, + } + + if err := DB.Create(&data).Error; err != nil { + t.Errorf("No error should happend when create scanner valuer struct, but got %v", err) + } + + var result ScannerValuerStruct + + if err := DB.Find(&result).Error; err != nil { + t.Errorf("no error should happen when query scanner, valuer struct, but got %v", err) + } + + AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Num", "Strings", "Structs") +} + +func TestInvalidValuer(t *testing.T) { + DB.Migrator().DropTable(&ScannerValuerStruct{}) + if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { + t.Errorf("no error should happen when migrate scanner, valuer struct") + } + + data := ScannerValuerStruct{ + Password: EncryptedData("xpass1"), + } + + if err := DB.Create(&data).Error; err == nil { + t.Errorf("Should failed to create data with invalid data") + } + + data.Password = EncryptedData("pass1") + if err := DB.Create(&data).Error; err != nil { + t.Errorf("Should got no error when creating data, but got %v", err) + } + + if err := DB.Model(&data).Update("password", EncryptedData("xnewpass")).Error; err == nil { + t.Errorf("Should failed to update data with invalid data") + } + + if err := DB.Model(&data).Update("password", EncryptedData("newpass")).Error; err != nil { + t.Errorf("Should got no error update data with valid data, but got %v", err) + } + + AssertEqual(t, data.Password, EncryptedData("newpass")) +} + +type ScannerValuerStruct struct { + gorm.Model + Name sql.NullString + Gender *sql.NullString + Age sql.NullInt64 + Male sql.NullBool + Height sql.NullFloat64 + Birthday sql.NullTime + Password EncryptedData + Num Num + Strings StringsSlice + Structs StructsSlice +} + +type EncryptedData []byte + +func (data *EncryptedData) Scan(value interface{}) error { + if b, ok := value.([]byte); ok { + if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' { + return errors.New("Too short") + } + + *data = b[3:] + return nil + } + + return errors.New("Bytes expected") +} + +func (data EncryptedData) Value() (driver.Value, error) { + if len(data) > 0 && data[0] == 'x' { + //needed to test failures + return nil, errors.New("Should not start with 'x'") + } + + //prepend asterisks + return append([]byte("***"), data...), nil +} + +type Num int64 + +func (i *Num) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + n, _ := strconv.Atoi(string(s)) + *i = Num(n) + case int64: + *i = Num(s) + default: + return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) + } + return nil +} + +type StringsSlice []string + +func (l StringsSlice) Value() (driver.Value, error) { + bytes, err := json.Marshal(l) + return string(bytes), err +} + +func (l *StringsSlice) Scan(input interface{}) error { + switch value := input.(type) { + case string: + return json.Unmarshal([]byte(value), l) + case []byte: + return json.Unmarshal(value, l) + default: + return errors.New("not supported") + } +} + +type ExampleStruct struct { + Name string + Value string +} + +type StructsSlice []ExampleStruct + +func (l StructsSlice) Value() (driver.Value, error) { + bytes, err := json.Marshal(l) + return string(bytes), err +} + +func (l *StructsSlice) Scan(input interface{}) error { + switch value := input.(type) { + case string: + return json.Unmarshal([]byte(value), l) + case []byte: + return json.Unmarshal(value, l) + default: + return errors.New("not supported") + } +} diff --git a/tests/utils.go b/tests/utils.go index 041dc9b1..dfddf848 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -1,6 +1,8 @@ package tests import ( + "database/sql/driver" + "fmt" "reflect" "sort" "strconv" @@ -89,12 +91,12 @@ func AssertEqual(t *testing.T, got, expect interface{}) { if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) { t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time).Round(time.Second).Format(format), curTime.Round(time.Second).Format(format)) } - } else if got != expect { + } else if fmt.Sprint(got) != fmt.Sprint(expect) { t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) } } - if got == expect { + if fmt.Sprint(got) == fmt.Sprint(expect) { return } @@ -103,6 +105,14 @@ func AssertEqual(t *testing.T, got, expect interface{}) { return } + if valuer, ok := got.(driver.Valuer); ok { + got, _ = valuer.Value() + } + + if valuer, ok := expect.(driver.Valuer); ok { + expect, _ = valuer.Value() + } + if got != nil { got = reflect.Indirect(reflect.ValueOf(got)).Interface() }