Handle postgres array
This commit is contained in:
parent
f482f25c71
commit
4465ee6c90
@ -20,6 +20,10 @@ type Dialector interface {
|
|||||||
Explain(sql string, vars ...interface{}) string
|
Explain(sql string, vars ...interface{}) string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ArrayValueHandler interface {
|
||||||
|
HandleArray(field *schema.Field) error
|
||||||
|
}
|
||||||
|
|
||||||
// Plugin GORM plugin interface
|
// Plugin GORM plugin interface
|
||||||
type Plugin interface {
|
type Plugin interface {
|
||||||
Name() string
|
Name() string
|
||||||
|
@ -47,6 +47,7 @@ const (
|
|||||||
String DataType = "string"
|
String DataType = "string"
|
||||||
Time DataType = "time"
|
Time DataType = "time"
|
||||||
Bytes DataType = "bytes"
|
Bytes DataType = "bytes"
|
||||||
|
Array DataType = "array"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultAutoIncrementIncrement int64 = 1
|
const DefaultAutoIncrementIncrement int64 = 1
|
||||||
@ -282,6 +283,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
case reflect.Array, reflect.Slice:
|
case reflect.Array, reflect.Slice:
|
||||||
if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" {
|
if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" {
|
||||||
field.DataType = Bytes
|
field.DataType = Bytes
|
||||||
|
} else {
|
||||||
|
elemType := reflect.Indirect(fieldValue).Type().Elem()
|
||||||
|
field.DataType = Array
|
||||||
|
field.TagSettings["ELEM_TYPE"] = elemType.Kind().String()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -977,6 +982,10 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if field.DataType != "" && field.FieldType.Kind() == reflect.Slice && field.FieldType.Elem().Kind() != reflect.Uint8 {
|
||||||
|
field.TagSettings["ARRAY_FIELD"] = "true"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (field *Field) setupNewValuePool() {
|
func (field *Field) setupNewValuePool() {
|
||||||
|
@ -7,6 +7,8 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -65,6 +67,54 @@ func TestScannerValuer(t *testing.T) {
|
|||||||
AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs")
|
AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestScannerValuerArray(t *testing.T) {
|
||||||
|
// Use custom dialector to enable array handler
|
||||||
|
os.Setenv("GORM_DIALECT", "postgres")
|
||||||
|
os.Setenv("GORM_ENABLE_ARRAY_HANDLER", "true")
|
||||||
|
var err error
|
||||||
|
if DB, err = OpenTestConnection(&gorm.Config{}); err != nil {
|
||||||
|
log.Printf("failed to connect database, got error %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&ScannerValuerStructOfArrays{})
|
||||||
|
if err := DB.Migrator().AutoMigrate(&ScannerValuerStructOfArrays{}); err != nil {
|
||||||
|
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data := ScannerValuerStructOfArrays{
|
||||||
|
StringArray: []string{"a", "b", "c"},
|
||||||
|
IntArray: []int{1, 2, 3},
|
||||||
|
Int8Array: []int8{1, 2, 3},
|
||||||
|
Int16Array: []int16{1, 2, 3},
|
||||||
|
Int32Array: []int32{1, 2, 3},
|
||||||
|
Int64Array: []int64{1, 2, 3},
|
||||||
|
UintArray: []uint{1, 2, 3},
|
||||||
|
Uint16Array: []uint16{1, 2, 3},
|
||||||
|
Uint32Array: []uint32{1, 2, 3},
|
||||||
|
Uint64Array: []uint64{1, 2, 3},
|
||||||
|
Float32Array: []float32{
|
||||||
|
1.1, 2.2, 3.3,
|
||||||
|
},
|
||||||
|
Float64Array: []float64{
|
||||||
|
1.1, 2.2, 3.3,
|
||||||
|
},
|
||||||
|
BoolArray: []bool{true, false, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Create(&data).Error; err != nil {
|
||||||
|
t.Fatalf("No error should happened when create scanner valuer struct, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result ScannerValuerStructOfArrays
|
||||||
|
|
||||||
|
if err := DB.Find(&result, "id = ?", data.ID).Error; err != nil {
|
||||||
|
t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertObjEqual(t, data, result, "StringArray", "IntArray", "Int8Array", "Int16Array", "Int32Array", "Int64Array", "UintArray", "Uint16Array", "Uint32Array", "Uint64Array", "Float32Array", "Float64Array", "BoolArray")
|
||||||
|
}
|
||||||
|
|
||||||
func TestScannerValuerWithFirstOrCreate(t *testing.T) {
|
func TestScannerValuerWithFirstOrCreate(t *testing.T) {
|
||||||
DB.Migrator().DropTable(&ScannerValuerStruct{})
|
DB.Migrator().DropTable(&ScannerValuerStruct{})
|
||||||
if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil {
|
if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil {
|
||||||
@ -162,6 +212,23 @@ type ScannerValuerStruct struct {
|
|||||||
ExampleStructPtr *ExampleStruct
|
ExampleStructPtr *ExampleStruct
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ScannerValuerStructOfArrays struct {
|
||||||
|
gorm.Model
|
||||||
|
StringArray []string
|
||||||
|
IntArray []int
|
||||||
|
Int8Array []int8
|
||||||
|
Int16Array []int16
|
||||||
|
Int32Array []int32
|
||||||
|
Int64Array []int64
|
||||||
|
UintArray []uint
|
||||||
|
Uint16Array []uint16
|
||||||
|
Uint32Array []uint32
|
||||||
|
Uint64Array []uint64
|
||||||
|
Float32Array []float32
|
||||||
|
Float64Array []float64
|
||||||
|
BoolArray []bool
|
||||||
|
}
|
||||||
|
|
||||||
type EncryptedData []byte
|
type EncryptedData []byte
|
||||||
|
|
||||||
func (data *EncryptedData) Scan(value interface{}) error {
|
func (data *EncryptedData) Scan(value interface{}) error {
|
||||||
|
@ -48,6 +48,7 @@ func init() {
|
|||||||
|
|
||||||
func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
|
func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
|
||||||
dbDSN := os.Getenv("GORM_DSN")
|
dbDSN := os.Getenv("GORM_DSN")
|
||||||
|
enableArrayHandler := os.Getenv("GORM_ENABLE_ARRAY_HANDLER")
|
||||||
switch os.Getenv("GORM_DIALECT") {
|
switch os.Getenv("GORM_DIALECT") {
|
||||||
case "mysql":
|
case "mysql":
|
||||||
log.Println("testing mysql...")
|
log.Println("testing mysql...")
|
||||||
@ -63,6 +64,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
|
|||||||
db, err = gorm.Open(postgres.New(postgres.Config{
|
db, err = gorm.Open(postgres.New(postgres.Config{
|
||||||
DSN: dbDSN,
|
DSN: dbDSN,
|
||||||
PreferSimpleProtocol: true,
|
PreferSimpleProtocol: true,
|
||||||
|
EnableArrayHandler: enableArrayHandler == "true",
|
||||||
}), cfg)
|
}), cfg)
|
||||||
case "sqlserver":
|
case "sqlserver":
|
||||||
// go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest
|
// go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest
|
||||||
|
Loading…
x
Reference in New Issue
Block a user