diff --git a/scan.go b/scan.go index be8782ed..48e474fb 100644 --- a/scan.go +++ b/scan.go @@ -12,7 +12,7 @@ import ( func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { if db.Statement.Schema != nil { for idx, name := range columns { - if field := db.Statement.Schema.LookUpField(name); field != nil { + if field := db.Statement.Schema.LookUpField(name); field != nil && !field.XJSON { values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() continue } @@ -145,7 +145,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } else { for idx, field := range fields { if field != nil { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + if field.XJSON { + values[idx] = &sql.RawBytes{} + } else { + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + } } } @@ -183,11 +187,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { if initialized || rows.Next() { for idx, column := range columns { - if field := Schema.LookUpField(column); field != nil && field.Readable { + if field := Schema.LookUpField(column); field != nil && field.Readable && !field.XJSON { values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := Schema.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable && !field.XJSON { values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() continue } diff --git a/schema/field.go b/schema/field.go index 4b8a5a2a..ef9ad4a2 100644 --- a/schema/field.go +++ b/schema/field.go @@ -3,6 +3,8 @@ package schema import ( "database/sql" "database/sql/driver" + "encoding/json" + "errors" "fmt" "reflect" "strconv" @@ -67,6 +69,8 @@ type Field struct { ReflectValueOf func(reflect.Value) reflect.Value ValueOf func(reflect.Value) (value interface{}, zero bool) Set func(reflect.Value, interface{}) error + + XJSON bool } func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { @@ -178,6 +182,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Comment = val } + if val, ok := field.TagSettings["XJSON"]; ok && utils.CheckTruth(val) { + field.XJSON = true + } + // default value is function or null or blank (primary keys) skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") && strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == "" @@ -381,13 +389,23 @@ func (field *Field) setupValuerAndSetter() { switch { case len(field.StructField.Index) == 1: field.ValueOf = func(value reflect.Value) (interface{}, bool) { - fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) - return fieldValue.Interface(), fieldValue.IsZero() + v := reflect.Indirect(value).Field(field.StructField.Index[0]) + if field.XJSON { + b, _ := json.Marshal(v.Interface()) + // pgx will santize bytes with special prefix + return string(b), v.IsZero() + } + return v.Interface(), v.IsZero() } case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: field.ValueOf = func(value reflect.Value) (interface{}, bool) { - fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) - return fieldValue.Interface(), fieldValue.IsZero() + v := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) + if field.XJSON { + b, _ := json.Marshal(v.Interface()) + // pgx will santize bytes with special prefix + return string(b), v.IsZero() + } + return v.Interface(), v.IsZero() } default: field.ValueOf = func(value reflect.Value) (interface{}, bool) { @@ -410,6 +428,11 @@ func (field *Field) setupValuerAndSetter() { } } } + if field.XJSON { + // pgx will santize bytes with special prefix + b, _ := json.Marshal(v.Interface()) + return string(b), v.IsZero() + } return v.Interface(), v.IsZero() } } @@ -500,6 +523,16 @@ func (field *Field) setupValuerAndSetter() { if v, err = valuer.Value(); err == nil { err = setter(value, v) } + } else if field.XJSON { + if s, ok := v.(sql.RawBytes); ok { + p := reflect.New(field.FieldType).Interface() + err = json.Unmarshal([]byte(s), p) + if err != nil { + return err + } + return setter(value, p) + } + return errors.New("inlvaid type for tag: xjson") } else { return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) } diff --git a/statement.go b/statement.go index ee80f8cd..8877d314 100644 --- a/statement.go +++ b/statement.go @@ -315,7 +315,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c switch reflectValue.Kind() { case reflect.Struct: for _, field := range s.Fields { - if field.Readable { + // NOTE: xjson doesn't support where with object + if field.Readable && !field.XJSON { if v, isZero := field.ValueOf(reflectValue); !isZero { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) @@ -328,7 +329,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { for _, field := range s.Fields { - if field.Readable { + if field.Readable && !field.XJSON { if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) diff --git a/tests/create_test.go b/tests/create_test.go index 00674eec..4628931b 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -371,7 +371,7 @@ func TestCreateWithNoGORMPrimaryKey(t *testing.T) { func TestSelectWithCreate(t *testing.T) { user := *GetUser("select_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) - DB.Select("Account", "Toys", "Manager", "ManagerID", "Languages", "Name", "CreatedAt", "Age", "Active").Create(&user) + DB.Select("Account", "Toys", "Manager", "ManagerID", "Languages", "Name", "CreatedAt", "Age", "Active", "Phones").Create(&user) var user2 User DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&user2, user.ID) diff --git a/tests/go.mod b/tests/go.mod index 0db87934..677feb3f 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,8 +6,10 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 + gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect + gopkg.in/yaml.v2 v2.2.4 // indirect gorm.io/driver/mysql v1.0.1 - gorm.io/driver/postgres v1.0.0 + gorm.io/driver/postgres v1.0.1 gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlserver v1.0.4 gorm.io/gorm v1.20.1 diff --git a/tests/helper_test.go b/tests/helper_test.go index eee34e99..156767a2 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -19,6 +19,7 @@ type Config struct { Team int Languages int Friends int + Phones int } func GetUser(name string, config Config) *User { @@ -65,6 +66,11 @@ func GetUser(name string, config Config) *User { user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{})) } + user.Phones = append(user.Phones, "phone_"+strconv.Itoa(10086)) + for i := 0; i < config.Phones; i++ { + user.Phones = append(user.Phones, "phone_"+strconv.Itoa(i+1)) + } + return &user } @@ -93,11 +99,11 @@ func CheckUser(t *testing.T, user User, expect User) { if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil { t.Fatalf("errors happened when query: %v", err) } else { - AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active", "Phones") } } - AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active", "Phones") t.Run("Account", func(t *testing.T) { AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index c0176fc3..5fab4cdf 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -156,7 +156,7 @@ func TestDryRun(t *testing.T) { dryRunDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryRunDB.Create(&user).Statement - if stmt.SQL.String() == "" || len(stmt.Vars) != 9 { + if stmt.SQL.String() == "" || len(stmt.Vars) != 10 { t.Errorf("Failed to generate sql, got %v", stmt.SQL.String()) } diff --git a/utils/tests/models.go b/utils/tests/models.go index 2c5e71c0..e7caa41e 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -27,6 +27,7 @@ type User struct { Languages []Language `gorm:"many2many:UserSpeak;"` Friends []*User `gorm:"many2many:user_friends;"` Active bool + Phones []string `gorm:"type:json;xjson"` } type Account struct {