diff --git a/utils/utils.go b/utils/utils.go index ddbca60a..739869cd 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -89,19 +89,28 @@ func Contains(elems []string, elem string) bool { return false } -func AssertEqual(src, dst interface{}) bool { - if !reflect.DeepEqual(src, dst) { - if valuer, ok := src.(driver.Valuer); ok { - src, _ = valuer.Value() - } - - if valuer, ok := dst.(driver.Valuer); ok { - dst, _ = valuer.Value() - } - - return reflect.DeepEqual(src, dst) +func AssertEqual(x, y interface{}) bool { + if reflect.DeepEqual(x, y) { + return true } - return true + if x == nil || y == nil { + return false + } + xval := reflect.ValueOf(x) + yval := reflect.ValueOf(y) + if xval.Kind() == reflect.Ptr && xval.IsNil() || + yval.Kind() == reflect.Ptr && yval.IsNil() { + return false + } + + if valuer, ok := x.(driver.Valuer); ok { + x, _ = valuer.Value() + } + + if valuer, ok := y.(driver.Valuer); ok { + y, _ = valuer.Value() + } + return reflect.DeepEqual(x, y) } func ToString(value interface{}) string { diff --git a/utils/utils_test.go b/utils/utils_test.go index 71eef964..a4a30595 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -3,8 +3,10 @@ package utils import ( "database/sql" "database/sql/driver" + "encoding/json" "errors" "math" + "reflect" "strings" "testing" "time" @@ -87,7 +89,30 @@ func (n ModifyAt) Value() (driver.Value, error) { return n.Time.Unix(), nil } +type datatypesJSON json.RawMessage + +// Value return json value, implement driver.Valuer interface +func (j datatypesJSON) Value() (driver.Value, error) { + if len(j) == 0 { + return nil, nil + } + return []byte(j), nil +} + func TestAssertEqual(t *testing.T) { + type model struct { + Raw *datatypesJSON + } + + m1 := model{} + f1 := reflect.Indirect(reflect.ValueOf(m1)).Field(0) + i1 := f1.Interface() + + raw := datatypesJSON("dreggn") + m2 := model{Raw: &raw} + f2 := reflect.Indirect(reflect.ValueOf(m2)).Field(0) + i2 := f2.Interface() + now := time.Now() assertEqualTests := []struct { name string @@ -98,6 +123,7 @@ func TestAssertEqual(t *testing.T) { {"error not equal", errors.New("1"), errors.New("2"), false}, {"driver.Valuer equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now, Valid: true}, true}, {"driver.Valuer not equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now.Add(time.Second), Valid: true}, false}, + {"driver.Valuer equal (ptr to nil ptr)", i1, i2, false}, } for _, test := range assertEqualTests { t.Run(test.name, func(t *testing.T) {