feat: support complexity field with special tag saved as json.
This commit is contained in:
parent
68920449f9
commit
95a7cb5291
12
scan.go
12
scan.go
@ -12,7 +12,7 @@ import (
|
||||
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
|
||||
if db.Statement.Schema != nil {
|
||||
for idx, name := range columns {
|
||||
if field := db.Statement.Schema.LookUpField(name); field != nil {
|
||||
if field := db.Statement.Schema.LookUpField(name); field != nil && !field.XJSON {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
|
||||
continue
|
||||
}
|
||||
@ -145,7 +145,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||
} else {
|
||||
for idx, field := range fields {
|
||||
if field != nil {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||
if field.XJSON {
|
||||
values[idx] = &sql.RawBytes{}
|
||||
} else {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -183,11 +187,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||
|
||||
if initialized || rows.Next() {
|
||||
for idx, column := range columns {
|
||||
if field := Schema.LookUpField(column); field != nil && field.Readable {
|
||||
if field := Schema.LookUpField(column); field != nil && field.Readable && !field.XJSON {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable && !field.XJSON {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||
continue
|
||||
}
|
||||
|
@ -3,6 +3,8 @@ package schema
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
@ -67,6 +69,8 @@ type Field struct {
|
||||
ReflectValueOf func(reflect.Value) reflect.Value
|
||||
ValueOf func(reflect.Value) (value interface{}, zero bool)
|
||||
Set func(reflect.Value, interface{}) error
|
||||
|
||||
XJSON bool
|
||||
}
|
||||
|
||||
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
@ -178,6 +182,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
field.Comment = val
|
||||
}
|
||||
|
||||
if val, ok := field.TagSettings["XJSON"]; ok && utils.CheckTruth(val) {
|
||||
field.XJSON = true
|
||||
}
|
||||
|
||||
// default value is function or null or blank (primary keys)
|
||||
skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") &&
|
||||
strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == ""
|
||||
@ -381,13 +389,23 @@ func (field *Field) setupValuerAndSetter() {
|
||||
switch {
|
||||
case len(field.StructField.Index) == 1:
|
||||
field.ValueOf = func(value reflect.Value) (interface{}, bool) {
|
||||
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
|
||||
return fieldValue.Interface(), fieldValue.IsZero()
|
||||
v := reflect.Indirect(value).Field(field.StructField.Index[0])
|
||||
if field.XJSON {
|
||||
b, _ := json.Marshal(v.Interface())
|
||||
// pgx will santize bytes with special prefix
|
||||
return string(b), v.IsZero()
|
||||
}
|
||||
return v.Interface(), v.IsZero()
|
||||
}
|
||||
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
|
||||
field.ValueOf = func(value reflect.Value) (interface{}, bool) {
|
||||
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
|
||||
return fieldValue.Interface(), fieldValue.IsZero()
|
||||
v := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
|
||||
if field.XJSON {
|
||||
b, _ := json.Marshal(v.Interface())
|
||||
// pgx will santize bytes with special prefix
|
||||
return string(b), v.IsZero()
|
||||
}
|
||||
return v.Interface(), v.IsZero()
|
||||
}
|
||||
default:
|
||||
field.ValueOf = func(value reflect.Value) (interface{}, bool) {
|
||||
@ -410,6 +428,11 @@ func (field *Field) setupValuerAndSetter() {
|
||||
}
|
||||
}
|
||||
}
|
||||
if field.XJSON {
|
||||
// pgx will santize bytes with special prefix
|
||||
b, _ := json.Marshal(v.Interface())
|
||||
return string(b), v.IsZero()
|
||||
}
|
||||
return v.Interface(), v.IsZero()
|
||||
}
|
||||
}
|
||||
@ -500,6 +523,16 @@ func (field *Field) setupValuerAndSetter() {
|
||||
if v, err = valuer.Value(); err == nil {
|
||||
err = setter(value, v)
|
||||
}
|
||||
} else if field.XJSON {
|
||||
if s, ok := v.(sql.RawBytes); ok {
|
||||
p := reflect.New(field.FieldType).Interface()
|
||||
err = json.Unmarshal([]byte(s), p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return setter(value, p)
|
||||
}
|
||||
return errors.New("inlvaid type for tag: xjson")
|
||||
} else {
|
||||
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
|
||||
}
|
||||
|
@ -315,7 +315,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
for _, field := range s.Fields {
|
||||
if field.Readable {
|
||||
// NOTE: xjson doesn't support where with object
|
||||
if field.Readable && !field.XJSON {
|
||||
if v, isZero := field.ValueOf(reflectValue); !isZero {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||
@ -328,7 +329,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
for _, field := range s.Fields {
|
||||
if field.Readable {
|
||||
if field.Readable && !field.XJSON {
|
||||
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||
|
@ -371,7 +371,7 @@ func TestCreateWithNoGORMPrimaryKey(t *testing.T) {
|
||||
|
||||
func TestSelectWithCreate(t *testing.T) {
|
||||
user := *GetUser("select_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
|
||||
DB.Select("Account", "Toys", "Manager", "ManagerID", "Languages", "Name", "CreatedAt", "Age", "Active").Create(&user)
|
||||
DB.Select("Account", "Toys", "Manager", "ManagerID", "Languages", "Name", "CreatedAt", "Age", "Active", "Phones").Create(&user)
|
||||
|
||||
var user2 User
|
||||
DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&user2, user.ID)
|
||||
|
@ -6,8 +6,10 @@ require (
|
||||
github.com/google/uuid v1.1.1
|
||||
github.com/jinzhu/now v1.1.1
|
||||
github.com/lib/pq v1.6.0
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
|
||||
gopkg.in/yaml.v2 v2.2.4 // indirect
|
||||
gorm.io/driver/mysql v1.0.1
|
||||
gorm.io/driver/postgres v1.0.0
|
||||
gorm.io/driver/postgres v1.0.1
|
||||
gorm.io/driver/sqlite v1.1.3
|
||||
gorm.io/driver/sqlserver v1.0.4
|
||||
gorm.io/gorm v1.20.1
|
||||
|
@ -19,6 +19,7 @@ type Config struct {
|
||||
Team int
|
||||
Languages int
|
||||
Friends int
|
||||
Phones int
|
||||
}
|
||||
|
||||
func GetUser(name string, config Config) *User {
|
||||
@ -65,6 +66,11 @@ func GetUser(name string, config Config) *User {
|
||||
user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{}))
|
||||
}
|
||||
|
||||
user.Phones = append(user.Phones, "phone_"+strconv.Itoa(10086))
|
||||
for i := 0; i < config.Phones; i++ {
|
||||
user.Phones = append(user.Phones, "phone_"+strconv.Itoa(i+1))
|
||||
}
|
||||
|
||||
return &user
|
||||
}
|
||||
|
||||
@ -93,11 +99,11 @@ func CheckUser(t *testing.T, user User, expect User) {
|
||||
if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil {
|
||||
t.Fatalf("errors happened when query: %v", err)
|
||||
} else {
|
||||
AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
||||
AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active", "Phones")
|
||||
}
|
||||
}
|
||||
|
||||
AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
||||
AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active", "Phones")
|
||||
|
||||
t.Run("Account", func(t *testing.T) {
|
||||
AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number")
|
||||
|
@ -156,7 +156,7 @@ func TestDryRun(t *testing.T) {
|
||||
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
|
||||
|
||||
stmt := dryRunDB.Create(&user).Statement
|
||||
if stmt.SQL.String() == "" || len(stmt.Vars) != 9 {
|
||||
if stmt.SQL.String() == "" || len(stmt.Vars) != 10 {
|
||||
t.Errorf("Failed to generate sql, got %v", stmt.SQL.String())
|
||||
}
|
||||
|
||||
|
@ -27,6 +27,7 @@ type User struct {
|
||||
Languages []Language `gorm:"many2many:UserSpeak;"`
|
||||
Friends []*User `gorm:"many2many:user_friends;"`
|
||||
Active bool
|
||||
Phones []string `gorm:"type:json;xjson"`
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
|
Loading…
x
Reference in New Issue
Block a user