commit
912bca0d1a
22
README.md
22
README.md
@ -28,7 +28,7 @@ go get github.com/jinzhu/gorm
|
|||||||
## Conventions
|
## Conventions
|
||||||
|
|
||||||
* Table name is the plural of struct name's snake case.
|
* Table name is the plural of struct name's snake case.
|
||||||
Disable pluralization with `db.SingularTable(true)`, or [specify your table name](#specify-table-name)
|
Disable pluralization with `db.SingularTable(true)`, or [Specifying the Table Name for Struct permanently with TableName](#Specifying-the-Table-Name-for-Struct-permanently-with-TableName)
|
||||||
* Column name is the snake case of field's name.
|
* Column name is the snake case of field's name.
|
||||||
* Use `Id int64` field as primary key.
|
* Use `Id int64` field as primary key.
|
||||||
* Use tag `sql` to change field's property, change the tag name with `db.SetTagIdentifier(new_name)`.
|
* Use tag `sql` to change field's property, change the tag name with `db.SetTagIdentifier(new_name)`.
|
||||||
@ -47,6 +47,20 @@ db.First(&user)
|
|||||||
DB.Save(&User{Name: "xxx"}) // table "users"
|
DB.Save(&User{Name: "xxx"}) // table "users"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Existing schema
|
||||||
|
|
||||||
|
If you have and existing database schema and some of your tables does not follow the conventions, (and you can't rename your table names), please use: [Specifying the Table Name for Struct permanently with TableName](#Specifying-the-Table-Name-for-Struct-permanently-with-TableName).
|
||||||
|
|
||||||
|
If your primary key field is different from `id`, you can add a tag to the field structure to specify that this field is a primary key.
|
||||||
|
|
||||||
|
```go
|
||||||
|
type Animal struct { // animals
|
||||||
|
AnimalId int64 `primaryKey:"yes"`
|
||||||
|
Birthday time.Time
|
||||||
|
Age int64
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
# Getting Started
|
# Getting Started
|
||||||
|
|
||||||
```go
|
```go
|
||||||
@ -97,9 +111,9 @@ import _ "github.com/lib/pq"
|
|||||||
// import _ "github.com/go-sql-driver/mysql"
|
// import _ "github.com/go-sql-driver/mysql"
|
||||||
// import _ "github.com/mattn/go-sqlite3"
|
// import _ "github.com/mattn/go-sqlite3"
|
||||||
|
|
||||||
db, err := Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||||
// db, err = Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True")
|
// db, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True")
|
||||||
// db, err = Open("sqlite3", "/tmp/gorm.db")
|
// db, err = gorm.Open("sqlite3", "/tmp/gorm.db")
|
||||||
|
|
||||||
// Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB)
|
// Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB)
|
||||||
d := db.DB()
|
d := db.DB()
|
||||||
|
@ -1,10 +1,17 @@
|
|||||||
package dialect
|
package dialect
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var timeType = reflect.TypeOf(time.Time{})
|
||||||
|
|
||||||
type Dialect interface {
|
type Dialect interface {
|
||||||
BinVar(i int) string
|
BinVar(i int) string
|
||||||
SupportLastInsertId() bool
|
SupportLastInsertId() bool
|
||||||
SqlTag(column interface{}, size int) string
|
SqlTag(value reflect.Value, size int) string
|
||||||
PrimaryKeyTag(column interface{}, size int) string
|
PrimaryKeyTag(value reflect.Value, size int) string
|
||||||
ReturningStr(key string) string
|
ReturningStr(key string) string
|
||||||
Quote(key string) string
|
Quote(key string) string
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
package dialect
|
package dialect
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mysql struct{}
|
type mysql struct{}
|
||||||
@ -16,41 +15,44 @@ func (s *mysql) SupportLastInsertId() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *mysql) SqlTag(column interface{}, size int) string {
|
func (d *mysql) SqlTag(value reflect.Value, size int) string {
|
||||||
switch column.(type) {
|
switch value.Kind() {
|
||||||
case time.Time:
|
case reflect.Bool:
|
||||||
return "datetime"
|
|
||||||
case bool, sql.NullBool:
|
|
||||||
return "boolean"
|
return "boolean"
|
||||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
return "int"
|
return "int"
|
||||||
case int64, uint64, sql.NullInt64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
return "bigint"
|
return "bigint"
|
||||||
case float32, float64, sql.NullFloat64:
|
case reflect.Float32, reflect.Float64:
|
||||||
return "double"
|
return "double"
|
||||||
case []byte:
|
case reflect.String:
|
||||||
if size > 0 && size < 65532 {
|
|
||||||
return fmt.Sprintf("varbinary(%d)", size)
|
|
||||||
} else {
|
|
||||||
return "longblob"
|
|
||||||
}
|
|
||||||
case string, sql.NullString:
|
|
||||||
if size > 0 && size < 65532 {
|
if size > 0 && size < 65532 {
|
||||||
return fmt.Sprintf("varchar(%d)", size)
|
return fmt.Sprintf("varchar(%d)", size)
|
||||||
} else {
|
} else {
|
||||||
return "longtext"
|
return "longtext"
|
||||||
}
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if value.Type() == timeType {
|
||||||
|
return "datetime"
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
panic("Invalid sql type for mysql")
|
if _, ok := value.Interface().([]byte); ok {
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
return fmt.Sprintf("varbinary(%d)", size)
|
||||||
|
} else {
|
||||||
|
return "longblob"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mysql) PrimaryKeyTag(column interface{}, size int) string {
|
func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||||
suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
|
suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
|
||||||
switch column.(type) {
|
switch value.Kind() {
|
||||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
return "int" + suffix_str
|
return "int" + suffix_str
|
||||||
case int64, uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
return "bigint" + suffix_str
|
return "bigint" + suffix_str
|
||||||
default:
|
default:
|
||||||
panic("Invalid primary key type")
|
panic("Invalid primary key type")
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
package dialect
|
package dialect
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
type postgres struct {
|
type postgres struct {
|
||||||
@ -17,36 +16,38 @@ func (s *postgres) SupportLastInsertId() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *postgres) SqlTag(column interface{}, size int) string {
|
func (d *postgres) SqlTag(value reflect.Value, size int) string {
|
||||||
switch column.(type) {
|
switch value.Kind() {
|
||||||
case time.Time:
|
case reflect.Bool:
|
||||||
return "timestamp with time zone"
|
|
||||||
case bool, sql.NullBool:
|
|
||||||
return "boolean"
|
return "boolean"
|
||||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
return "integer"
|
return "integer"
|
||||||
case int64, uint64, sql.NullInt64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
return "bigint"
|
return "bigint"
|
||||||
case float32, float64, sql.NullFloat64:
|
case reflect.Float32, reflect.Float64:
|
||||||
return "numeric"
|
return "numeric"
|
||||||
case []byte:
|
case reflect.String:
|
||||||
return "bytea"
|
|
||||||
case string, sql.NullString:
|
|
||||||
if size > 0 && size < 65532 {
|
if size > 0 && size < 65532 {
|
||||||
return fmt.Sprintf("varchar(%d)", size)
|
return fmt.Sprintf("varchar(%d)", size)
|
||||||
} else {
|
}
|
||||||
return "text"
|
return "text"
|
||||||
|
case reflect.Struct:
|
||||||
|
if value.Type() == timeType {
|
||||||
|
return "timestamp with time zone"
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
panic("Invalid sql type for postgres")
|
if _, ok := value.Interface().([]byte); ok {
|
||||||
|
return "bytea"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *postgres) PrimaryKeyTag(column interface{}, size int) string {
|
func (s *postgres) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||||
switch column.(type) {
|
switch value.Kind() {
|
||||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
return "serial PRIMARY KEY"
|
return "serial PRIMARY KEY"
|
||||||
case int64, uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
return "bigserial PRIMARY KEY"
|
return "bigserial PRIMARY KEY"
|
||||||
default:
|
default:
|
||||||
panic("Invalid primary key type")
|
panic("Invalid primary key type")
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
package dialect
|
package dialect
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
type sqlite3 struct{}
|
type sqlite3 struct{}
|
||||||
@ -16,33 +15,41 @@ func (s *sqlite3) SupportLastInsertId() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqlite3) SqlTag(column interface{}, size int) string {
|
func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
|
||||||
switch column.(type) {
|
switch value.Kind() {
|
||||||
case time.Time:
|
case reflect.Bool:
|
||||||
return "datetime"
|
|
||||||
case bool, sql.NullBool:
|
|
||||||
return "bool"
|
return "bool"
|
||||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
return "integer"
|
return "integer"
|
||||||
case int64, uint64, sql.NullInt64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
return "bigint"
|
return "bigint"
|
||||||
case float32, float64, sql.NullFloat64:
|
case reflect.Float32, reflect.Float64:
|
||||||
return "real"
|
return "real"
|
||||||
case []byte:
|
case reflect.String:
|
||||||
return "blob"
|
|
||||||
case string, sql.NullString:
|
|
||||||
if size > 0 && size < 65532 {
|
if size > 0 && size < 65532 {
|
||||||
return fmt.Sprintf("varchar(%d)", size)
|
return fmt.Sprintf("varchar(%d)", size)
|
||||||
} else {
|
} else {
|
||||||
return "text"
|
return "text"
|
||||||
}
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if value.Type() == timeType {
|
||||||
|
return "datetime"
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
panic("Invalid sql type for sqlite3")
|
if _, ok := value.Interface().([]byte); ok {
|
||||||
|
return "blob"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqlite3) PrimaryKeyTag(column interface{}, size int) string {
|
func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||||
return "INTEGER PRIMARY KEY"
|
switch value.Kind() {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr, reflect.Int64, reflect.Uint64:
|
||||||
|
return "INTEGER PRIMARY KEY"
|
||||||
|
default:
|
||||||
|
panic("Invalid primary key type")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqlite3) ReturningStr(key string) (str string) {
|
func (s *sqlite3) ReturningStr(key string) (str string) {
|
||||||
|
11
main.go
11
main.go
@ -32,6 +32,13 @@ func (s *DB) DB() *sql.DB {
|
|||||||
return s.db.(*sql.DB)
|
return s.db.(*sql.DB)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the underlying sql.DB or sql.Tx instance.
|
||||||
|
// Use of this method is discouraged. It's mainly intended to allow
|
||||||
|
// coexistence with legacy non-GORM code.
|
||||||
|
func (s *DB) CommonDB() sqlCommon {
|
||||||
|
return s.db
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DB) Callback() *callback {
|
func (s *DB) Callback() *callback {
|
||||||
s.parent.callback = s.parent.callback.clone()
|
s.parent.callback = s.parent.callback.clone()
|
||||||
return s.parent.callback
|
return s.parent.callback
|
||||||
@ -124,13 +131,13 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
|
|||||||
|
|
||||||
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
||||||
scope := s.clone().NewScope(out)
|
scope := s.clone().NewScope(out)
|
||||||
scope.Search = scope.Search.clone().order(scope.PrimaryKey()).limit(1)
|
scope.Search = scope.Search.clone().order(scope.TableName()+"."+scope.PrimaryKey()).limit(1)
|
||||||
return scope.inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
return scope.inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
||||||
scope := s.clone().NewScope(out)
|
scope := s.clone().NewScope(out)
|
||||||
scope.Search = scope.Search.clone().order(scope.PrimaryKey() + " DESC").limit(1)
|
scope.Search = scope.Search.clone().order(scope.TableName()+"."+scope.PrimaryKey() + " DESC").limit(1)
|
||||||
return scope.inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
return scope.inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
277
main_test.go
277
main_test.go
@ -7,9 +7,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
"github.com/jinzhu/gorm"
|
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -22,9 +22,45 @@ type IgnoredEmbedStruct struct {
|
|||||||
Name string
|
Name string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Num int64
|
||||||
|
|
||||||
|
func (i *Num) Scan(src interface{}) error {
|
||||||
|
switch s := src.(type) {
|
||||||
|
case []byte:
|
||||||
|
case int64:
|
||||||
|
*i = Num(s)
|
||||||
|
default:
|
||||||
|
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Company struct {
|
||||||
|
Id int64
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Role struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (role *Role) Scan(value interface{}) error {
|
||||||
|
role.Name = string(value.([]uint8))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (role Role) Value() (driver.Value, error) {
|
||||||
|
return role.Name, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (role Role) IsAdmin() bool {
|
||||||
|
return role.Name == "admin"
|
||||||
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
Id int64 // Id: Primary key
|
Id int64 // Id: Primary key
|
||||||
Age int64
|
Age int64
|
||||||
|
UserNum Num
|
||||||
Name string `sql:"size:255"`
|
Name string `sql:"size:255"`
|
||||||
Birthday time.Time // Time
|
Birthday time.Time // Time
|
||||||
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
|
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
|
||||||
@ -39,8 +75,12 @@ type User struct {
|
|||||||
When time.Time
|
When time.Time
|
||||||
CreditCard CreditCard
|
CreditCard CreditCard
|
||||||
Latitude float64
|
Latitude float64
|
||||||
PasswordHash []byte
|
CompanyId int64
|
||||||
IgnoreMe int64 `sql:"-"`
|
Company
|
||||||
|
Role
|
||||||
|
PasswordHash []byte
|
||||||
|
IgnoreMe int64 `sql:"-"`
|
||||||
|
IgnoreStringSlice []string `sql:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreditCard struct {
|
type CreditCard struct {
|
||||||
@ -87,6 +127,13 @@ type Product struct {
|
|||||||
AfterDeleteCallTimes int64
|
AfterDeleteCallTimes int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Animal struct {
|
||||||
|
Counter int64 `primaryKey:"yes"`
|
||||||
|
Name string
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
db gorm.DB
|
db gorm.DB
|
||||||
t1, t2, t3, t4, t5 time.Time
|
t1, t2, t3, t4, t5 time.Time
|
||||||
@ -128,6 +175,13 @@ func init() {
|
|||||||
db.Exec("drop table emails;")
|
db.Exec("drop table emails;")
|
||||||
db.Exec("drop table addresses")
|
db.Exec("drop table addresses")
|
||||||
db.Exec("drop table credit_cards")
|
db.Exec("drop table credit_cards")
|
||||||
|
db.Exec("drop table roles")
|
||||||
|
db.Exec("drop table companies")
|
||||||
|
db.Exec("drop table animals")
|
||||||
|
|
||||||
|
if err = db.CreateTable(&Animal{}).Error; err != nil {
|
||||||
|
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||||
|
}
|
||||||
|
|
||||||
if err = db.CreateTable(&User{}).Error; err != nil {
|
if err = db.CreateTable(&User{}).Error; err != nil {
|
||||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||||
@ -149,17 +203,30 @@ func init() {
|
|||||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = db.AutoMigrate(Company{}).Error; err != nil {
|
||||||
|
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = db.AutoMigrate(Role{}).Error; err != nil {
|
||||||
|
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||||
|
}
|
||||||
|
|
||||||
var shortForm = "2006-01-02 15:04:05"
|
var shortForm = "2006-01-02 15:04:05"
|
||||||
t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40")
|
t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40")
|
||||||
t2, _ = time.Parse(shortForm, "2002-01-01 00:00:00")
|
t2, _ = time.Parse(shortForm, "2002-01-01 00:00:00")
|
||||||
t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00")
|
t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00")
|
||||||
t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00")
|
t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00")
|
||||||
t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00")
|
t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00")
|
||||||
db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now()})
|
db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now(), UserNum: Num(111)})
|
||||||
db.Save(&User{Name: "2", Age: 20, Birthday: t2})
|
db.Save(&User{Name: "2", Age: 20, Birthday: t2})
|
||||||
db.Save(&User{Name: "3", Age: 22, Birthday: t3})
|
db.Save(&User{Name: "3", Age: 22, Birthday: t3})
|
||||||
db.Save(&User{Name: "3", Age: 24, Birthday: t4})
|
db.Save(&User{Name: "3", Age: 24, Birthday: t4})
|
||||||
db.Save(&User{Name: "5", Age: 26, Birthday: t4})
|
db.Save(&User{Name: "5", Age: 26, Birthday: t4})
|
||||||
|
|
||||||
|
db.Save(&Animal{Name: "First"})
|
||||||
|
db.Save(&Animal{Name: "Amazing"})
|
||||||
|
db.Save(&Animal{Name: "Horse"})
|
||||||
|
db.Save(&Animal{Name: "Last"})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirstAndLast(t *testing.T) {
|
func TestFirstAndLast(t *testing.T) {
|
||||||
@ -170,7 +237,7 @@ func TestFirstAndLast(t *testing.T) {
|
|||||||
db.Last(&user3)
|
db.Last(&user3)
|
||||||
db.Order("id desc").Find(&user4)
|
db.Order("id desc").Find(&user4)
|
||||||
if user1.Id != user2.Id || user3.Id != user4.Id {
|
if user1.Id != user2.Id || user3.Id != user4.Id {
|
||||||
t.Errorf("First and Last should works correctly")
|
t.Errorf("First and Last should work correctly")
|
||||||
}
|
}
|
||||||
|
|
||||||
var users []User
|
var users []User
|
||||||
@ -180,6 +247,52 @@ func TestFirstAndLast(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFirstAndLastWithJoins(t *testing.T) {
|
||||||
|
var user1, user2, user3, user4 User
|
||||||
|
db.Joins("left join emails on emails.user_id = users.id").First(&user1)
|
||||||
|
db.Order("id").Find(&user2)
|
||||||
|
|
||||||
|
db.Joins("left join emails on emails.user_id = users.id").Last(&user3)
|
||||||
|
db.Order("id desc").Find(&user4)
|
||||||
|
if user1.Id != user2.Id || user3.Id != user4.Id {
|
||||||
|
t.Errorf("First and Last should work correctly with Joins")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFirstAndLastForTableWithNoStdPrimaryKey(t *testing.T) {
|
||||||
|
var animal1, animal2, animal3, animal4 Animal
|
||||||
|
db.First(&animal1)
|
||||||
|
db.Order("counter").Find(&animal2)
|
||||||
|
|
||||||
|
db.Last(&animal3)
|
||||||
|
db.Order("counter desc").Find(&animal4)
|
||||||
|
if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter {
|
||||||
|
t.Errorf("First and Last should work correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
var animals []Animal
|
||||||
|
db.First(&animals)
|
||||||
|
if len(animals) != 1 {
|
||||||
|
t.Errorf("Find first record as map")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveCustomType(t *testing.T) {
|
||||||
|
var user, user1 User
|
||||||
|
db.First(&user, "name = ?", "1")
|
||||||
|
if user.UserNum != Num(111) {
|
||||||
|
t.Errorf("UserNum should be saved correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
user.UserNum = Num(222)
|
||||||
|
db.Save(&user)
|
||||||
|
|
||||||
|
db.First(&user1, "name = ?", "1")
|
||||||
|
if user1.UserNum != Num(222) {
|
||||||
|
t.Errorf("UserNum should be updated correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestPrecision(t *testing.T) {
|
func TestPrecision(t *testing.T) {
|
||||||
f := 35.03554004971999
|
f := 35.03554004971999
|
||||||
user := User{Name: "Precision", Latitude: f}
|
user := User{Name: "Precision", Latitude: f}
|
||||||
@ -518,7 +631,7 @@ func TestOrderAndPluck(t *testing.T) {
|
|||||||
|
|
||||||
db.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4)
|
db.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4)
|
||||||
if reflect.DeepEqual(ages3, ages4) {
|
if reflect.DeepEqual(ages3, ages4) {
|
||||||
t.Errorf("Reorder should works")
|
t.Errorf("Reorder should work")
|
||||||
}
|
}
|
||||||
|
|
||||||
var names []string
|
var names []string
|
||||||
@ -535,7 +648,7 @@ func TestLimit(t *testing.T) {
|
|||||||
db.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3)
|
db.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3)
|
||||||
|
|
||||||
if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 {
|
if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 {
|
||||||
t.Errorf("Limit should works")
|
t.Errorf("Limit should work")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -544,7 +657,7 @@ func TestOffset(t *testing.T) {
|
|||||||
db.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
|
db.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
|
||||||
|
|
||||||
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
|
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
|
||||||
t.Errorf("Offset should works")
|
t.Errorf("Offset should work")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -561,7 +674,7 @@ func TestCount(t *testing.T) {
|
|||||||
var users []User
|
var users []User
|
||||||
|
|
||||||
if err := db.Where("name = ?", "1").Or("name = ?", "3").Find(&users).Count(&count).Error; err != nil {
|
if err := db.Where("name = ?", "1").Or("name = ?", "3").Find(&users).Count(&count).Error; err != nil {
|
||||||
t.Errorf("Count should works", err)
|
t.Errorf("Count should work", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if count != int64(len(users)) {
|
if count != int64(len(users)) {
|
||||||
@ -570,7 +683,7 @@ func TestCount(t *testing.T) {
|
|||||||
|
|
||||||
db.Model(&User{}).Where("name = ?", "1").Count(&count1).Or("name = ?", "3").Count(&count2)
|
db.Model(&User{}).Where("name = ?", "1").Count(&count1).Or("name = ?", "3").Count(&count2)
|
||||||
if count1 != 1 || count2 != 3 {
|
if count1 != 1 || count2 != 3 {
|
||||||
t.Errorf("Multiple count should works")
|
t.Errorf("Multiple count should work")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -688,7 +801,7 @@ func TestRunCallbacks(t *testing.T) {
|
|||||||
var products []Product
|
var products []Product
|
||||||
db.Find(&products, "code = ?", "unique_code")
|
db.Find(&products, "code = ?", "unique_code")
|
||||||
if products[0].AfterFindCallTimes != 2 {
|
if products[0].AfterFindCallTimes != 2 {
|
||||||
t.Errorf("AfterFind callbacks should works with slice")
|
t.Errorf("AfterFind callbacks should work with slice")
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Where("Code = ?", "unique_code").First(&p)
|
db.Where("Code = ?", "unique_code").First(&p)
|
||||||
@ -869,17 +982,31 @@ func TestSetTableDirectly(t *testing.T) {
|
|||||||
func TestUpdate(t *testing.T) {
|
func TestUpdate(t *testing.T) {
|
||||||
product1 := Product{Code: "123"}
|
product1 := Product{Code: "123"}
|
||||||
product2 := Product{Code: "234"}
|
product2 := Product{Code: "234"}
|
||||||
|
animal1 := Animal{Name: "Ferdinand"}
|
||||||
|
animal2 := Animal{Name: "nerdz"}
|
||||||
|
|
||||||
db.Save(&product1).Save(&product2).Update("code", "456")
|
db.Save(&product1).Save(&product2).Update("code", "456")
|
||||||
|
|
||||||
if product2.Code != "456" {
|
if product2.Code != "456" {
|
||||||
t.Errorf("Record should be updated with update attributes")
|
t.Errorf("Record should be updated with update attributes")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db.Save(&animal1).Save(&animal2).Update("name", "Francis")
|
||||||
|
|
||||||
|
if animal2.Name != "Francis" {
|
||||||
|
t.Errorf("Record should be updated with update attributes")
|
||||||
|
}
|
||||||
|
|
||||||
db.First(&product1, product1.Id)
|
db.First(&product1, product1.Id)
|
||||||
db.First(&product2, product2.Id)
|
db.First(&product2, product2.Id)
|
||||||
updated_at1 := product1.UpdatedAt
|
updated_at1 := product1.UpdatedAt
|
||||||
updated_at2 := product2.UpdatedAt
|
updated_at2 := product2.UpdatedAt
|
||||||
|
|
||||||
|
db.First(&animal1, animal1.Counter)
|
||||||
|
db.First(&animal2, animal2.Counter)
|
||||||
|
animalUpdated_at1 := animal1.UpdatedAt
|
||||||
|
animalUpdated_at2 := animal2.UpdatedAt
|
||||||
|
|
||||||
var product3 Product
|
var product3 Product
|
||||||
db.First(&product3, product2.Id).Update("code", "456")
|
db.First(&product3, product2.Id).Update("code", "456")
|
||||||
if updated_at2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
|
if updated_at2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
|
||||||
@ -898,6 +1025,25 @@ func TestUpdate(t *testing.T) {
|
|||||||
t.Errorf("Product 234 should be changed to 456")
|
t.Errorf("Product 234 should be changed to 456")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var animal3 Animal
|
||||||
|
db.First(&animal3, animal2.Counter).Update("Name", "Robert")
|
||||||
|
|
||||||
|
if animalUpdated_at2.Format(time.RFC3339Nano) != animal2.UpdatedAt.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("updated_at should not be updated if nothing changed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.First(&Animal{}, "name = 'Ferdinand'").Error != nil {
|
||||||
|
t.Errorf("Animal 'Ferdinand' should not be updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.First(&Animal{}, "name = 'nerdz'").Error == nil {
|
||||||
|
t.Errorf("Animal 'nerdz' should be changed to 'Francis'")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.First(&Animal{}, "name = 'Robert'").Error != nil {
|
||||||
|
t.Errorf("Animal 'nerdz' should be changed to 'Robert'")
|
||||||
|
}
|
||||||
|
|
||||||
db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789")
|
db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789")
|
||||||
|
|
||||||
var product4 Product
|
var product4 Product
|
||||||
@ -925,6 +1071,34 @@ func TestUpdate(t *testing.T) {
|
|||||||
if db.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
|
if db.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
|
||||||
t.Error("No error should raise when update_column with CamelCase")
|
t.Error("No error should raise when update_column with CamelCase")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db.Table("animals").Where("name in (?)", []string{"Ferdinand"}).Update("name", "Franz")
|
||||||
|
|
||||||
|
var animal4 Animal
|
||||||
|
db.First(&animal4, animal1.Counter)
|
||||||
|
if animalUpdated_at1.Format(time.RFC3339Nano) != animal4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("animalUpdated_at should be updated if something changed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.First(&Animal{}, "name = 'Ferdinand'").Error == nil {
|
||||||
|
t.Errorf("Animal 'Fredinand' should be changed to 'Franz'")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.First(&Animal{}, "name = 'Robert'").Error != nil {
|
||||||
|
t.Errorf("Animal 'Robert' should not be changed to 'Francis'")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.First(&Animal{}, "name = 'Franz'").Error != nil {
|
||||||
|
t.Errorf("Product 'nerdz' should be changed to 'Franz'")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.Model(animal2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
|
||||||
|
t.Error("No error should raise when update with CamelCase")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.Model(&animal2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
|
||||||
|
t.Error("No error should raise when update_column with CamelCase")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdates(t *testing.T) {
|
func TestUpdates(t *testing.T) {
|
||||||
@ -1308,10 +1482,17 @@ func TestRelated(t *testing.T) {
|
|||||||
if len(emails) != 2 {
|
if len(emails) != 2 {
|
||||||
t.Errorf("Should have two emails")
|
t.Errorf("Should have two emails")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var emails2 []Email
|
||||||
|
db.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2)
|
||||||
|
if len(emails2) != 1 {
|
||||||
|
t.Errorf("Should have two emails")
|
||||||
|
}
|
||||||
|
|
||||||
var user1 User
|
var user1 User
|
||||||
db.Model(&user).Related(&user1.Emails)
|
db.Model(&user).Related(&user1.Emails)
|
||||||
if len(user1.Emails) != 2 {
|
if len(user1.Emails) != 2 {
|
||||||
t.Errorf("Should have two emails")
|
t.Errorf("Should have only one email match related condition")
|
||||||
}
|
}
|
||||||
|
|
||||||
var address1 Address
|
var address1 Address
|
||||||
@ -1511,6 +1692,10 @@ func TestTransaction(t *testing.T) {
|
|||||||
t.Errorf("Should find saved record, but got", err)
|
t.Errorf("Should find saved record, but got", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if sql_tx, ok := tx.CommonDB().(*sql.Tx); !ok || sql_tx == nil {
|
||||||
|
t.Errorf("Should return the underlying sql.Tx")
|
||||||
|
}
|
||||||
|
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
|
|
||||||
if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil {
|
if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil {
|
||||||
@ -1580,7 +1765,7 @@ func TestScan(t *testing.T) {
|
|||||||
var res result
|
var res result
|
||||||
db.Table("users").Select("name, age").Where("name = ?", 3).Scan(&res)
|
db.Table("users").Select("name, age").Where("name = ?", 3).Scan(&res)
|
||||||
if res.Name != "3" {
|
if res.Name != "3" {
|
||||||
t.Errorf("Scan into struct should works")
|
t.Errorf("Scan into struct should work")
|
||||||
}
|
}
|
||||||
|
|
||||||
var ress []result
|
var ress []result
|
||||||
@ -1694,6 +1879,33 @@ func TestHaving(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAnonymousField(t *testing.T) {
|
||||||
|
user := User{Name: "anonymous_field", Company: Company{Name: "company"}}
|
||||||
|
db.Save(&user)
|
||||||
|
|
||||||
|
var user2 User
|
||||||
|
db.First(&user2, "name = ?", "anonymous_field")
|
||||||
|
db.Model(&user2).Related(&user2.Company)
|
||||||
|
if user2.Company.Name != "company" {
|
||||||
|
t.Errorf("Should be able to get anonymous field")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnonymousScanner(t *testing.T) {
|
||||||
|
user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}}
|
||||||
|
db.Save(&user)
|
||||||
|
|
||||||
|
var user2 User
|
||||||
|
db.First(&user2, "name = ?", "anonymous_scanner")
|
||||||
|
if user2.Role.Name != "admin" {
|
||||||
|
t.Errorf("Should be able to get anonymous scanner")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user2.IsAdmin() {
|
||||||
|
t.Errorf("Should be able to get anonymous scanner")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestExecRawSql(t *testing.T) {
|
func TestExecRawSql(t *testing.T) {
|
||||||
db.Exec("update users set name=? where name in (?)", "jinzhu", []string{"1", "2", "3"})
|
db.Exec("update users set name=? where name in (?)", "jinzhu", []string{"1", "2", "3"})
|
||||||
if db.Where("name in (?)", []string{"1", "2", "3"}).First(&User{}).Error != gorm.RecordNotFound {
|
if db.Where("name in (?)", []string{"1", "2", "3"}).First(&User{}).Error != gorm.RecordNotFound {
|
||||||
@ -1701,6 +1913,45 @@ func TestExecRawSql(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTimeWithZone(t *testing.T) {
|
||||||
|
var format = "2006-01-02 15:04:05 -0700"
|
||||||
|
var times []time.Time
|
||||||
|
GMT8, _ := time.LoadLocation("Asia/Shanghai")
|
||||||
|
times = append(times, time.Date(2013, 02, 19, 1, 51, 49, 123456789, GMT8))
|
||||||
|
times = append(times, time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC))
|
||||||
|
|
||||||
|
for index, vtime := range times {
|
||||||
|
name := "time_with_zone_" + strconv.Itoa(index)
|
||||||
|
user := User{Name: name, Birthday: vtime}
|
||||||
|
db.Save(&user)
|
||||||
|
if user.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" {
|
||||||
|
t.Errorf("User's birthday should not be changed after save")
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.DeletedAt.UTC().Format(format) != "0001-01-01 00:00:00 +0000" {
|
||||||
|
t.Errorf("User's deleted at should be zero")
|
||||||
|
}
|
||||||
|
|
||||||
|
var findUser, findUser2, findUser3 User
|
||||||
|
db.First(&findUser, "name = ?", name)
|
||||||
|
if findUser.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" {
|
||||||
|
t.Errorf("User's birthday should not be changed after find")
|
||||||
|
}
|
||||||
|
|
||||||
|
if findUser.DeletedAt.UTC().Format(format) != "0001-01-01 00:00:00 +0000" {
|
||||||
|
t.Errorf("User's deleted at should be zero")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.Where("birthday >= ?", vtime.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
|
||||||
|
t.Errorf("User should be found")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !db.Where("birthday >= ?", vtime.Add(time.Minute)).First(&findUser3).RecordNotFound() {
|
||||||
|
t.Errorf("User should not be found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkGorm(b *testing.B) {
|
func BenchmarkGorm(b *testing.B) {
|
||||||
b.N = 2000
|
b.N = 2000
|
||||||
for x := 0; x < b.N; x++ {
|
for x := 0; x < b.N; x++ {
|
||||||
|
34
scope.go
34
scope.go
@ -13,13 +13,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Scope struct {
|
type Scope struct {
|
||||||
Value interface{}
|
Value interface{}
|
||||||
Search *search
|
Search *search
|
||||||
Sql string
|
Sql string
|
||||||
SqlVars []interface{}
|
SqlVars []interface{}
|
||||||
db *DB
|
db *DB
|
||||||
_values map[string]interface{}
|
_values map[string]interface{}
|
||||||
skipLeft bool
|
skipLeft bool
|
||||||
|
primaryKey string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewScope create scope for callbacks, including DB's search information
|
// NewScope create scope for callbacks, including DB's search information
|
||||||
@ -78,7 +79,12 @@ func (scope *Scope) HasError() bool {
|
|||||||
|
|
||||||
// PrimaryKey get the primary key's column name
|
// PrimaryKey get the primary key's column name
|
||||||
func (scope *Scope) PrimaryKey() string {
|
func (scope *Scope) PrimaryKey() string {
|
||||||
return "id"
|
if scope.primaryKey != "" {
|
||||||
|
return scope.primaryKey
|
||||||
|
}
|
||||||
|
|
||||||
|
scope.primaryKey = scope.getPrimaryKey()
|
||||||
|
return scope.primaryKey
|
||||||
}
|
}
|
||||||
|
|
||||||
// PrimaryKeyZero check the primary key is blank or not
|
// PrimaryKeyZero check the primary key is blank or not
|
||||||
@ -227,7 +233,7 @@ func (scope *Scope) Fields() []*Field {
|
|||||||
scopeTyp := indirectValue.Type()
|
scopeTyp := indirectValue.Type()
|
||||||
for i := 0; i < scopeTyp.NumField(); i++ {
|
for i := 0; i < scopeTyp.NumField(); i++ {
|
||||||
fieldStruct := scopeTyp.Field(i)
|
fieldStruct := scopeTyp.Field(i)
|
||||||
if fieldStruct.Anonymous || !ast.IsExported(fieldStruct.Name) {
|
if !ast.IsExported(fieldStruct.Name) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -238,7 +244,13 @@ func (scope *Scope) Fields() []*Field {
|
|||||||
value := indirectValue.FieldByName(fieldStruct.Name)
|
value := indirectValue.FieldByName(fieldStruct.Name)
|
||||||
field.Value = value.Interface()
|
field.Value = value.Interface()
|
||||||
field.IsBlank = isBlank(value)
|
field.IsBlank = isBlank(value)
|
||||||
field.isPrimaryKey = scope.PrimaryKey() == field.DBName
|
|
||||||
|
// Search for primary key tag identifier
|
||||||
|
field.isPrimaryKey = scope.PrimaryKey() == field.DBName || fieldStruct.Tag.Get("primaryKey") != ""
|
||||||
|
|
||||||
|
if field.isPrimaryKey {
|
||||||
|
scope.primaryKey = field.DBName
|
||||||
|
}
|
||||||
|
|
||||||
if scope.db != nil {
|
if scope.db != nil {
|
||||||
field.Tag = fieldStruct.Tag
|
field.Tag = fieldStruct.Tag
|
||||||
@ -252,7 +264,7 @@ func (scope *Scope) Fields() []*Field {
|
|||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
typ = typ.Elem()
|
typ = typ.Elem()
|
||||||
|
|
||||||
if _, ok := field.Value.([]byte); !ok {
|
if typ.Kind() == reflect.Struct {
|
||||||
foreignKey := scopeTyp.Name() + "Id"
|
foreignKey := scopeTyp.Name() + "Id"
|
||||||
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
|
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
|
||||||
field.ForeignKey = foreignKey
|
field.ForeignKey = foreignKey
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"go/ast"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -309,26 +310,24 @@ func (scope *Scope) sqlTagForField(field *Field) (tag string) {
|
|||||||
value := field.Value
|
value := field.Value
|
||||||
reflectValue := reflect.ValueOf(value)
|
reflectValue := reflect.ValueOf(value)
|
||||||
|
|
||||||
if field.IsScanner() {
|
|
||||||
value = reflectValue.Field(0).Interface()
|
|
||||||
}
|
|
||||||
|
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
if _, ok := value.([]byte); !ok {
|
if _, ok := value.([]byte); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if !field.IsTime() && !field.IsScanner() {
|
if field.IsScanner() {
|
||||||
|
reflectValue = reflectValue.Field(0)
|
||||||
|
} else if !field.IsTime() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(tag) == 0 {
|
if len(tag) == 0 {
|
||||||
if field.isPrimaryKey {
|
if field.isPrimaryKey {
|
||||||
tag = scope.Dialect().PrimaryKeyTag(value, size)
|
tag = scope.Dialect().PrimaryKeyTag(reflectValue, size)
|
||||||
} else {
|
} else {
|
||||||
tag = scope.Dialect().SqlTag(value, size)
|
tag = scope.Dialect().SqlTag(reflectValue, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -395,7 +394,7 @@ func (scope *Scope) typeName() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||||
toScope := scope.New(value)
|
toScope := scope.db.NewScope(value)
|
||||||
|
|
||||||
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
|
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
|
||||||
if foreignValue, ok := scope.FieldByName(foreignKey); ok {
|
if foreignValue, ok := scope.FieldByName(foreignKey); ok {
|
||||||
@ -474,3 +473,33 @@ func (scope *Scope) autoMigrate() *Scope {
|
|||||||
}
|
}
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) getPrimaryKey() string {
|
||||||
|
var indirectValue reflect.Value
|
||||||
|
|
||||||
|
indirectValue = reflect.Indirect(reflect.ValueOf(scope.Value))
|
||||||
|
|
||||||
|
if indirectValue.Kind() == reflect.Slice {
|
||||||
|
indirectValue = reflect.New(indirectValue.Type().Elem()).Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !indirectValue.IsValid() {
|
||||||
|
return "id"
|
||||||
|
}
|
||||||
|
|
||||||
|
scopeTyp := indirectValue.Type()
|
||||||
|
for i := 0; i < scopeTyp.NumField(); i++ {
|
||||||
|
fieldStruct := scopeTyp.Field(i)
|
||||||
|
if !ast.IsExported(fieldStruct.Name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// if primaryKey tag found, return column name
|
||||||
|
if fieldStruct.Tag.Get("primaryKey") != "" {
|
||||||
|
return toSnake(fieldStruct.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//If primaryKey tag not found, fallback to id
|
||||||
|
return "id"
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user