From cbd98687af390959267d4cd22f23358aae214b27 Mon Sep 17 00:00:00 2001 From: zaneli Date: Tue, 5 Nov 2019 08:40:41 +0900 Subject: [PATCH] Add `gorm:bytes_value_as_array` tag --- scope.go | 10 +++++++++- scope_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/scope.go b/scope.go index eb7525b8..a76ebd0d 100644 --- a/scope.go +++ b/scope.go @@ -602,6 +602,14 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) replacements := []string{} args := clause["args"].([]interface{}) + + var bytesValueAsArray bool + if v, ok := scope.Get("gorm:bytes_value_as_array"); ok { + if b, ok := v.(bool); ok { + bytesValueAsArray = b + } + } + for _, arg := range args { var err error switch reflect.ValueOf(arg).Kind() { @@ -609,7 +617,7 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) if scanner, ok := interface{}(arg).(driver.Valuer); ok { arg, err = scanner.Value() replacements = append(replacements, scope.AddToVars(arg)) - } else if b, ok := arg.([]byte); ok { + } else if b, ok := arg.([]byte); ok && !bytesValueAsArray { replacements = append(replacements, scope.AddToVars(b)) } else if as, ok := arg.([][]interface{}); ok { var tempMarks []string diff --git a/scope_test.go b/scope_test.go index f7f1ed08..f9377d6c 100644 --- a/scope_test.go +++ b/scope_test.go @@ -3,6 +3,7 @@ package gorm_test import ( "encoding/hex" "math/rand" + "os" "strings" "testing" @@ -91,3 +92,57 @@ func TestDropTableWithTableOptions(t *testing.T) { t.Errorf("Table must be dropped, got error %s", err) } } + +func TestBytesValueAsArray(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" { + t.Skip("Skipping this because only postgres and mysql support bytes as string") + } + + users := []*User{ + &User{Name: "bytesValueAsArray1", Age: 1}, + &User{Name: "bytesValueAsArray2", Age: 2}, + } + + for _, user := range users { + if err := DB.Save(user).Error; err != nil { + t.Fatal(err) + } + } + defer func() { + for _, user := range users { + if err := DB.Delete(user).Error; err != nil { + t.Fatal(err) + } + } + }() + + var user User + db := DB.New() + if err := db.Where("name = ?", []byte(users[0].Name)).First(&user).Error; err != nil { + t.Error(err) + } + if user.Id != users[0].Id { + t.Errorf("user.Id expected %d, but got %d", users[0].Id, user.Id) + } + + var ids []int64 + db = DB.New() + if err := db.Model(&User{}). + Set("gorm:bytes_value_as_array", true). + Where("name LIKE ?", "bytesValueAsArray%"). + Where("age IN (?)", []byte{byte(users[0].Age), byte(users[1].Age)}). + Order("age"). + Pluck("id", &ids). + Error; err != nil { + t.Error(err) + } + if len(ids) != 2 { + t.Errorf("ids length expected 2, but got %d", len(ids)) + } + if ids[0] != users[0].Id { + t.Errorf("ids[0] expected %d, but got %d", users[0].Id, ids[0]) + } + if ids[1] != users[1].Id { + t.Errorf("ids[1] expected %d, but got %d", users[0].Id, ids[0]) + } +}