Merge pull request #1 from jinzhu/master

sync with master
This commit is contained in:
Pavel Larkin 2014-04-18 17:14:19 -04:00
commit 912bca0d1a
9 changed files with 428 additions and 98 deletions

View File

@ -28,7 +28,7 @@ go get github.com/jinzhu/gorm
## Conventions
* Table name is the plural of struct name's snake case.
Disable pluralization with `db.SingularTable(true)`, or [specify your table name](#specify-table-name)
Disable pluralization with `db.SingularTable(true)`, or [Specifying the Table Name for Struct permanently with TableName](#Specifying-the-Table-Name-for-Struct-permanently-with-TableName)
* Column name is the snake case of field's name.
* Use `Id int64` field as primary key.
* Use tag `sql` to change field's property, change the tag name with `db.SetTagIdentifier(new_name)`.
@ -47,6 +47,20 @@ db.First(&user)
DB.Save(&User{Name: "xxx"}) // table "users"
```
## Existing schema
If you have and existing database schema and some of your tables does not follow the conventions, (and you can't rename your table names), please use: [Specifying the Table Name for Struct permanently with TableName](#Specifying-the-Table-Name-for-Struct-permanently-with-TableName).
If your primary key field is different from `id`, you can add a tag to the field structure to specify that this field is a primary key.
```go
type Animal struct { // animals
AnimalId int64 `primaryKey:"yes"`
Birthday time.Time
Age int64
}
```
# Getting Started
```go
@ -97,9 +111,9 @@ import _ "github.com/lib/pq"
// import _ "github.com/go-sql-driver/mysql"
// import _ "github.com/mattn/go-sqlite3"
db, err := Open("postgres", "user=gorm dbname=gorm sslmode=disable")
// db, err = Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True")
// db, err = Open("sqlite3", "/tmp/gorm.db")
db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
// db, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True")
// db, err = gorm.Open("sqlite3", "/tmp/gorm.db")
// Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB)
d := db.DB()

View File

@ -1,10 +1,17 @@
package dialect
import (
"reflect"
"time"
)
var timeType = reflect.TypeOf(time.Time{})
type Dialect interface {
BinVar(i int) string
SupportLastInsertId() bool
SqlTag(column interface{}, size int) string
PrimaryKeyTag(column interface{}, size int) string
SqlTag(value reflect.Value, size int) string
PrimaryKeyTag(value reflect.Value, size int) string
ReturningStr(key string) string
Quote(key string) string
}

View File

@ -1,9 +1,8 @@
package dialect
import (
"database/sql"
"fmt"
"time"
"reflect"
)
type mysql struct{}
@ -16,41 +15,44 @@ func (s *mysql) SupportLastInsertId() bool {
return true
}
func (d *mysql) SqlTag(column interface{}, size int) string {
switch column.(type) {
case time.Time:
return "datetime"
case bool, sql.NullBool:
func (d *mysql) SqlTag(value reflect.Value, size int) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "int"
case int64, uint64, sql.NullInt64:
case reflect.Int64, reflect.Uint64:
return "bigint"
case float32, float64, sql.NullFloat64:
case reflect.Float32, reflect.Float64:
return "double"
case []byte:
if size > 0 && size < 65532 {
return fmt.Sprintf("varbinary(%d)", size)
} else {
return "longblob"
}
case string, sql.NullString:
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
} else {
return "longtext"
}
case reflect.Struct:
if value.Type() == timeType {
return "datetime"
}
default:
panic("Invalid sql type for mysql")
if _, ok := value.Interface().([]byte); ok {
if size > 0 && size < 65532 {
return fmt.Sprintf("varbinary(%d)", size)
} else {
return "longblob"
}
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
}
func (s *mysql) PrimaryKeyTag(column interface{}, size int) string {
func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string {
suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
switch column.(type) {
case int, int8, int16, int32, uint, uint8, uint16, uint32:
switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "int" + suffix_str
case int64, uint64:
case reflect.Int64, reflect.Uint64:
return "bigint" + suffix_str
default:
panic("Invalid primary key type")

View File

@ -1,9 +1,8 @@
package dialect
import (
"database/sql"
"fmt"
"time"
"reflect"
)
type postgres struct {
@ -17,36 +16,38 @@ func (s *postgres) SupportLastInsertId() bool {
return false
}
func (d *postgres) SqlTag(column interface{}, size int) string {
switch column.(type) {
case time.Time:
return "timestamp with time zone"
case bool, sql.NullBool:
func (d *postgres) SqlTag(value reflect.Value, size int) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "integer"
case int64, uint64, sql.NullInt64:
case reflect.Int64, reflect.Uint64:
return "bigint"
case float32, float64, sql.NullFloat64:
case reflect.Float32, reflect.Float64:
return "numeric"
case []byte:
return "bytea"
case string, sql.NullString:
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
} else {
return "text"
}
return "text"
case reflect.Struct:
if value.Type() == timeType {
return "timestamp with time zone"
}
default:
panic("Invalid sql type for postgres")
if _, ok := value.Interface().([]byte); ok {
return "bytea"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
}
func (s *postgres) PrimaryKeyTag(column interface{}, size int) string {
switch column.(type) {
case int, int8, int16, int32, uint, uint8, uint16, uint32:
func (s *postgres) PrimaryKeyTag(value reflect.Value, size int) string {
switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "serial PRIMARY KEY"
case int64, uint64:
case reflect.Int64, reflect.Uint64:
return "bigserial PRIMARY KEY"
default:
panic("Invalid primary key type")

View File

@ -1,9 +1,8 @@
package dialect
import (
"database/sql"
"fmt"
"time"
"reflect"
)
type sqlite3 struct{}
@ -16,33 +15,41 @@ func (s *sqlite3) SupportLastInsertId() bool {
return true
}
func (s *sqlite3) SqlTag(column interface{}, size int) string {
switch column.(type) {
case time.Time:
return "datetime"
case bool, sql.NullBool:
func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
switch value.Kind() {
case reflect.Bool:
return "bool"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "integer"
case int64, uint64, sql.NullInt64:
case reflect.Int64, reflect.Uint64:
return "bigint"
case float32, float64, sql.NullFloat64:
case reflect.Float32, reflect.Float64:
return "real"
case []byte:
return "blob"
case string, sql.NullString:
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
} else {
return "text"
}
case reflect.Struct:
if value.Type() == timeType {
return "datetime"
}
default:
panic("Invalid sql type for sqlite3")
if _, ok := value.Interface().([]byte); ok {
return "blob"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
}
func (s *sqlite3) PrimaryKeyTag(column interface{}, size int) string {
return "INTEGER PRIMARY KEY"
func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string {
switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr, reflect.Int64, reflect.Uint64:
return "INTEGER PRIMARY KEY"
default:
panic("Invalid primary key type")
}
}
func (s *sqlite3) ReturningStr(key string) (str string) {

11
main.go
View File

@ -32,6 +32,13 @@ func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB)
}
// Return the underlying sql.DB or sql.Tx instance.
// Use of this method is discouraged. It's mainly intended to allow
// coexistence with legacy non-GORM code.
func (s *DB) CommonDB() sqlCommon {
return s.db
}
func (s *DB) Callback() *callback {
s.parent.callback = s.parent.callback.clone()
return s.parent.callback
@ -124,13 +131,13 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
func (s *DB) First(out interface{}, where ...interface{}) *DB {
scope := s.clone().NewScope(out)
scope.Search = scope.Search.clone().order(scope.PrimaryKey()).limit(1)
scope.Search = scope.Search.clone().order(scope.TableName()+"."+scope.PrimaryKey()).limit(1)
return scope.inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
}
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
scope := s.clone().NewScope(out)
scope.Search = scope.Search.clone().order(scope.PrimaryKey() + " DESC").limit(1)
scope.Search = scope.Search.clone().order(scope.TableName()+"."+scope.PrimaryKey() + " DESC").limit(1)
return scope.inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
}

View File

@ -7,9 +7,9 @@ import (
"fmt"
_ "github.com/go-sql-driver/mysql"
"github.com/jinzhu/gorm"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"github.com/jinzhu/gorm"
"os"
"reflect"
@ -22,9 +22,45 @@ type IgnoredEmbedStruct struct {
Name string
}
type Num int64
func (i *Num) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
case int64:
*i = Num(s)
default:
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
}
return nil
}
type Company struct {
Id int64
Name string
}
type Role struct {
Name string
}
func (role *Role) Scan(value interface{}) error {
role.Name = string(value.([]uint8))
return nil
}
func (role Role) Value() (driver.Value, error) {
return role.Name, nil
}
func (role Role) IsAdmin() bool {
return role.Name == "admin"
}
type User struct {
Id int64 // Id: Primary key
Age int64
UserNum Num
Name string `sql:"size:255"`
Birthday time.Time // Time
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
@ -39,8 +75,12 @@ type User struct {
When time.Time
CreditCard CreditCard
Latitude float64
PasswordHash []byte
IgnoreMe int64 `sql:"-"`
CompanyId int64
Company
Role
PasswordHash []byte
IgnoreMe int64 `sql:"-"`
IgnoreStringSlice []string `sql:"-"`
}
type CreditCard struct {
@ -87,6 +127,13 @@ type Product struct {
AfterDeleteCallTimes int64
}
type Animal struct {
Counter int64 `primaryKey:"yes"`
Name string
CreatedAt time.Time
UpdatedAt time.Time
}
var (
db gorm.DB
t1, t2, t3, t4, t5 time.Time
@ -128,6 +175,13 @@ func init() {
db.Exec("drop table emails;")
db.Exec("drop table addresses")
db.Exec("drop table credit_cards")
db.Exec("drop table roles")
db.Exec("drop table companies")
db.Exec("drop table animals")
if err = db.CreateTable(&Animal{}).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
}
if err = db.CreateTable(&User{}).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
@ -149,17 +203,30 @@ func init() {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
}
if err = db.AutoMigrate(Company{}).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
}
if err = db.AutoMigrate(Role{}).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
}
var shortForm = "2006-01-02 15:04:05"
t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40")
t2, _ = time.Parse(shortForm, "2002-01-01 00:00:00")
t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00")
t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00")
t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00")
db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now()})
db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now(), UserNum: Num(111)})
db.Save(&User{Name: "2", Age: 20, Birthday: t2})
db.Save(&User{Name: "3", Age: 22, Birthday: t3})
db.Save(&User{Name: "3", Age: 24, Birthday: t4})
db.Save(&User{Name: "5", Age: 26, Birthday: t4})
db.Save(&Animal{Name: "First"})
db.Save(&Animal{Name: "Amazing"})
db.Save(&Animal{Name: "Horse"})
db.Save(&Animal{Name: "Last"})
}
func TestFirstAndLast(t *testing.T) {
@ -170,7 +237,7 @@ func TestFirstAndLast(t *testing.T) {
db.Last(&user3)
db.Order("id desc").Find(&user4)
if user1.Id != user2.Id || user3.Id != user4.Id {
t.Errorf("First and Last should works correctly")
t.Errorf("First and Last should work correctly")
}
var users []User
@ -180,6 +247,52 @@ func TestFirstAndLast(t *testing.T) {
}
}
func TestFirstAndLastWithJoins(t *testing.T) {
var user1, user2, user3, user4 User
db.Joins("left join emails on emails.user_id = users.id").First(&user1)
db.Order("id").Find(&user2)
db.Joins("left join emails on emails.user_id = users.id").Last(&user3)
db.Order("id desc").Find(&user4)
if user1.Id != user2.Id || user3.Id != user4.Id {
t.Errorf("First and Last should work correctly with Joins")
}
}
func TestFirstAndLastForTableWithNoStdPrimaryKey(t *testing.T) {
var animal1, animal2, animal3, animal4 Animal
db.First(&animal1)
db.Order("counter").Find(&animal2)
db.Last(&animal3)
db.Order("counter desc").Find(&animal4)
if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter {
t.Errorf("First and Last should work correctly")
}
var animals []Animal
db.First(&animals)
if len(animals) != 1 {
t.Errorf("Find first record as map")
}
}
func TestSaveCustomType(t *testing.T) {
var user, user1 User
db.First(&user, "name = ?", "1")
if user.UserNum != Num(111) {
t.Errorf("UserNum should be saved correctly")
}
user.UserNum = Num(222)
db.Save(&user)
db.First(&user1, "name = ?", "1")
if user1.UserNum != Num(222) {
t.Errorf("UserNum should be updated correctly")
}
}
func TestPrecision(t *testing.T) {
f := 35.03554004971999
user := User{Name: "Precision", Latitude: f}
@ -518,7 +631,7 @@ func TestOrderAndPluck(t *testing.T) {
db.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4)
if reflect.DeepEqual(ages3, ages4) {
t.Errorf("Reorder should works")
t.Errorf("Reorder should work")
}
var names []string
@ -535,7 +648,7 @@ func TestLimit(t *testing.T) {
db.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3)
if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 {
t.Errorf("Limit should works")
t.Errorf("Limit should work")
}
}
@ -544,7 +657,7 @@ func TestOffset(t *testing.T) {
db.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
t.Errorf("Offset should works")
t.Errorf("Offset should work")
}
}
@ -561,7 +674,7 @@ func TestCount(t *testing.T) {
var users []User
if err := db.Where("name = ?", "1").Or("name = ?", "3").Find(&users).Count(&count).Error; err != nil {
t.Errorf("Count should works", err)
t.Errorf("Count should work", err)
}
if count != int64(len(users)) {
@ -570,7 +683,7 @@ func TestCount(t *testing.T) {
db.Model(&User{}).Where("name = ?", "1").Count(&count1).Or("name = ?", "3").Count(&count2)
if count1 != 1 || count2 != 3 {
t.Errorf("Multiple count should works")
t.Errorf("Multiple count should work")
}
}
@ -688,7 +801,7 @@ func TestRunCallbacks(t *testing.T) {
var products []Product
db.Find(&products, "code = ?", "unique_code")
if products[0].AfterFindCallTimes != 2 {
t.Errorf("AfterFind callbacks should works with slice")
t.Errorf("AfterFind callbacks should work with slice")
}
db.Where("Code = ?", "unique_code").First(&p)
@ -869,17 +982,31 @@ func TestSetTableDirectly(t *testing.T) {
func TestUpdate(t *testing.T) {
product1 := Product{Code: "123"}
product2 := Product{Code: "234"}
animal1 := Animal{Name: "Ferdinand"}
animal2 := Animal{Name: "nerdz"}
db.Save(&product1).Save(&product2).Update("code", "456")
if product2.Code != "456" {
t.Errorf("Record should be updated with update attributes")
}
db.Save(&animal1).Save(&animal2).Update("name", "Francis")
if animal2.Name != "Francis" {
t.Errorf("Record should be updated with update attributes")
}
db.First(&product1, product1.Id)
db.First(&product2, product2.Id)
updated_at1 := product1.UpdatedAt
updated_at2 := product2.UpdatedAt
db.First(&animal1, animal1.Counter)
db.First(&animal2, animal2.Counter)
animalUpdated_at1 := animal1.UpdatedAt
animalUpdated_at2 := animal2.UpdatedAt
var product3 Product
db.First(&product3, product2.Id).Update("code", "456")
if updated_at2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
@ -898,6 +1025,25 @@ func TestUpdate(t *testing.T) {
t.Errorf("Product 234 should be changed to 456")
}
var animal3 Animal
db.First(&animal3, animal2.Counter).Update("Name", "Robert")
if animalUpdated_at2.Format(time.RFC3339Nano) != animal2.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("updated_at should not be updated if nothing changed")
}
if db.First(&Animal{}, "name = 'Ferdinand'").Error != nil {
t.Errorf("Animal 'Ferdinand' should not be updated")
}
if db.First(&Animal{}, "name = 'nerdz'").Error == nil {
t.Errorf("Animal 'nerdz' should be changed to 'Francis'")
}
if db.First(&Animal{}, "name = 'Robert'").Error != nil {
t.Errorf("Animal 'nerdz' should be changed to 'Robert'")
}
db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789")
var product4 Product
@ -925,6 +1071,34 @@ func TestUpdate(t *testing.T) {
if db.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
t.Error("No error should raise when update_column with CamelCase")
}
db.Table("animals").Where("name in (?)", []string{"Ferdinand"}).Update("name", "Franz")
var animal4 Animal
db.First(&animal4, animal1.Counter)
if animalUpdated_at1.Format(time.RFC3339Nano) != animal4.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("animalUpdated_at should be updated if something changed")
}
if db.First(&Animal{}, "name = 'Ferdinand'").Error == nil {
t.Errorf("Animal 'Fredinand' should be changed to 'Franz'")
}
if db.First(&Animal{}, "name = 'Robert'").Error != nil {
t.Errorf("Animal 'Robert' should not be changed to 'Francis'")
}
if db.First(&Animal{}, "name = 'Franz'").Error != nil {
t.Errorf("Product 'nerdz' should be changed to 'Franz'")
}
if db.Model(animal2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
t.Error("No error should raise when update with CamelCase")
}
if db.Model(&animal2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
t.Error("No error should raise when update_column with CamelCase")
}
}
func TestUpdates(t *testing.T) {
@ -1308,10 +1482,17 @@ func TestRelated(t *testing.T) {
if len(emails) != 2 {
t.Errorf("Should have two emails")
}
var emails2 []Email
db.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2)
if len(emails2) != 1 {
t.Errorf("Should have two emails")
}
var user1 User
db.Model(&user).Related(&user1.Emails)
if len(user1.Emails) != 2 {
t.Errorf("Should have two emails")
t.Errorf("Should have only one email match related condition")
}
var address1 Address
@ -1511,6 +1692,10 @@ func TestTransaction(t *testing.T) {
t.Errorf("Should find saved record, but got", err)
}
if sql_tx, ok := tx.CommonDB().(*sql.Tx); !ok || sql_tx == nil {
t.Errorf("Should return the underlying sql.Tx")
}
tx.Rollback()
if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil {
@ -1580,7 +1765,7 @@ func TestScan(t *testing.T) {
var res result
db.Table("users").Select("name, age").Where("name = ?", 3).Scan(&res)
if res.Name != "3" {
t.Errorf("Scan into struct should works")
t.Errorf("Scan into struct should work")
}
var ress []result
@ -1694,6 +1879,33 @@ func TestHaving(t *testing.T) {
}
}
func TestAnonymousField(t *testing.T) {
user := User{Name: "anonymous_field", Company: Company{Name: "company"}}
db.Save(&user)
var user2 User
db.First(&user2, "name = ?", "anonymous_field")
db.Model(&user2).Related(&user2.Company)
if user2.Company.Name != "company" {
t.Errorf("Should be able to get anonymous field")
}
}
func TestAnonymousScanner(t *testing.T) {
user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}}
db.Save(&user)
var user2 User
db.First(&user2, "name = ?", "anonymous_scanner")
if user2.Role.Name != "admin" {
t.Errorf("Should be able to get anonymous scanner")
}
if !user2.IsAdmin() {
t.Errorf("Should be able to get anonymous scanner")
}
}
func TestExecRawSql(t *testing.T) {
db.Exec("update users set name=? where name in (?)", "jinzhu", []string{"1", "2", "3"})
if db.Where("name in (?)", []string{"1", "2", "3"}).First(&User{}).Error != gorm.RecordNotFound {
@ -1701,6 +1913,45 @@ func TestExecRawSql(t *testing.T) {
}
}
func TestTimeWithZone(t *testing.T) {
var format = "2006-01-02 15:04:05 -0700"
var times []time.Time
GMT8, _ := time.LoadLocation("Asia/Shanghai")
times = append(times, time.Date(2013, 02, 19, 1, 51, 49, 123456789, GMT8))
times = append(times, time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC))
for index, vtime := range times {
name := "time_with_zone_" + strconv.Itoa(index)
user := User{Name: name, Birthday: vtime}
db.Save(&user)
if user.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" {
t.Errorf("User's birthday should not be changed after save")
}
if user.DeletedAt.UTC().Format(format) != "0001-01-01 00:00:00 +0000" {
t.Errorf("User's deleted at should be zero")
}
var findUser, findUser2, findUser3 User
db.First(&findUser, "name = ?", name)
if findUser.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" {
t.Errorf("User's birthday should not be changed after find")
}
if findUser.DeletedAt.UTC().Format(format) != "0001-01-01 00:00:00 +0000" {
t.Errorf("User's deleted at should be zero")
}
if db.Where("birthday >= ?", vtime.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
t.Errorf("User should be found")
}
if !db.Where("birthday >= ?", vtime.Add(time.Minute)).First(&findUser3).RecordNotFound() {
t.Errorf("User should not be found")
}
}
}
func BenchmarkGorm(b *testing.B) {
b.N = 2000
for x := 0; x < b.N; x++ {

View File

@ -13,13 +13,14 @@ import (
)
type Scope struct {
Value interface{}
Search *search
Sql string
SqlVars []interface{}
db *DB
_values map[string]interface{}
skipLeft bool
Value interface{}
Search *search
Sql string
SqlVars []interface{}
db *DB
_values map[string]interface{}
skipLeft bool
primaryKey string
}
// NewScope create scope for callbacks, including DB's search information
@ -78,7 +79,12 @@ func (scope *Scope) HasError() bool {
// PrimaryKey get the primary key's column name
func (scope *Scope) PrimaryKey() string {
return "id"
if scope.primaryKey != "" {
return scope.primaryKey
}
scope.primaryKey = scope.getPrimaryKey()
return scope.primaryKey
}
// PrimaryKeyZero check the primary key is blank or not
@ -227,7 +233,7 @@ func (scope *Scope) Fields() []*Field {
scopeTyp := indirectValue.Type()
for i := 0; i < scopeTyp.NumField(); i++ {
fieldStruct := scopeTyp.Field(i)
if fieldStruct.Anonymous || !ast.IsExported(fieldStruct.Name) {
if !ast.IsExported(fieldStruct.Name) {
continue
}
@ -238,7 +244,13 @@ func (scope *Scope) Fields() []*Field {
value := indirectValue.FieldByName(fieldStruct.Name)
field.Value = value.Interface()
field.IsBlank = isBlank(value)
field.isPrimaryKey = scope.PrimaryKey() == field.DBName
// Search for primary key tag identifier
field.isPrimaryKey = scope.PrimaryKey() == field.DBName || fieldStruct.Tag.Get("primaryKey") != ""
if field.isPrimaryKey {
scope.primaryKey = field.DBName
}
if scope.db != nil {
field.Tag = fieldStruct.Tag
@ -252,7 +264,7 @@ func (scope *Scope) Fields() []*Field {
case reflect.Slice:
typ = typ.Elem()
if _, ok := field.Value.([]byte); !ok {
if typ.Kind() == reflect.Struct {
foreignKey := scopeTyp.Name() + "Id"
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
field.ForeignKey = foreignKey

View File

@ -5,6 +5,7 @@ import (
"database/sql/driver"
"errors"
"fmt"
"go/ast"
"reflect"
"regexp"
"strconv"
@ -309,26 +310,24 @@ func (scope *Scope) sqlTagForField(field *Field) (tag string) {
value := field.Value
reflectValue := reflect.ValueOf(value)
if field.IsScanner() {
value = reflectValue.Field(0).Interface()
}
switch reflectValue.Kind() {
case reflect.Slice:
if _, ok := value.([]byte); !ok {
return
}
case reflect.Struct:
if !field.IsTime() && !field.IsScanner() {
if field.IsScanner() {
reflectValue = reflectValue.Field(0)
} else if !field.IsTime() {
return
}
}
if len(tag) == 0 {
if field.isPrimaryKey {
tag = scope.Dialect().PrimaryKeyTag(value, size)
tag = scope.Dialect().PrimaryKeyTag(reflectValue, size)
} else {
tag = scope.Dialect().SqlTag(value, size)
tag = scope.Dialect().SqlTag(reflectValue, size)
}
}
@ -395,7 +394,7 @@ func (scope *Scope) typeName() string {
}
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
toScope := scope.New(value)
toScope := scope.db.NewScope(value)
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
if foreignValue, ok := scope.FieldByName(foreignKey); ok {
@ -474,3 +473,33 @@ func (scope *Scope) autoMigrate() *Scope {
}
return scope
}
func (scope *Scope) getPrimaryKey() string {
var indirectValue reflect.Value
indirectValue = reflect.Indirect(reflect.ValueOf(scope.Value))
if indirectValue.Kind() == reflect.Slice {
indirectValue = reflect.New(indirectValue.Type().Elem()).Elem()
}
if !indirectValue.IsValid() {
return "id"
}
scopeTyp := indirectValue.Type()
for i := 0; i < scopeTyp.NumField(); i++ {
fieldStruct := scopeTyp.Field(i)
if !ast.IsExported(fieldStruct.Name) {
continue
}
// if primaryKey tag found, return column name
if fieldStruct.Tag.Get("primaryKey") != "" {
return toSnake(fieldStruct.Name)
}
}
//If primaryKey tag not found, fallback to id
return "id"
}