
use golangci/golangci-lint-action instead of reviewdog/action-golangci-lint as the second was not reporting any failures even if there was some. Report code coverage with codecov/codecov-action I have set some flags per dialect and go version Several linters has been fixed, some disabled so the build can pass
394 lines
11 KiB
Go
394 lines
11 KiB
Go
package tests_test
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"regexp"
|
|
"strconv"
|
|
"testing"
|
|
"time"
|
|
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
. "gorm.io/gorm/utils/tests"
|
|
)
|
|
|
|
func TestScannerValuer(t *testing.T) {
|
|
DB.Migrator().DropTable(&ScannerValuerStruct{})
|
|
if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil {
|
|
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
|
|
}
|
|
|
|
data := ScannerValuerStruct{
|
|
Name: sql.NullString{String: "name", Valid: true},
|
|
Gender: &sql.NullString{String: "M", Valid: true},
|
|
Age: sql.NullInt64{Int64: 18, Valid: true},
|
|
Male: sql.NullBool{Bool: true, Valid: true},
|
|
Height: sql.NullFloat64{Float64: 1.8888, Valid: true},
|
|
Birthday: sql.NullTime{Time: time.Now(), Valid: true},
|
|
Allergen: NullString{sql.NullString{String: "Allergen", Valid: true}},
|
|
Password: EncryptedData("pass1"),
|
|
Bytes: []byte("byte"),
|
|
Num: 18,
|
|
Strings: StringsSlice{"a", "b", "c"},
|
|
Structs: StructsSlice{
|
|
{"name1", "value1"},
|
|
{"name2", "value2"},
|
|
},
|
|
Role: Role{Name: "admin"},
|
|
ExampleStruct: ExampleStruct{"name", "value1"},
|
|
ExampleStructPtr: &ExampleStruct{"name", "value2"},
|
|
}
|
|
|
|
if err := DB.Create(&data).Error; err != nil {
|
|
t.Fatalf("No error should happened when create scanner valuer struct, but got %v", err)
|
|
}
|
|
|
|
var result ScannerValuerStruct
|
|
|
|
if err := DB.Find(&result, "id = ?", data.ID).Error; err != nil {
|
|
t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err)
|
|
}
|
|
|
|
if result.ExampleStructPtr.Val != "value2" {
|
|
t.Errorf(`ExampleStructPtr.Val should equal to "value2", but got %v`, result.ExampleStructPtr.Val)
|
|
}
|
|
|
|
if result.ExampleStruct.Val != "value1" {
|
|
t.Errorf(`ExampleStruct.Val should equal to "value1", but got %#v`, result.ExampleStruct)
|
|
}
|
|
AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs")
|
|
}
|
|
|
|
func TestScannerValuerWithFirstOrCreate(t *testing.T) {
|
|
DB.Migrator().DropTable(&ScannerValuerStruct{})
|
|
if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil {
|
|
t.Errorf("no error should happen when migrate scanner, valuer struct")
|
|
}
|
|
|
|
data := ScannerValuerStruct{
|
|
Name: sql.NullString{String: "name", Valid: true},
|
|
Gender: &sql.NullString{String: "M", Valid: true},
|
|
Age: sql.NullInt64{Int64: 18, Valid: true},
|
|
ExampleStruct: ExampleStruct{"name", "value1"},
|
|
ExampleStructPtr: &ExampleStruct{"name", "value2"},
|
|
}
|
|
|
|
var result ScannerValuerStruct
|
|
tx := DB.Where(data).FirstOrCreate(&result)
|
|
|
|
if tx.RowsAffected != 1 {
|
|
t.Errorf("RowsAffected should be 1 after create some record")
|
|
}
|
|
|
|
if tx.Error != nil {
|
|
t.Errorf("Should not raise any error, but got %v", tx.Error)
|
|
}
|
|
|
|
AssertObjEqual(t, result, data, "Name", "Gender", "Age")
|
|
|
|
if err := DB.Where(data).Assign(ScannerValuerStruct{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&result).Error; err != nil {
|
|
t.Errorf("Should not raise any error, but got %v", err)
|
|
}
|
|
|
|
if result.Age.Int64 != 18 {
|
|
t.Errorf("should update age to 18")
|
|
}
|
|
|
|
var result2 ScannerValuerStruct
|
|
if err := DB.First(&result2, result.ID).Error; err != nil {
|
|
t.Errorf("got error %v when query with %v", err, result.ID)
|
|
}
|
|
|
|
AssertObjEqual(t, result2, result, "ID", "CreatedAt", "UpdatedAt", "Name", "Gender", "Age")
|
|
}
|
|
|
|
func TestInvalidValuer(t *testing.T) {
|
|
DB.Migrator().DropTable(&ScannerValuerStruct{})
|
|
if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil {
|
|
t.Errorf("no error should happen when migrate scanner, valuer struct")
|
|
}
|
|
|
|
data := ScannerValuerStruct{
|
|
Password: EncryptedData("xpass1"),
|
|
ExampleStruct: ExampleStruct{"name", "value1"},
|
|
ExampleStructPtr: &ExampleStruct{"name", "value2"},
|
|
}
|
|
|
|
if err := DB.Create(&data).Error; err == nil {
|
|
t.Errorf("Should failed to create data with invalid data")
|
|
}
|
|
|
|
data.Password = EncryptedData("pass1")
|
|
if err := DB.Create(&data).Error; err != nil {
|
|
t.Errorf("Should got no error when creating data, but got %v", err)
|
|
}
|
|
|
|
if err := DB.Model(&data).Update("password", EncryptedData("xnewpass")).Error; err == nil {
|
|
t.Errorf("Should failed to update data with invalid data")
|
|
}
|
|
|
|
if err := DB.Model(&data).Update("password", EncryptedData("newpass")).Error; err != nil {
|
|
t.Errorf("Should got no error update data with valid data, but got %v", err)
|
|
}
|
|
|
|
AssertEqual(t, data.Password, EncryptedData("newpass"))
|
|
}
|
|
|
|
type ScannerValuerStruct struct {
|
|
gorm.Model
|
|
Name sql.NullString
|
|
Gender *sql.NullString
|
|
Age sql.NullInt64
|
|
Male sql.NullBool
|
|
Height sql.NullFloat64
|
|
Birthday sql.NullTime
|
|
Allergen NullString
|
|
Password EncryptedData
|
|
Bytes []byte
|
|
Num Num
|
|
Strings StringsSlice
|
|
Structs StructsSlice
|
|
Role Role
|
|
UserID *sql.NullInt64
|
|
User User
|
|
EmptyTime EmptyTime
|
|
ExampleStruct ExampleStruct
|
|
ExampleStructPtr *ExampleStruct
|
|
}
|
|
|
|
type EncryptedData []byte
|
|
|
|
func (data *EncryptedData) Scan(value interface{}) error {
|
|
if b, ok := value.([]byte); ok {
|
|
if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' {
|
|
return errors.New("Too short")
|
|
}
|
|
|
|
*data = b[3:]
|
|
return nil
|
|
} else if s, ok := value.(string); ok {
|
|
*data = []byte(s)[3:]
|
|
return nil
|
|
}
|
|
|
|
return errors.New("Bytes expected")
|
|
}
|
|
|
|
func (data EncryptedData) Value() (driver.Value, error) {
|
|
if len(data) > 0 && data[0] == 'x' {
|
|
// needed to test failures
|
|
return nil, errors.New("Should not start with 'x'")
|
|
}
|
|
|
|
// prepend asterisks
|
|
return append([]byte("***"), data...), nil
|
|
}
|
|
|
|
type Num int64
|
|
|
|
func (i *Num) Scan(src interface{}) error {
|
|
switch s := src.(type) {
|
|
case []byte:
|
|
n, _ := strconv.Atoi(string(s))
|
|
*i = Num(n)
|
|
case int64:
|
|
*i = Num(s)
|
|
default:
|
|
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type StringsSlice []string
|
|
|
|
func (l StringsSlice) Value() (driver.Value, error) {
|
|
bytes, err := json.Marshal(l)
|
|
return string(bytes), err
|
|
}
|
|
|
|
func (l *StringsSlice) Scan(input interface{}) error {
|
|
switch value := input.(type) {
|
|
case string:
|
|
return json.Unmarshal([]byte(value), l)
|
|
case []byte:
|
|
return json.Unmarshal(value, l)
|
|
default:
|
|
return errors.New("not supported")
|
|
}
|
|
}
|
|
|
|
type ExampleStruct struct {
|
|
Name string
|
|
Val string
|
|
}
|
|
|
|
func (ExampleStruct) GormDataType() string {
|
|
return "bytes"
|
|
}
|
|
|
|
func (s ExampleStruct) Value() (driver.Value, error) {
|
|
if len(s.Name) == 0 {
|
|
return nil, nil
|
|
}
|
|
// for test, has no practical meaning
|
|
s.Name = ""
|
|
return json.Marshal(s)
|
|
}
|
|
|
|
func (s *ExampleStruct) Scan(src interface{}) error {
|
|
switch value := src.(type) {
|
|
case string:
|
|
return json.Unmarshal([]byte(value), s)
|
|
case []byte:
|
|
return json.Unmarshal(value, s)
|
|
default:
|
|
return errors.New("not supported")
|
|
}
|
|
}
|
|
|
|
type StructsSlice []ExampleStruct
|
|
|
|
func (l StructsSlice) Value() (driver.Value, error) {
|
|
bytes, err := json.Marshal(l)
|
|
return string(bytes), err
|
|
}
|
|
|
|
func (l *StructsSlice) Scan(input interface{}) error {
|
|
switch value := input.(type) {
|
|
case string:
|
|
return json.Unmarshal([]byte(value), l)
|
|
case []byte:
|
|
return json.Unmarshal(value, l)
|
|
default:
|
|
return errors.New("not supported")
|
|
}
|
|
}
|
|
|
|
type Role struct {
|
|
Name string `gorm:"size:256"`
|
|
}
|
|
|
|
func (role *Role) Scan(value interface{}) error {
|
|
if b, ok := value.([]uint8); ok {
|
|
role.Name = string(b)
|
|
} else {
|
|
role.Name = value.(string)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (role Role) Value() (driver.Value, error) {
|
|
return role.Name, nil
|
|
}
|
|
|
|
func (role Role) IsAdmin() bool {
|
|
return role.Name == "admin"
|
|
}
|
|
|
|
type EmptyTime struct {
|
|
time.Time
|
|
}
|
|
|
|
func (t *EmptyTime) Scan(v interface{}) error {
|
|
nullTime := sql.NullTime{}
|
|
err := nullTime.Scan(v)
|
|
t.Time = nullTime.Time
|
|
return err
|
|
}
|
|
|
|
func (t EmptyTime) Value() (driver.Value, error) {
|
|
return time.Now() /* pass tests, mysql 8 doesn't support 0000-00-00 by default */, nil
|
|
}
|
|
|
|
type NullString struct {
|
|
sql.NullString
|
|
}
|
|
|
|
type Point struct {
|
|
X, Y int
|
|
}
|
|
|
|
func (point Point) GormDataType() string {
|
|
return "geo"
|
|
}
|
|
|
|
func (point Point) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
|
|
return clause.Expr{
|
|
SQL: "ST_PointFromText(?)",
|
|
Vars: []interface{}{fmt.Sprintf("POINT(%d %d)", point.X, point.Y)},
|
|
}
|
|
}
|
|
|
|
func TestGORMValuer(t *testing.T) {
|
|
type UserWithPoint struct {
|
|
Name string
|
|
Point Point
|
|
}
|
|
|
|
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
|
|
|
|
stmt := dryRunDB.Create(&UserWithPoint{
|
|
Name: "jinzhu",
|
|
Point: Point{X: 100, Y: 100},
|
|
}).Statement
|
|
|
|
if stmt.SQL.String() == "" || len(stmt.Vars) != 2 {
|
|
t.Errorf("Failed to generate sql, got %v", stmt.SQL.String())
|
|
}
|
|
|
|
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
|
|
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
|
|
}
|
|
|
|
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
|
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
|
}
|
|
|
|
stmt = dryRunDB.Model(UserWithPoint{}).Create(map[string]interface{}{
|
|
"Name": "jinzhu",
|
|
"Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}},
|
|
}).Statement
|
|
|
|
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
|
|
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
|
|
}
|
|
|
|
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
|
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
|
}
|
|
|
|
stmt = dryRunDB.Table("user_with_points").Create(&map[string]interface{}{
|
|
"Name": "jinzhu",
|
|
"Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}},
|
|
}).Statement
|
|
|
|
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.Name.,.Point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
|
|
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
|
|
}
|
|
|
|
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
|
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
|
}
|
|
|
|
stmt = dryRunDB.Session(&gorm.Session{
|
|
AllowGlobalUpdate: true,
|
|
}).Model(&UserWithPoint{}).Updates(UserWithPoint{
|
|
Name: "jinzhu",
|
|
Point: Point{X: 100, Y: 100},
|
|
}).Statement
|
|
|
|
if !regexp.MustCompile(`UPDATE .user_with_points. SET .name.=.+,.point.=ST_PointFromText\(.+\)`).MatchString(stmt.SQL.String()) {
|
|
t.Errorf("update with sql.Expr, but got %v", stmt.SQL.String())
|
|
}
|
|
|
|
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
|
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
|
}
|
|
}
|