feat: support complexity field with special tag saved as json.

This commit is contained in:
j2gg0s 2020-09-23 14:15:57 +08:00
parent 68920449f9
commit 95a7cb5291
8 changed files with 62 additions and 15 deletions

10
scan.go
View File

@ -12,7 +12,7 @@ import (
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
if db.Statement.Schema != nil { if db.Statement.Schema != nil {
for idx, name := range columns { for idx, name := range columns {
if field := db.Statement.Schema.LookUpField(name); field != nil { if field := db.Statement.Schema.LookUpField(name); field != nil && !field.XJSON {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
continue continue
} }
@ -145,9 +145,13 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
} else { } else {
for idx, field := range fields { for idx, field := range fields {
if field != nil { if field != nil {
if field.XJSON {
values[idx] = &sql.RawBytes{}
} else {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
} }
} }
}
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
@ -183,11 +187,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
if initialized || rows.Next() { if initialized || rows.Next() {
for idx, column := range columns { for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable { if field := Schema.LookUpField(column); field != nil && field.Readable && !field.XJSON {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
} else if names := strings.Split(column, "__"); len(names) > 1 { } else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable && !field.XJSON {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
continue continue
} }

View File

@ -3,6 +3,8 @@ package schema
import ( import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"encoding/json"
"errors"
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
@ -67,6 +69,8 @@ type Field struct {
ReflectValueOf func(reflect.Value) reflect.Value ReflectValueOf func(reflect.Value) reflect.Value
ValueOf func(reflect.Value) (value interface{}, zero bool) ValueOf func(reflect.Value) (value interface{}, zero bool)
Set func(reflect.Value, interface{}) error Set func(reflect.Value, interface{}) error
XJSON bool
} }
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
@ -178,6 +182,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.Comment = val field.Comment = val
} }
if val, ok := field.TagSettings["XJSON"]; ok && utils.CheckTruth(val) {
field.XJSON = true
}
// default value is function or null or blank (primary keys) // default value is function or null or blank (primary keys)
skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") && skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") &&
strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == "" strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == ""
@ -381,13 +389,23 @@ func (field *Field) setupValuerAndSetter() {
switch { switch {
case len(field.StructField.Index) == 1: case len(field.StructField.Index) == 1:
field.ValueOf = func(value reflect.Value) (interface{}, bool) { field.ValueOf = func(value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) v := reflect.Indirect(value).Field(field.StructField.Index[0])
return fieldValue.Interface(), fieldValue.IsZero() if field.XJSON {
b, _ := json.Marshal(v.Interface())
// pgx will santize bytes with special prefix
return string(b), v.IsZero()
}
return v.Interface(), v.IsZero()
} }
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
field.ValueOf = func(value reflect.Value) (interface{}, bool) { field.ValueOf = func(value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) v := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
return fieldValue.Interface(), fieldValue.IsZero() if field.XJSON {
b, _ := json.Marshal(v.Interface())
// pgx will santize bytes with special prefix
return string(b), v.IsZero()
}
return v.Interface(), v.IsZero()
} }
default: default:
field.ValueOf = func(value reflect.Value) (interface{}, bool) { field.ValueOf = func(value reflect.Value) (interface{}, bool) {
@ -410,6 +428,11 @@ func (field *Field) setupValuerAndSetter() {
} }
} }
} }
if field.XJSON {
// pgx will santize bytes with special prefix
b, _ := json.Marshal(v.Interface())
return string(b), v.IsZero()
}
return v.Interface(), v.IsZero() return v.Interface(), v.IsZero()
} }
} }
@ -500,6 +523,16 @@ func (field *Field) setupValuerAndSetter() {
if v, err = valuer.Value(); err == nil { if v, err = valuer.Value(); err == nil {
err = setter(value, v) err = setter(value, v)
} }
} else if field.XJSON {
if s, ok := v.(sql.RawBytes); ok {
p := reflect.New(field.FieldType).Interface()
err = json.Unmarshal([]byte(s), p)
if err != nil {
return err
}
return setter(value, p)
}
return errors.New("inlvaid type for tag: xjson")
} else { } else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
} }

