Add Scanner, Valuer tests
This commit is contained in:
		
							parent
							
								
									c422d75f4b
								
							
						
					
					
						commit
						c291c2f42c
					
				| @ -1,6 +1,9 @@ | |||||||
| package clause | package clause | ||||||
| 
 | 
 | ||||||
| import "reflect" | import ( | ||||||
|  | 	"database/sql/driver" | ||||||
|  | 	"reflect" | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| // Expression expression interface
 | // Expression expression interface
 | ||||||
| type Expression interface { | type Expression interface { | ||||||
| @ -28,16 +31,20 @@ func (expr Expr) Build(builder Builder) { | |||||||
| 	for _, v := range []byte(expr.SQL) { | 	for _, v := range []byte(expr.SQL) { | ||||||
| 		if v == '?' { | 		if v == '?' { | ||||||
| 			if afterParenthesis { | 			if afterParenthesis { | ||||||
| 				switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { | 				if _, ok := expr.Vars[idx].(driver.Valuer); ok { | ||||||
| 				case reflect.Slice, reflect.Array: |  | ||||||
| 					for i := 0; i < rv.Len(); i++ { |  | ||||||
| 						if i > 0 { |  | ||||||
| 							builder.WriteByte(',') |  | ||||||
| 						} |  | ||||||
| 						builder.AddVar(builder, rv.Index(i).Interface()) |  | ||||||
| 					} |  | ||||||
| 				default: |  | ||||||
| 					builder.AddVar(builder, expr.Vars[idx]) | 					builder.AddVar(builder, expr.Vars[idx]) | ||||||
|  | 				} else { | ||||||
|  | 					switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { | ||||||
|  | 					case reflect.Slice, reflect.Array: | ||||||
|  | 						for i := 0; i < rv.Len(); i++ { | ||||||
|  | 							if i > 0 { | ||||||
|  | 								builder.WriteByte(',') | ||||||
|  | 							} | ||||||
|  | 							builder.AddVar(builder, rv.Index(i).Interface()) | ||||||
|  | 						} | ||||||
|  | 					default: | ||||||
|  | 						builder.AddVar(builder, expr.Vars[idx]) | ||||||
|  | 					} | ||||||
| 				} | 				} | ||||||
| 			} else { | 			} else { | ||||||
| 				builder.AddVar(builder, expr.Vars[idx]) | 				builder.AddVar(builder, expr.Vars[idx]) | ||||||
|  | |||||||
| @ -57,6 +57,9 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v | |||||||
| 					vars[idx] = "NULL" | 					vars[idx] = "NULL" | ||||||
| 				} else if rv.Kind() == reflect.Ptr && rv.IsNil() { | 				} else if rv.Kind() == reflect.Ptr && rv.IsNil() { | ||||||
| 					vars[idx] = "NULL" | 					vars[idx] = "NULL" | ||||||
|  | 				} else if valuer, ok := v.(driver.Valuer); ok { | ||||||
|  | 					v, _ = valuer.Value() | ||||||
|  | 					convertParams(v, idx) | ||||||
| 				} else if rv.Kind() == reflect.Ptr && !rv.IsZero() { | 				} else if rv.Kind() == reflect.Ptr && !rv.IsZero() { | ||||||
| 					convertParams(reflect.Indirect(rv).Interface(), idx) | 					convertParams(reflect.Indirect(rv).Interface(), idx) | ||||||
| 				} else { | 				} else { | ||||||
| @ -74,10 +77,6 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for idx, v := range vars { | 	for idx, v := range vars { | ||||||
| 		if valuer, ok := v.(driver.Valuer); ok { |  | ||||||
| 			v, _ = valuer.Value() |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		convertParams(v, idx) | 		convertParams(v, idx) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -207,7 +207,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | |||||||
| 		field.DBDataType = val | 		field.DBDataType = val | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	switch fieldValue.Elem().Kind() { | 	switch reflect.Indirect(fieldValue).Kind() { | ||||||
| 	case reflect.Bool: | 	case reflect.Bool: | ||||||
| 		field.DataType = Bool | 		field.DataType = Bool | ||||||
| 		if field.HasDefaultValue { | 		if field.HasDefaultValue { | ||||||
|  | |||||||
| @ -146,6 +146,9 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { | |||||||
| 		case clause.Expr: | 		case clause.Expr: | ||||||
| 			writer.WriteString(v.SQL) | 			writer.WriteString(v.SQL) | ||||||
| 			stmt.Vars = append(stmt.Vars, v.Vars...) | 			stmt.Vars = append(stmt.Vars, v.Vars...) | ||||||
|  | 		case driver.Valuer: | ||||||
|  | 			stmt.Vars = append(stmt.Vars, v) | ||||||
|  | 			stmt.DB.Dialector.BindVarTo(writer, stmt, v) | ||||||
| 		case []interface{}: | 		case []interface{}: | ||||||
| 			if len(v) > 0 { | 			if len(v) > 0 { | ||||||
| 				writer.WriteByte('(') | 				writer.WriteByte('(') | ||||||
|  | |||||||
							
								
								
									
										175
									
								
								tests/scanner_valuer_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										175
									
								
								tests/scanner_valuer_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,175 @@ | |||||||
|  | package tests_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"database/sql/driver" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 	"reflect" | ||||||
|  | 	"strconv" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/jinzhu/gorm" | ||||||
|  | 	. "github.com/jinzhu/gorm/tests" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestScannerValuer(t *testing.T) { | ||||||
|  | 	DB.Migrator().DropTable(&ScannerValuerStruct{}) | ||||||
|  | 	if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { | ||||||
|  | 		t.Errorf("no error should happen when migrate scanner, valuer struct") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	data := ScannerValuerStruct{ | ||||||
|  | 		Name:     sql.NullString{String: "name", Valid: true}, | ||||||
|  | 		Gender:   &sql.NullString{String: "M", Valid: true}, | ||||||
|  | 		Age:      sql.NullInt64{Int64: 18, Valid: true}, | ||||||
|  | 		Male:     sql.NullBool{Bool: true, Valid: true}, | ||||||
|  | 		Height:   sql.NullFloat64{Float64: 1.8888, Valid: true}, | ||||||
|  | 		Birthday: sql.NullTime{Time: time.Now(), Valid: true}, | ||||||
|  | 		Password: EncryptedData("pass1"), | ||||||
|  | 		Num:      18, | ||||||
|  | 		Strings:  StringsSlice{"a", "b", "c"}, | ||||||
|  | 		Structs: StructsSlice{ | ||||||
|  | 			{"name1", "value1"}, | ||||||
|  | 			{"name2", "value2"}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Create(&data).Error; err != nil { | ||||||
|  | 		t.Errorf("No error should happend when create scanner valuer struct, but got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var result ScannerValuerStruct | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Find(&result).Error; err != nil { | ||||||
|  | 		t.Errorf("no error should happen when query scanner, valuer struct, but got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Num", "Strings", "Structs") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestInvalidValuer(t *testing.T) { | ||||||
|  | 	DB.Migrator().DropTable(&ScannerValuerStruct{}) | ||||||
|  | 	if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { | ||||||
|  | 		t.Errorf("no error should happen when migrate scanner, valuer struct") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	data := ScannerValuerStruct{ | ||||||
|  | 		Password: EncryptedData("xpass1"), | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Create(&data).Error; err == nil { | ||||||
|  | 		t.Errorf("Should failed to create data with invalid data") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	data.Password = EncryptedData("pass1") | ||||||
|  | 	if err := DB.Create(&data).Error; err != nil { | ||||||
|  | 		t.Errorf("Should got no error when creating data, but got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Model(&data).Update("password", EncryptedData("xnewpass")).Error; err == nil { | ||||||
|  | 		t.Errorf("Should failed to update data with invalid data") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Model(&data).Update("password", EncryptedData("newpass")).Error; err != nil { | ||||||
|  | 		t.Errorf("Should got no error update data with valid data, but got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	AssertEqual(t, data.Password, EncryptedData("newpass")) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type ScannerValuerStruct struct { | ||||||
|  | 	gorm.Model | ||||||
|  | 	Name     sql.NullString | ||||||
|  | 	Gender   *sql.NullString | ||||||
|  | 	Age      sql.NullInt64 | ||||||
|  | 	Male     sql.NullBool | ||||||
|  | 	Height   sql.NullFloat64 | ||||||
|  | 	Birthday sql.NullTime | ||||||
|  | 	Password EncryptedData | ||||||
|  | 	Num      Num | ||||||
|  | 	Strings  StringsSlice | ||||||
|  | 	Structs  StructsSlice | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type EncryptedData []byte | ||||||
|  | 
 | ||||||
|  | func (data *EncryptedData) Scan(value interface{}) error { | ||||||
|  | 	if b, ok := value.([]byte); ok { | ||||||
|  | 		if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' { | ||||||
|  | 			return errors.New("Too short") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		*data = b[3:] | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return errors.New("Bytes expected") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (data EncryptedData) Value() (driver.Value, error) { | ||||||
|  | 	if len(data) > 0 && data[0] == 'x' { | ||||||
|  | 		//needed to test failures
 | ||||||
|  | 		return nil, errors.New("Should not start with 'x'") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	//prepend asterisks
 | ||||||
|  | 	return append([]byte("***"), data...), nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type Num int64 | ||||||
|  | 
 | ||||||
|  | func (i *Num) Scan(src interface{}) error { | ||||||
|  | 	switch s := src.(type) { | ||||||
|  | 	case []byte: | ||||||
|  | 		n, _ := strconv.Atoi(string(s)) | ||||||
|  | 		*i = Num(n) | ||||||
|  | 	case int64: | ||||||
|  | 		*i = Num(s) | ||||||
|  | 	default: | ||||||
|  | 		return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type StringsSlice []string | ||||||
|  | 
 | ||||||
|  | func (l StringsSlice) Value() (driver.Value, error) { | ||||||
|  | 	bytes, err := json.Marshal(l) | ||||||
|  | 	return string(bytes), err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *StringsSlice) Scan(input interface{}) error { | ||||||
|  | 	switch value := input.(type) { | ||||||
|  | 	case string: | ||||||
|  | 		return json.Unmarshal([]byte(value), l) | ||||||
|  | 	case []byte: | ||||||
|  | 		return json.Unmarshal(value, l) | ||||||
|  | 	default: | ||||||
|  | 		return errors.New("not supported") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type ExampleStruct struct { | ||||||
|  | 	Name  string | ||||||
|  | 	Value string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type StructsSlice []ExampleStruct | ||||||
|  | 
 | ||||||
|  | func (l StructsSlice) Value() (driver.Value, error) { | ||||||
|  | 	bytes, err := json.Marshal(l) | ||||||
|  | 	return string(bytes), err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *StructsSlice) Scan(input interface{}) error { | ||||||
|  | 	switch value := input.(type) { | ||||||
|  | 	case string: | ||||||
|  | 		return json.Unmarshal([]byte(value), l) | ||||||
|  | 	case []byte: | ||||||
|  | 		return json.Unmarshal(value, l) | ||||||
|  | 	default: | ||||||
|  | 		return errors.New("not supported") | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @ -1,6 +1,8 @@ | |||||||
| package tests | package tests | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"database/sql/driver" | ||||||
|  | 	"fmt" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"sort" | 	"sort" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| @ -89,12 +91,12 @@ func AssertEqual(t *testing.T, got, expect interface{}) { | |||||||
| 				if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) { | 				if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) { | ||||||
| 					t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time).Round(time.Second).Format(format), curTime.Round(time.Second).Format(format)) | 					t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time).Round(time.Second).Format(format), curTime.Round(time.Second).Format(format)) | ||||||
| 				} | 				} | ||||||
| 			} else if got != expect { | 			} else if fmt.Sprint(got) != fmt.Sprint(expect) { | ||||||
| 				t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) | 				t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if got == expect { | 		if fmt.Sprint(got) == fmt.Sprint(expect) { | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| @ -103,6 +105,14 @@ func AssertEqual(t *testing.T, got, expect interface{}) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		if valuer, ok := got.(driver.Valuer); ok { | ||||||
|  | 			got, _ = valuer.Value() | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if valuer, ok := expect.(driver.Valuer); ok { | ||||||
|  | 			expect, _ = valuer.Value() | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		if got != nil { | 		if got != nil { | ||||||
| 			got = reflect.Indirect(reflect.ValueOf(got)).Interface() | 			got = reflect.Indirect(reflect.ValueOf(got)).Interface() | ||||||
| 		} | 		} | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu