commit
912bca0d1a
22
README.md
22
README.md
@ -28,7 +28,7 @@ go get github.com/jinzhu/gorm
|
||||
## Conventions
|
||||
|
||||
* Table name is the plural of struct name's snake case.
|
||||
Disable pluralization with `db.SingularTable(true)`, or [specify your table name](#specify-table-name)
|
||||
Disable pluralization with `db.SingularTable(true)`, or [Specifying the Table Name for Struct permanently with TableName](#Specifying-the-Table-Name-for-Struct-permanently-with-TableName)
|
||||
* Column name is the snake case of field's name.
|
||||
* Use `Id int64` field as primary key.
|
||||
* Use tag `sql` to change field's property, change the tag name with `db.SetTagIdentifier(new_name)`.
|
||||
@ -47,6 +47,20 @@ db.First(&user)
|
||||
DB.Save(&User{Name: "xxx"}) // table "users"
|
||||
```
|
||||
|
||||
## Existing schema
|
||||
|
||||
If you have and existing database schema and some of your tables does not follow the conventions, (and you can't rename your table names), please use: [Specifying the Table Name for Struct permanently with TableName](#Specifying-the-Table-Name-for-Struct-permanently-with-TableName).
|
||||
|
||||
If your primary key field is different from `id`, you can add a tag to the field structure to specify that this field is a primary key.
|
||||
|
||||
```go
|
||||
type Animal struct { // animals
|
||||
AnimalId int64 `primaryKey:"yes"`
|
||||
Birthday time.Time
|
||||
Age int64
|
||||
}
|
||||
```
|
||||
|
||||
# Getting Started
|
||||
|
||||
```go
|
||||
@ -97,9 +111,9 @@ import _ "github.com/lib/pq"
|
||||
// import _ "github.com/go-sql-driver/mysql"
|
||||
// import _ "github.com/mattn/go-sqlite3"
|
||||
|
||||
db, err := Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||
// db, err = Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True")
|
||||
// db, err = Open("sqlite3", "/tmp/gorm.db")
|
||||
db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||
// db, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True")
|
||||
// db, err = gorm.Open("sqlite3", "/tmp/gorm.db")
|
||||
|
||||
// Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB)
|
||||
d := db.DB()
|
||||
|
@ -1,10 +1,17 @@
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
var timeType = reflect.TypeOf(time.Time{})
|
||||
|
||||
type Dialect interface {
|
||||
BinVar(i int) string
|
||||
SupportLastInsertId() bool
|
||||
SqlTag(column interface{}, size int) string
|
||||
PrimaryKeyTag(column interface{}, size int) string
|
||||
SqlTag(value reflect.Value, size int) string
|
||||
PrimaryKeyTag(value reflect.Value, size int) string
|
||||
ReturningStr(key string) string
|
||||
Quote(key string) string
|
||||
}
|
||||
|
@ -1,9 +1,8 @@
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type mysql struct{}
|
||||
@ -16,41 +15,44 @@ func (s *mysql) SupportLastInsertId() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (d *mysql) SqlTag(column interface{}, size int) string {
|
||||
switch column.(type) {
|
||||
case time.Time:
|
||||
return "datetime"
|
||||
case bool, sql.NullBool:
|
||||
func (d *mysql) SqlTag(value reflect.Value, size int) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
return "int"
|
||||
case int64, uint64, sql.NullInt64:
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
return "bigint"
|
||||
case float32, float64, sql.NullFloat64:
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "double"
|
||||
case []byte:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varbinary(%d)", size)
|
||||
} else {
|
||||
return "longblob"
|
||||
}
|
||||
case string, sql.NullString:
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
return "longtext"
|
||||
}
|
||||
case reflect.Struct:
|
||||
if value.Type() == timeType {
|
||||
return "datetime"
|
||||
}
|
||||
default:
|
||||
panic("Invalid sql type for mysql")
|
||||
if _, ok := value.Interface().([]byte); ok {
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varbinary(%d)", size)
|
||||
} else {
|
||||
return "longblob"
|
||||
}
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s *mysql) PrimaryKeyTag(column interface{}, size int) string {
|
||||
func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||
suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
|
||||
switch column.(type) {
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
return "int" + suffix_str
|
||||
case int64, uint64:
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
return "bigint" + suffix_str
|
||||
default:
|
||||
panic("Invalid primary key type")
|
||||
|
@ -1,9 +1,8 @@
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type postgres struct {
|
||||
@ -17,36 +16,38 @@ func (s *postgres) SupportLastInsertId() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *postgres) SqlTag(column interface{}, size int) string {
|
||||
switch column.(type) {
|
||||
case time.Time:
|
||||
return "timestamp with time zone"
|
||||
case bool, sql.NullBool:
|
||||
func (d *postgres) SqlTag(value reflect.Value, size int) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
return "integer"
|
||||
case int64, uint64, sql.NullInt64:
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
return "bigint"
|
||||
case float32, float64, sql.NullFloat64:
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "numeric"
|
||||
case []byte:
|
||||
return "bytea"
|
||||
case string, sql.NullString:
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
return "text"
|
||||
}
|
||||
return "text"
|
||||
case reflect.Struct:
|
||||
if value.Type() == timeType {
|
||||
return "timestamp with time zone"
|
||||
}
|
||||
default:
|
||||
panic("Invalid sql type for postgres")
|
||||
if _, ok := value.Interface().([]byte); ok {
|
||||
return "bytea"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s *postgres) PrimaryKeyTag(column interface{}, size int) string {
|
||||
switch column.(type) {
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
func (s *postgres) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
return "serial PRIMARY KEY"
|
||||
case int64, uint64:
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
return "bigserial PRIMARY KEY"
|
||||
default:
|
||||
panic("Invalid primary key type")
|
||||
|
@ -1,9 +1,8 @@
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type sqlite3 struct{}
|
||||
@ -16,33 +15,41 @@ func (s *sqlite3) SupportLastInsertId() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *sqlite3) SqlTag(column interface{}, size int) string {
|
||||
switch column.(type) {
|
||||
case time.Time:
|
||||
return "datetime"
|
||||
case bool, sql.NullBool:
|
||||
func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "bool"
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
return "integer"
|
||||
case int64, uint64, sql.NullInt64:
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
return "bigint"
|
||||
case float32, float64, sql.NullFloat64:
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "real"
|
||||
case []byte:
|
||||
return "blob"
|
||||
case string, sql.NullString:
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
return "text"
|
||||
}
|
||||
case reflect.Struct:
|
||||
if value.Type() == timeType {
|
||||
return "datetime"
|
||||
}
|
||||
default:
|
||||
panic("Invalid sql type for sqlite3")
|
||||
if _, ok := value.Interface().([]byte); ok {
|
||||
return "blob"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s *sqlite3) PrimaryKeyTag(column interface{}, size int) string {
|
||||
return "INTEGER PRIMARY KEY"
|
||||
func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr, reflect.Int64, reflect.Uint64:
|
||||
return "INTEGER PRIMARY KEY"
|
||||
default:
|
||||
panic("Invalid primary key type")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sqlite3) ReturningStr(key string) (str string) {
|
||||
|
11
main.go
11
main.go
@ -32,6 +32,13 @@ func (s *DB) DB() *sql.DB {
|
||||
return s.db.(*sql.DB)
|
||||
}
|
||||
|
||||
// Return the underlying sql.DB or sql.Tx instance.
|
||||
// Use of this method is discouraged. It's mainly intended to allow
|
||||
// coexistence with legacy non-GORM code.
|
||||
func (s *DB) CommonDB() sqlCommon {
|
||||
return s.db
|
||||
}
|
||||
|
||||
func (s *DB) Callback() *callback {
|
||||
s.parent.callback = s.parent.callback.clone()
|
||||
return s.parent.callback
|
||||
@ -124,13 +131,13 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
|
||||
|
||||
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
||||
scope := s.clone().NewScope(out)
|
||||
scope.Search = scope.Search.clone().order(scope.PrimaryKey()).limit(1)
|
||||
scope.Search = scope.Search.clone().order(scope.TableName()+"."+scope.PrimaryKey()).limit(1)
|
||||
return scope.inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
||||
}
|
||||
|
||||
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
||||
scope := s.clone().NewScope(out)
|
||||
scope.Search = scope.Search.clone().order(scope.PrimaryKey() + " DESC").limit(1)
|
||||
scope.Search = scope.Search.clone().order(scope.TableName()+"."+scope.PrimaryKey() + " DESC").limit(1)
|
||||
return scope.inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
||||
}
|
||||
|
||||
|
277
main_test.go
277
main_test.go
@ -7,9 +7,9 @@ import (
|
||||
"fmt"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/jinzhu/gorm"
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/jinzhu/gorm"
|
||||
|
||||
"os"
|
||||
"reflect"
|
||||
@ -22,9 +22,45 @@ type IgnoredEmbedStruct struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type Num int64
|
||||
|
||||
func (i *Num) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
case int64:
|
||||
*i = Num(s)
|
||||
default:
|
||||
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Company struct {
|
||||
Id int64
|
||||
Name string
|
||||
}
|
||||
|
||||
type Role struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
func (role *Role) Scan(value interface{}) error {
|
||||
role.Name = string(value.([]uint8))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (role Role) Value() (driver.Value, error) {
|
||||
return role.Name, nil
|
||||
}
|
||||
|
||||
func (role Role) IsAdmin() bool {
|
||||
return role.Name == "admin"
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Id int64 // Id: Primary key
|
||||
Age int64
|
||||
UserNum Num
|
||||
Name string `sql:"size:255"`
|
||||
Birthday time.Time // Time
|
||||
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
|
||||
@ -39,8 +75,12 @@ type User struct {
|
||||
When time.Time
|
||||
CreditCard CreditCard
|
||||
Latitude float64
|
||||
PasswordHash []byte
|
||||
IgnoreMe int64 `sql:"-"`
|
||||
CompanyId int64
|
||||
Company
|
||||
Role
|
||||
PasswordHash []byte
|
||||
IgnoreMe int64 `sql:"-"`
|
||||
IgnoreStringSlice []string `sql:"-"`
|
||||
}
|
||||
|
||||
type CreditCard struct {
|
||||
@ -87,6 +127,13 @@ type Product struct {
|
||||
AfterDeleteCallTimes int64
|
||||
}
|
||||
|
||||
type Animal struct {
|
||||
Counter int64 `primaryKey:"yes"`
|
||||
Name string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
db gorm.DB
|
||||
t1, t2, t3, t4, t5 time.Time
|
||||
@ -128,6 +175,13 @@ func init() {
|
||||
db.Exec("drop table emails;")
|
||||
db.Exec("drop table addresses")
|
||||
db.Exec("drop table credit_cards")
|
||||
db.Exec("drop table roles")
|
||||
db.Exec("drop table companies")
|
||||
db.Exec("drop table animals")
|
||||
|
||||
if err = db.CreateTable(&Animal{}).Error; err != nil {
|
||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||
}
|
||||
|
||||
if err = db.CreateTable(&User{}).Error; err != nil {
|
||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||
@ -149,17 +203,30 @@ func init() {
|
||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||
}
|
||||
|
||||
if err = db.AutoMigrate(Company{}).Error; err != nil {
|
||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||
}
|
||||
|
||||
if err = db.AutoMigrate(Role{}).Error; err != nil {
|
||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||
}
|
||||
|
||||
var shortForm = "2006-01-02 15:04:05"
|
||||
t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40")
|
||||
t2, _ = time.Parse(shortForm, "2002-01-01 00:00:00")
|
||||
t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00")
|
||||
t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00")
|
||||
t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00")
|
||||
db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now()})
|
||||
db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now(), UserNum: Num(111)})
|
||||
db.Save(&User{Name: "2", Age: 20, Birthday: t2})
|
||||
db.Save(&User{Name: "3", Age: 22, Birthday: t3})
|
||||
db.Save(&User{Name: "3", Age: 24, Birthday: t4})
|
||||
db.Save(&User{Name: "5", Age: 26, Birthday: t4})
|
||||
|
||||
db.Save(&Animal{Name: "First"})
|
||||
db.Save(&Animal{Name: "Amazing"})
|
||||
db.Save(&Animal{Name: "Horse"})
|
||||
db.Save(&Animal{Name: "Last"})
|
||||
}
|
||||
|
||||
func TestFirstAndLast(t *testing.T) {
|
||||
@ -170,7 +237,7 @@ func TestFirstAndLast(t *testing.T) {
|
||||
db.Last(&user3)
|
||||
db.Order("id desc").Find(&user4)
|
||||
if user1.Id != user2.Id || user3.Id != user4.Id {
|
||||
t.Errorf("First and Last should works correctly")
|
||||
t.Errorf("First and Last should work correctly")
|
||||
}
|
||||
|
||||
var users []User
|
||||
@ -180,6 +247,52 @@ func TestFirstAndLast(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirstAndLastWithJoins(t *testing.T) {
|
||||
var user1, user2, user3, user4 User
|
||||
db.Joins("left join emails on emails.user_id = users.id").First(&user1)
|
||||
db.Order("id").Find(&user2)
|
||||
|
||||
db.Joins("left join emails on emails.user_id = users.id").Last(&user3)
|
||||
db.Order("id desc").Find(&user4)
|
||||
if user1.Id != user2.Id || user3.Id != user4.Id {
|
||||
t.Errorf("First and Last should work correctly with Joins")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirstAndLastForTableWithNoStdPrimaryKey(t *testing.T) {
|
||||
var animal1, animal2, animal3, animal4 Animal
|
||||
db.First(&animal1)
|
||||
db.Order("counter").Find(&animal2)
|
||||
|
||||
db.Last(&animal3)
|
||||
db.Order("counter desc").Find(&animal4)
|
||||
if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter {
|
||||
t.Errorf("First and Last should work correctly")
|
||||
}
|
||||
|
||||
var animals []Animal
|
||||
db.First(&animals)
|
||||
if len(animals) != 1 {
|
||||
t.Errorf("Find first record as map")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveCustomType(t *testing.T) {
|
||||
var user, user1 User
|
||||
db.First(&user, "name = ?", "1")
|
||||
if user.UserNum != Num(111) {
|
||||
t.Errorf("UserNum should be saved correctly")
|
||||
}
|
||||
|
||||
user.UserNum = Num(222)
|
||||
db.Save(&user)
|
||||
|
||||
db.First(&user1, "name = ?", "1")
|
||||
if user1.UserNum != Num(222) {
|
||||
t.Errorf("UserNum should be updated correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrecision(t *testing.T) {
|
||||
f := 35.03554004971999
|
||||
user := User{Name: "Precision", Latitude: f}
|
||||
@ -518,7 +631,7 @@ func TestOrderAndPluck(t *testing.T) {
|
||||
|
||||
db.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4)
|
||||
if reflect.DeepEqual(ages3, ages4) {
|
||||
t.Errorf("Reorder should works")
|
||||
t.Errorf("Reorder should work")
|
||||
}
|
||||
|
||||
var names []string
|
||||
@ -535,7 +648,7 @@ func TestLimit(t *testing.T) {
|
||||
db.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3)
|
||||
|
||||
if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 {
|
||||
t.Errorf("Limit should works")
|
||||
t.Errorf("Limit should work")
|
||||
}
|
||||
}
|
||||
|
||||
@ -544,7 +657,7 @@ func TestOffset(t *testing.T) {
|
||||
db.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
|
||||
|
||||
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
|
||||
t.Errorf("Offset should works")
|
||||
t.Errorf("Offset should work")
|
||||
}
|
||||
}
|
||||
|
||||
@ -561,7 +674,7 @@ func TestCount(t *testing.T) {
|
||||
var users []User
|
||||
|
||||
if err := db.Where("name = ?", "1").Or("name = ?", "3").Find(&users).Count(&count).Error; err != nil {
|
||||
t.Errorf("Count should works", err)
|
||||
t.Errorf("Count should work", err)
|
||||
}
|
||||
|
||||
if count != int64(len(users)) {
|
||||
@ -570,7 +683,7 @@ func TestCount(t *testing.T) {
|
||||
|
||||
db.Model(&User{}).Where("name = ?", "1").Count(&count1).Or("name = ?", "3").Count(&count2)
|
||||
if count1 != 1 || count2 != 3 {
|
||||
t.Errorf("Multiple count should works")
|
||||
t.Errorf("Multiple count should work")
|
||||
}
|
||||
}
|
||||
|
||||
@ -688,7 +801,7 @@ func TestRunCallbacks(t *testing.T) {
|
||||
var products []Product
|
||||
db.Find(&products, "code = ?", "unique_code")
|
||||
if products[0].AfterFindCallTimes != 2 {
|
||||
t.Errorf("AfterFind callbacks should works with slice")
|
||||
t.Errorf("AfterFind callbacks should work with slice")
|
||||
}
|
||||
|
||||
db.Where("Code = ?", "unique_code").First(&p)
|
||||
@ -869,17 +982,31 @@ func TestSetTableDirectly(t *testing.T) {
|
||||
func TestUpdate(t *testing.T) {
|
||||
product1 := Product{Code: "123"}
|
||||
product2 := Product{Code: "234"}
|
||||
animal1 := Animal{Name: "Ferdinand"}
|
||||
animal2 := Animal{Name: "nerdz"}
|
||||
|
||||
db.Save(&product1).Save(&product2).Update("code", "456")
|
||||
|
||||
if product2.Code != "456" {
|
||||
t.Errorf("Record should be updated with update attributes")
|
||||
}
|
||||
|
||||
db.Save(&animal1).Save(&animal2).Update("name", "Francis")
|
||||
|
||||
if animal2.Name != "Francis" {
|
||||
t.Errorf("Record should be updated with update attributes")
|
||||
}
|
||||
|
||||
db.First(&product1, product1.Id)
|
||||
db.First(&product2, product2.Id)
|
||||
updated_at1 := product1.UpdatedAt
|
||||
updated_at2 := product2.UpdatedAt
|
||||
|
||||
db.First(&animal1, animal1.Counter)
|
||||
db.First(&animal2, animal2.Counter)
|
||||
animalUpdated_at1 := animal1.UpdatedAt
|
||||
animalUpdated_at2 := animal2.UpdatedAt
|
||||
|
||||
var product3 Product
|
||||
db.First(&product3, product2.Id).Update("code", "456")
|
||||
if updated_at2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
|
||||
@ -898,6 +1025,25 @@ func TestUpdate(t *testing.T) {
|
||||
t.Errorf("Product 234 should be changed to 456")
|
||||
}
|
||||
|
||||
var animal3 Animal
|
||||
db.First(&animal3, animal2.Counter).Update("Name", "Robert")
|
||||
|
||||
if animalUpdated_at2.Format(time.RFC3339Nano) != animal2.UpdatedAt.Format(time.RFC3339Nano) {
|
||||
t.Errorf("updated_at should not be updated if nothing changed")
|
||||
}
|
||||
|
||||
if db.First(&Animal{}, "name = 'Ferdinand'").Error != nil {
|
||||
t.Errorf("Animal 'Ferdinand' should not be updated")
|
||||
}
|
||||
|
||||
if db.First(&Animal{}, "name = 'nerdz'").Error == nil {
|
||||
t.Errorf("Animal 'nerdz' should be changed to 'Francis'")
|
||||
}
|
||||
|
||||
if db.First(&Animal{}, "name = 'Robert'").Error != nil {
|
||||
t.Errorf("Animal 'nerdz' should be changed to 'Robert'")
|
||||
}
|
||||
|
||||
db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789")
|
||||
|
||||
var product4 Product
|
||||
@ -925,6 +1071,34 @@ func TestUpdate(t *testing.T) {
|
||||
if db.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
|
||||
t.Error("No error should raise when update_column with CamelCase")
|
||||
}
|
||||
|
||||
db.Table("animals").Where("name in (?)", []string{"Ferdinand"}).Update("name", "Franz")
|
||||
|
||||
var animal4 Animal
|
||||
db.First(&animal4, animal1.Counter)
|
||||
if animalUpdated_at1.Format(time.RFC3339Nano) != animal4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||
t.Errorf("animalUpdated_at should be updated if something changed")
|
||||
}
|
||||
|
||||
if db.First(&Animal{}, "name = 'Ferdinand'").Error == nil {
|
||||
t.Errorf("Animal 'Fredinand' should be changed to 'Franz'")
|
||||
}
|
||||
|
||||
if db.First(&Animal{}, "name = 'Robert'").Error != nil {
|
||||
t.Errorf("Animal 'Robert' should not be changed to 'Francis'")
|
||||
}
|
||||
|
||||
if db.First(&Animal{}, "name = 'Franz'").Error != nil {
|
||||
t.Errorf("Product 'nerdz' should be changed to 'Franz'")
|
||||
}
|
||||
|
||||
if db.Model(animal2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
|
||||
t.Error("No error should raise when update with CamelCase")
|
||||
}
|
||||
|
||||
if db.Model(&animal2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
|
||||
t.Error("No error should raise when update_column with CamelCase")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdates(t *testing.T) {
|
||||
@ -1308,10 +1482,17 @@ func TestRelated(t *testing.T) {
|
||||
if len(emails) != 2 {
|
||||
t.Errorf("Should have two emails")
|
||||
}
|
||||
|
||||
var emails2 []Email
|
||||
db.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2)
|
||||
if len(emails2) != 1 {
|
||||
t.Errorf("Should have two emails")
|
||||
}
|
||||
|
||||
var user1 User
|
||||
db.Model(&user).Related(&user1.Emails)
|
||||
if len(user1.Emails) != 2 {
|
||||
t.Errorf("Should have two emails")
|
||||
t.Errorf("Should have only one email match related condition")
|
||||
}
|
||||
|
||||
var address1 Address
|
||||
@ -1511,6 +1692,10 @@ func TestTransaction(t *testing.T) {
|
||||
t.Errorf("Should find saved record, but got", err)
|
||||
}
|
||||
|
||||
if sql_tx, ok := tx.CommonDB().(*sql.Tx); !ok || sql_tx == nil {
|
||||
t.Errorf("Should return the underlying sql.Tx")
|
||||
}
|
||||
|
||||
tx.Rollback()
|
||||
|
||||
if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil {
|
||||
@ -1580,7 +1765,7 @@ func TestScan(t *testing.T) {
|
||||
var res result
|
||||
db.Table("users").Select("name, age").Where("name = ?", 3).Scan(&res)
|
||||
if res.Name != "3" {
|
||||
t.Errorf("Scan into struct should works")
|
||||
t.Errorf("Scan into struct should work")
|
||||
}
|
||||
|
||||
var ress []result
|
||||
@ -1694,6 +1879,33 @@ func TestHaving(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnonymousField(t *testing.T) {
|
||||
user := User{Name: "anonymous_field", Company: Company{Name: "company"}}
|
||||
db.Save(&user)
|
||||
|
||||
var user2 User
|
||||
db.First(&user2, "name = ?", "anonymous_field")
|
||||
db.Model(&user2).Related(&user2.Company)
|
||||
if user2.Company.Name != "company" {
|
||||
t.Errorf("Should be able to get anonymous field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnonymousScanner(t *testing.T) {
|
||||
user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}}
|
||||
db.Save(&user)
|
||||
|
||||
var user2 User
|
||||
db.First(&user2, "name = ?", "anonymous_scanner")
|
||||
if user2.Role.Name != "admin" {
|
||||
t.Errorf("Should be able to get anonymous scanner")
|
||||
}
|
||||
|
||||
if !user2.IsAdmin() {
|
||||
t.Errorf("Should be able to get anonymous scanner")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecRawSql(t *testing.T) {
|
||||
db.Exec("update users set name=? where name in (?)", "jinzhu", []string{"1", "2", "3"})
|
||||
if db.Where("name in (?)", []string{"1", "2", "3"}).First(&User{}).Error != gorm.RecordNotFound {
|
||||
@ -1701,6 +1913,45 @@ func TestExecRawSql(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeWithZone(t *testing.T) {
|
||||
var format = "2006-01-02 15:04:05 -0700"
|
||||
var times []time.Time
|
||||
GMT8, _ := time.LoadLocation("Asia/Shanghai")
|
||||
times = append(times, time.Date(2013, 02, 19, 1, 51, 49, 123456789, GMT8))
|
||||
times = append(times, time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC))
|
||||
|
||||
for index, vtime := range times {
|
||||
name := "time_with_zone_" + strconv.Itoa(index)
|
||||
user := User{Name: name, Birthday: vtime}
|
||||
db.Save(&user)
|
||||
if user.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" {
|
||||
t.Errorf("User's birthday should not be changed after save")
|
||||
}
|
||||
|
||||
if user.DeletedAt.UTC().Format(format) != "0001-01-01 00:00:00 +0000" {
|
||||
t.Errorf("User's deleted at should be zero")
|
||||
}
|
||||
|
||||
var findUser, findUser2, findUser3 User
|
||||
db.First(&findUser, "name = ?", name)
|
||||
if findUser.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" {
|
||||
t.Errorf("User's birthday should not be changed after find")
|
||||
}
|
||||
|
||||
if findUser.DeletedAt.UTC().Format(format) != "0001-01-01 00:00:00 +0000" {
|
||||
t.Errorf("User's deleted at should be zero")
|
||||
}
|
||||
|
||||
if db.Where("birthday >= ?", vtime.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
|
||||
t.Errorf("User should be found")
|
||||
}
|
||||
|
||||
if !db.Where("birthday >= ?", vtime.Add(time.Minute)).First(&findUser3).RecordNotFound() {
|
||||
t.Errorf("User should not be found")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGorm(b *testing.B) {
|
||||
b.N = 2000
|
||||
for x := 0; x < b.N; x++ {
|
||||
|
34
scope.go
34
scope.go
@ -13,13 +13,14 @@ import (
|
||||
)
|
||||
|
||||
type Scope struct {
|
||||
Value interface{}
|
||||
Search *search
|
||||
Sql string
|
||||
SqlVars []interface{}
|
||||
db *DB
|
||||
_values map[string]interface{}
|
||||
skipLeft bool
|
||||
Value interface{}
|
||||
Search *search
|
||||
Sql string
|
||||
SqlVars []interface{}
|
||||
db *DB
|
||||
_values map[string]interface{}
|
||||
skipLeft bool
|
||||
primaryKey string
|
||||
}
|
||||
|
||||
// NewScope create scope for callbacks, including DB's search information
|
||||
@ -78,7 +79,12 @@ func (scope *Scope) HasError() bool {
|
||||
|
||||
// PrimaryKey get the primary key's column name
|
||||
func (scope *Scope) PrimaryKey() string {
|
||||
return "id"
|
||||
if scope.primaryKey != "" {
|
||||
return scope.primaryKey
|
||||
}
|
||||
|
||||
scope.primaryKey = scope.getPrimaryKey()
|
||||
return scope.primaryKey
|
||||
}
|
||||
|
||||
// PrimaryKeyZero check the primary key is blank or not
|
||||
@ -227,7 +233,7 @@ func (scope *Scope) Fields() []*Field {
|
||||
scopeTyp := indirectValue.Type()
|
||||
for i := 0; i < scopeTyp.NumField(); i++ {
|
||||
fieldStruct := scopeTyp.Field(i)
|
||||
if fieldStruct.Anonymous || !ast.IsExported(fieldStruct.Name) {
|
||||
if !ast.IsExported(fieldStruct.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -238,7 +244,13 @@ func (scope *Scope) Fields() []*Field {
|
||||
value := indirectValue.FieldByName(fieldStruct.Name)
|
||||
field.Value = value.Interface()
|
||||
field.IsBlank = isBlank(value)
|
||||
field.isPrimaryKey = scope.PrimaryKey() == field.DBName
|
||||
|
||||
// Search for primary key tag identifier
|
||||
field.isPrimaryKey = scope.PrimaryKey() == field.DBName || fieldStruct.Tag.Get("primaryKey") != ""
|
||||
|
||||
if field.isPrimaryKey {
|
||||
scope.primaryKey = field.DBName
|
||||
}
|
||||
|
||||
if scope.db != nil {
|
||||
field.Tag = fieldStruct.Tag
|
||||
@ -252,7 +264,7 @@ func (scope *Scope) Fields() []*Field {
|
||||
case reflect.Slice:
|
||||
typ = typ.Elem()
|
||||
|
||||
if _, ok := field.Value.([]byte); !ok {
|
||||
if typ.Kind() == reflect.Struct {
|
||||
foreignKey := scopeTyp.Name() + "Id"
|
||||
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
|
||||
field.ForeignKey = foreignKey
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
@ -309,26 +310,24 @@ func (scope *Scope) sqlTagForField(field *Field) (tag string) {
|
||||
value := field.Value
|
||||
reflectValue := reflect.ValueOf(value)
|
||||
|
||||
if field.IsScanner() {
|
||||
value = reflectValue.Field(0).Interface()
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice:
|
||||
if _, ok := value.([]byte); !ok {
|
||||
return
|
||||
}
|
||||
case reflect.Struct:
|
||||
if !field.IsTime() && !field.IsScanner() {
|
||||
if field.IsScanner() {
|
||||
reflectValue = reflectValue.Field(0)
|
||||
} else if !field.IsTime() {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if len(tag) == 0 {
|
||||
if field.isPrimaryKey {
|
||||
tag = scope.Dialect().PrimaryKeyTag(value, size)
|
||||
tag = scope.Dialect().PrimaryKeyTag(reflectValue, size)
|
||||
} else {
|
||||
tag = scope.Dialect().SqlTag(value, size)
|
||||
tag = scope.Dialect().SqlTag(reflectValue, size)
|
||||
}
|
||||
}
|
||||
|
||||
@ -395,7 +394,7 @@ func (scope *Scope) typeName() string {
|
||||
}
|
||||
|
||||
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||
toScope := scope.New(value)
|
||||
toScope := scope.db.NewScope(value)
|
||||
|
||||
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
|
||||
if foreignValue, ok := scope.FieldByName(foreignKey); ok {
|
||||
@ -474,3 +473,33 @@ func (scope *Scope) autoMigrate() *Scope {
|
||||
}
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) getPrimaryKey() string {
|
||||
var indirectValue reflect.Value
|
||||
|
||||
indirectValue = reflect.Indirect(reflect.ValueOf(scope.Value))
|
||||
|
||||
if indirectValue.Kind() == reflect.Slice {
|
||||
indirectValue = reflect.New(indirectValue.Type().Elem()).Elem()
|
||||
}
|
||||
|
||||
if !indirectValue.IsValid() {
|
||||
return "id"
|
||||
}
|
||||
|
||||
scopeTyp := indirectValue.Type()
|
||||
for i := 0; i < scopeTyp.NumField(); i++ {
|
||||
fieldStruct := scopeTyp.Field(i)
|
||||
if !ast.IsExported(fieldStruct.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
// if primaryKey tag found, return column name
|
||||
if fieldStruct.Tag.Get("primaryKey") != "" {
|
||||
return toSnake(fieldStruct.Name)
|
||||
}
|
||||
}
|
||||
|
||||
//If primaryKey tag not found, fallback to id
|
||||
return "id"
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user