View File

@ -315,7 +315,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Struct: case reflect.Struct:
for _, field := range s.Fields { for _, field := range s.Fields {
if field.Readable { // NOTE: xjson doesn't support where with object
if field.Readable && !field.XJSON {
if v, isZero := field.ValueOf(reflectValue); !isZero { if v, isZero := field.ValueOf(reflectValue); !isZero {
if field.DBName != "" { if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
@ -328,7 +329,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
for _, field := range s.Fields { for _, field := range s.Fields {
if field.Readable { if field.Readable && !field.XJSON {
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
if field.DBName != "" { if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})

View File

@ -371,7 +371,7 @@ func TestCreateWithNoGORMPrimaryKey(t *testing.T) {
func TestSelectWithCreate(t *testing.T) { func TestSelectWithCreate(t *testing.T) {
user := *GetUser("select_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) user := *GetUser("select_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4})
DB.Select("Account", "Toys", "Manager", "ManagerID", "Languages", "Name", "CreatedAt", "Age", "Active").Create(&user) DB.Select("Account", "Toys", "Manager", "ManagerID", "Languages", "Name", "CreatedAt", "Age", "Active", "Phones").Create(&user)
var user2 User var user2 User
DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&user2, user.ID) DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&user2, user.ID)

View File

@ -6,8 +6,10 @@ require (
github.com/google/uuid v1.1.1 github.com/google/uuid v1.1.1
github.com/jinzhu/now v1.1.1 github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.6.0 github.com/lib/pq v1.6.0
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
gopkg.in/yaml.v2 v2.2.4 // indirect
gorm.io/driver/mysql v1.0.1 gorm.io/driver/mysql v1.0.1
gorm.io/driver/postgres v1.0.0 gorm.io/driver/postgres v1.0.1
gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlite v1.1.3
gorm.io/driver/sqlserver v1.0.4 gorm.io/driver/sqlserver v1.0.4
gorm.io/gorm v1.20.1 gorm.io/gorm v1.20.1

View File

@ -19,6 +19,7 @@ type Config struct {
Team int Team int
Languages int Languages int
Friends int Friends int
Phones int
} }
func GetUser(name string, config Config) *User { func GetUser(name string, config Config) *User {
@ -65,6 +66,11 @@ func GetUser(name string, config Config) *User {
user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{})) user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{}))
} }
user.Phones = append(user.Phones, "phone_"+strconv.Itoa(10086))
for i := 0; i < config.Phones; i++ {
user.Phones = append(user.Phones, "phone_"+strconv.Itoa(i+1))
}
return &user return &user
} }
@ -93,11 +99,11 @@ func CheckUser(t *testing.T, user User, expect User) {
if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil { if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil {
t.Fatalf("errors happened when query: %v", err) t.Fatalf("errors happened when query: %v", err)
} else { } else {
AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active", "Phones")
} }
} }
AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active", "Phones")
t.Run("Account", func(t *testing.T) { t.Run("Account", func(t *testing.T) {
AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number")

View File

@ -156,7 +156,7 @@ func TestDryRun(t *testing.T) {
dryRunDB := DB.Session(&gorm.Session{DryRun: true}) dryRunDB := DB.Session(&gorm.Session{DryRun: true})
stmt := dryRunDB.Create(&user).Statement stmt := dryRunDB.Create(&user).Statement
if stmt.SQL.String() == "" || len(stmt.Vars) != 9 { if stmt.SQL.String() == "" || len(stmt.Vars) != 10 {
t.Errorf("Failed to generate sql, got %v", stmt.SQL.String()) t.Errorf("Failed to generate sql, got %v", stmt.SQL.String())
} }

View File

@ -27,6 +27,7 @@ type User struct {
Languages []Language `gorm:"many2many:UserSpeak;"` Languages []Language `gorm:"many2many:UserSpeak;"`
Friends []*User `gorm:"many2many:user_friends;"` Friends []*User `gorm:"many2many:user_friends;"`
Active bool Active bool
Phones []string `gorm:"type:json;xjson"`
} }
type Account struct { type Account struct {