feat: support complexity field with special tag saved as json.

This commit is contained in:
j2gg0s 2020-09-23 14:15:57 +08:00
parent 68920449f9
commit 95a7cb5291
8 changed files with 62 additions and 15 deletions

12
scan.go
View File

@ -12,7 +12,7 @@ import (
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
if db.Statement.Schema != nil {
for idx, name := range columns {
if field := db.Statement.Schema.LookUpField(name); field != nil {
if field := db.Statement.Schema.LookUpField(name); field != nil && !field.XJSON {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
continue
}
@ -145,7 +145,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
} else {
for idx, field := range fields {
if field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
if field.XJSON {
values[idx] = &sql.RawBytes{}
} else {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
}
}
}
@ -183,11 +187,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
if initialized || rows.Next() {
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
if field := Schema.LookUpField(column); field != nil && field.Readable && !field.XJSON {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable && !field.XJSON {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
continue
}

View File

@ -3,6 +3,8 @@ package schema
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"reflect"
"strconv"
@ -67,6 +69,8 @@ type Field struct {
ReflectValueOf func(reflect.Value) reflect.Value
ValueOf func(reflect.Value) (value interface{}, zero bool)
Set func(reflect.Value, interface{}) error
XJSON bool
}
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
@ -178,6 +182,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.Comment = val
}
if val, ok := field.TagSettings["XJSON"]; ok && utils.CheckTruth(val) {
field.XJSON = true
}
// default value is function or null or blank (primary keys)
skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") &&
strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == ""
@ -381,13 +389,23 @@ func (field *Field) setupValuerAndSetter() {
switch {
case len(field.StructField.Index) == 1:
field.ValueOf = func(value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
return fieldValue.Interface(), fieldValue.IsZero()
v := reflect.Indirect(value).Field(field.StructField.Index[0])
if field.XJSON {
b, _ := json.Marshal(v.Interface())
// pgx will santize bytes with special prefix
return string(b), v.IsZero()
}
return v.Interface(), v.IsZero()
}
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
field.ValueOf = func(value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
return fieldValue.Interface(), fieldValue.IsZero()
v := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
if field.XJSON {
b, _ := json.Marshal(v.Interface())
// pgx will santize bytes with special prefix
return string(b), v.IsZero()
}
return v.Interface(), v.IsZero()
}
default:
field.ValueOf = func(value reflect.Value) (interface{}, bool) {
@ -410,6 +428,11 @@ func (field *Field) setupValuerAndSetter() {
}
}
}
if field.XJSON {
// pgx will santize bytes with special prefix
b, _ := json.Marshal(v.Interface())
return string(b), v.IsZero()
}
return v.Interface(), v.IsZero()
}
}
@ -500,6 +523,16 @@ func (field *Field) setupValuerAndSetter() {
if v, err = valuer.Value(); err == nil {
err = setter(value, v)
}
} else if field.XJSON {
if s, ok := v.(sql.RawBytes); ok {
p := reflect.New(field.FieldType).Interface()
err = json.Unmarshal([]byte(s), p)
if err != nil {
return err
}
return setter(value, p)
}
return errors.New("inlvaid type for tag: xjson")
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
}

View File

@ -315,7 +315,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
switch reflectValue.Kind() {
case reflect.Struct:
for _, field := range s.Fields {
if field.Readable {
// NOTE: xjson doesn't support where with object
if field.Readable && !field.XJSON {
if v, isZero := field.ValueOf(reflectValue); !isZero {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
@ -328,7 +329,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
for _, field := range s.Fields {
if field.Readable {
if field.Readable && !field.XJSON {
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})

View File

@ -371,7 +371,7 @@ func TestCreateWithNoGORMPrimaryKey(t *testing.T) {
func TestSelectWithCreate(t *testing.T) {
user := *GetUser("select_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
DB.Select("Account", "Toys", "Manager", "ManagerID", "Languages", "Name", "CreatedAt", "Age", "Active").Create(&user)
DB.Select("Account", "Toys", "Manager", "ManagerID", "Languages", "Name", "CreatedAt", "Age", "Active", "Phones").Create(&user)
var user2 User
DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&user2, user.ID)

View File

@ -6,8 +6,10 @@ require (
github.com/google/uuid v1.1.1
github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.6.0
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
gopkg.in/yaml.v2 v2.2.4 // indirect
gorm.io/driver/mysql v1.0.1
gorm.io/driver/postgres v1.0.0
gorm.io/driver/postgres v1.0.1
gorm.io/driver/sqlite v1.1.3
gorm.io/driver/sqlserver v1.0.4
gorm.io/gorm v1.20.1

View File

@ -19,6 +19,7 @@ type Config struct {
Team int
Languages int
Friends int
Phones int
}
func GetUser(name string, config Config) *User {
@ -65,6 +66,11 @@ func GetUser(name string, config Config) *User {
user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{}))
}
user.Phones = append(user.Phones, "phone_"+strconv.Itoa(10086))
for i := 0; i < config.Phones; i++ {
user.Phones = append(user.Phones, "phone_"+strconv.Itoa(i+1))
}
return &user
}
@ -93,11 +99,11 @@ func CheckUser(t *testing.T, user User, expect User) {
if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil {
t.Fatalf("errors happened when query: %v", err)
} else {
AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active", "Phones")
}
}
AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active", "Phones")
t.Run("Account", func(t *testing.T) {
AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number")

View File

@ -156,7 +156,7 @@ func TestDryRun(t *testing.T) {
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
stmt := dryRunDB.Create(&user).Statement
if stmt.SQL.String() == "" || len(stmt.Vars) != 9 {
if stmt.SQL.String() == "" || len(stmt.Vars) != 10 {
t.Errorf("Failed to generate sql, got %v", stmt.SQL.String())
}

View File

@ -27,6 +27,7 @@ type User struct {
Languages []Language `gorm:"many2many:UserSpeak;"`
Friends []*User `gorm:"many2many:user_friends;"`
Active bool
Phones []string `gorm:"type:json;xjson"`
}
type Account struct {