Add Scanner, Valuer tests
This commit is contained in:
		
							parent
							
								
									c422d75f4b
								
							
						
					
					
						commit
						c291c2f42c
					
				| @ -1,6 +1,9 @@ | ||||
| package clause | ||||
| 
 | ||||
| import "reflect" | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"reflect" | ||||
| ) | ||||
| 
 | ||||
| // Expression expression interface
 | ||||
| type Expression interface { | ||||
| @ -28,16 +31,20 @@ func (expr Expr) Build(builder Builder) { | ||||
| 	for _, v := range []byte(expr.SQL) { | ||||
| 		if v == '?' { | ||||
| 			if afterParenthesis { | ||||
| 				switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { | ||||
| 				case reflect.Slice, reflect.Array: | ||||
| 					for i := 0; i < rv.Len(); i++ { | ||||
| 						if i > 0 { | ||||
| 							builder.WriteByte(',') | ||||
| 						} | ||||
| 						builder.AddVar(builder, rv.Index(i).Interface()) | ||||
| 					} | ||||
| 				default: | ||||
| 				if _, ok := expr.Vars[idx].(driver.Valuer); ok { | ||||
| 					builder.AddVar(builder, expr.Vars[idx]) | ||||
| 				} else { | ||||
| 					switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { | ||||
| 					case reflect.Slice, reflect.Array: | ||||
| 						for i := 0; i < rv.Len(); i++ { | ||||
| 							if i > 0 { | ||||
| 								builder.WriteByte(',') | ||||
| 							} | ||||
| 							builder.AddVar(builder, rv.Index(i).Interface()) | ||||
| 						} | ||||
| 					default: | ||||
| 						builder.AddVar(builder, expr.Vars[idx]) | ||||
| 					} | ||||
| 				} | ||||
| 			} else { | ||||
| 				builder.AddVar(builder, expr.Vars[idx]) | ||||
|  | ||||
| @ -57,6 +57,9 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v | ||||
| 					vars[idx] = "NULL" | ||||
| 				} else if rv.Kind() == reflect.Ptr && rv.IsNil() { | ||||
| 					vars[idx] = "NULL" | ||||
| 				} else if valuer, ok := v.(driver.Valuer); ok { | ||||
| 					v, _ = valuer.Value() | ||||
| 					convertParams(v, idx) | ||||
| 				} else if rv.Kind() == reflect.Ptr && !rv.IsZero() { | ||||
| 					convertParams(reflect.Indirect(rv).Interface(), idx) | ||||
| 				} else { | ||||
| @ -74,10 +77,6 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v | ||||
| 	} | ||||
| 
 | ||||
| 	for idx, v := range vars { | ||||
| 		if valuer, ok := v.(driver.Valuer); ok { | ||||
| 			v, _ = valuer.Value() | ||||
| 		} | ||||
| 
 | ||||
| 		convertParams(v, idx) | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
| @ -207,7 +207,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | ||||
| 		field.DBDataType = val | ||||
| 	} | ||||
| 
 | ||||
| 	switch fieldValue.Elem().Kind() { | ||||
| 	switch reflect.Indirect(fieldValue).Kind() { | ||||
| 	case reflect.Bool: | ||||
| 		field.DataType = Bool | ||||
| 		if field.HasDefaultValue { | ||||
|  | ||||
| @ -146,6 +146,9 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { | ||||
| 		case clause.Expr: | ||||
| 			writer.WriteString(v.SQL) | ||||
| 			stmt.Vars = append(stmt.Vars, v.Vars...) | ||||
| 		case driver.Valuer: | ||||
| 			stmt.Vars = append(stmt.Vars, v) | ||||
| 			stmt.DB.Dialector.BindVarTo(writer, stmt, v) | ||||
| 		case []interface{}: | ||||
| 			if len(v) > 0 { | ||||
| 				writer.WriteByte('(') | ||||
|  | ||||
							
								
								
									
										175
									
								
								tests/scanner_valuer_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										175
									
								
								tests/scanner_valuer_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,175 @@ | ||||
| package tests_test | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"database/sql/driver" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"reflect" | ||||
| 	"strconv" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	. "github.com/jinzhu/gorm/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestScannerValuer(t *testing.T) { | ||||
| 	DB.Migrator().DropTable(&ScannerValuerStruct{}) | ||||
| 	if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { | ||||
| 		t.Errorf("no error should happen when migrate scanner, valuer struct") | ||||
| 	} | ||||
| 
 | ||||
| 	data := ScannerValuerStruct{ | ||||
| 		Name:     sql.NullString{String: "name", Valid: true}, | ||||
| 		Gender:   &sql.NullString{String: "M", Valid: true}, | ||||
| 		Age:      sql.NullInt64{Int64: 18, Valid: true}, | ||||
| 		Male:     sql.NullBool{Bool: true, Valid: true}, | ||||
| 		Height:   sql.NullFloat64{Float64: 1.8888, Valid: true}, | ||||
| 		Birthday: sql.NullTime{Time: time.Now(), Valid: true}, | ||||
| 		Password: EncryptedData("pass1"), | ||||
| 		Num:      18, | ||||
| 		Strings:  StringsSlice{"a", "b", "c"}, | ||||
| 		Structs: StructsSlice{ | ||||
| 			{"name1", "value1"}, | ||||
| 			{"name2", "value2"}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Create(&data).Error; err != nil { | ||||
| 		t.Errorf("No error should happend when create scanner valuer struct, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var result ScannerValuerStruct | ||||
| 
 | ||||
| 	if err := DB.Find(&result).Error; err != nil { | ||||
| 		t.Errorf("no error should happen when query scanner, valuer struct, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Num", "Strings", "Structs") | ||||
| } | ||||
| 
 | ||||
| func TestInvalidValuer(t *testing.T) { | ||||
| 	DB.Migrator().DropTable(&ScannerValuerStruct{}) | ||||
| 	if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { | ||||
| 		t.Errorf("no error should happen when migrate scanner, valuer struct") | ||||
| 	} | ||||
| 
 | ||||
| 	data := ScannerValuerStruct{ | ||||
| 		Password: EncryptedData("xpass1"), | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Create(&data).Error; err == nil { | ||||
| 		t.Errorf("Should failed to create data with invalid data") | ||||
| 	} | ||||
| 
 | ||||
| 	data.Password = EncryptedData("pass1") | ||||
| 	if err := DB.Create(&data).Error; err != nil { | ||||
| 		t.Errorf("Should got no error when creating data, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Model(&data).Update("password", EncryptedData("xnewpass")).Error; err == nil { | ||||
| 		t.Errorf("Should failed to update data with invalid data") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Model(&data).Update("password", EncryptedData("newpass")).Error; err != nil { | ||||
| 		t.Errorf("Should got no error update data with valid data, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	AssertEqual(t, data.Password, EncryptedData("newpass")) | ||||
| } | ||||
| 
 | ||||
| type ScannerValuerStruct struct { | ||||
| 	gorm.Model | ||||
| 	Name     sql.NullString | ||||
| 	Gender   *sql.NullString | ||||
| 	Age      sql.NullInt64 | ||||
| 	Male     sql.NullBool | ||||
| 	Height   sql.NullFloat64 | ||||
| 	Birthday sql.NullTime | ||||
| 	Password EncryptedData | ||||
| 	Num      Num | ||||
| 	Strings  StringsSlice | ||||
| 	Structs  StructsSlice | ||||
| } | ||||
| 
 | ||||
| type EncryptedData []byte | ||||
| 
 | ||||
| func (data *EncryptedData) Scan(value interface{}) error { | ||||
| 	if b, ok := value.([]byte); ok { | ||||
| 		if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' { | ||||
| 			return errors.New("Too short") | ||||
| 		} | ||||
| 
 | ||||
| 		*data = b[3:] | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	return errors.New("Bytes expected") | ||||
| } | ||||
| 
 | ||||
| func (data EncryptedData) Value() (driver.Value, error) { | ||||
| 	if len(data) > 0 && data[0] == 'x' { | ||||
| 		//needed to test failures
 | ||||
| 		return nil, errors.New("Should not start with 'x'") | ||||
| 	} | ||||
| 
 | ||||
| 	//prepend asterisks
 | ||||
| 	return append([]byte("***"), data...), nil | ||||
| } | ||||
| 
 | ||||
| type Num int64 | ||||
| 
 | ||||
| func (i *Num) Scan(src interface{}) error { | ||||
| 	switch s := src.(type) { | ||||
| 	case []byte: | ||||
| 		n, _ := strconv.Atoi(string(s)) | ||||
| 		*i = Num(n) | ||||
| 	case int64: | ||||
| 		*i = Num(s) | ||||
| 	default: | ||||
| 		return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| type StringsSlice []string | ||||
| 
 | ||||
| func (l StringsSlice) Value() (driver.Value, error) { | ||||
| 	bytes, err := json.Marshal(l) | ||||
| 	return string(bytes), err | ||||
| } | ||||
| 
 | ||||
| func (l *StringsSlice) Scan(input interface{}) error { | ||||
| 	switch value := input.(type) { | ||||
| 	case string: | ||||
| 		return json.Unmarshal([]byte(value), l) | ||||
| 	case []byte: | ||||
| 		return json.Unmarshal(value, l) | ||||
| 	default: | ||||
| 		return errors.New("not supported") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type ExampleStruct struct { | ||||
| 	Name  string | ||||
| 	Value string | ||||
| } | ||||
| 
 | ||||
| type StructsSlice []ExampleStruct | ||||
| 
 | ||||
| func (l StructsSlice) Value() (driver.Value, error) { | ||||
| 	bytes, err := json.Marshal(l) | ||||
| 	return string(bytes), err | ||||
| } | ||||
| 
 | ||||
| func (l *StructsSlice) Scan(input interface{}) error { | ||||
| 	switch value := input.(type) { | ||||
| 	case string: | ||||
| 		return json.Unmarshal([]byte(value), l) | ||||
| 	case []byte: | ||||
| 		return json.Unmarshal(value, l) | ||||
| 	default: | ||||
| 		return errors.New("not supported") | ||||
| 	} | ||||
| } | ||||
| @ -1,6 +1,8 @@ | ||||
| package tests | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| @ -89,12 +91,12 @@ func AssertEqual(t *testing.T, got, expect interface{}) { | ||||
| 				if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) { | ||||
| 					t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time).Round(time.Second).Format(format), curTime.Round(time.Second).Format(format)) | ||||
| 				} | ||||
| 			} else if got != expect { | ||||
| 			} else if fmt.Sprint(got) != fmt.Sprint(expect) { | ||||
| 				t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if got == expect { | ||||
| 		if fmt.Sprint(got) == fmt.Sprint(expect) { | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| @ -103,6 +105,14 @@ func AssertEqual(t *testing.T, got, expect interface{}) { | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		if valuer, ok := got.(driver.Valuer); ok { | ||||
| 			got, _ = valuer.Value() | ||||
| 		} | ||||
| 
 | ||||
| 		if valuer, ok := expect.(driver.Valuer); ok { | ||||
| 			expect, _ = valuer.Value() | ||||
| 		} | ||||
| 
 | ||||
| 		if got != nil { | ||||
| 			got = reflect.Indirect(reflect.ValueOf(got)).Interface() | ||||
| 		} | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu