diff --git a/README.md b/README.md index d3dfdd11..b06743b5 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ go get github.com/jinzhu/gorm ## Conventions * Table name is the plural of struct name's snake case. - Disable pluralization with `db.SingularTable(true)`, or [specify your table name](#specify-table-name) + Disable pluralization with `db.SingularTable(true)`, or [Specifying the Table Name for Struct permanently with TableName](#Specifying-the-Table-Name-for-Struct-permanently-with-TableName) * Column name is the snake case of field's name. * Use `Id int64` field as primary key. * Use tag `sql` to change field's property, change the tag name with `db.SetTagIdentifier(new_name)`. @@ -47,6 +47,20 @@ db.First(&user) DB.Save(&User{Name: "xxx"}) // table "users" ``` +## Existing schema + +If you have and existing database schema and some of your tables does not follow the conventions, (and you can't rename your table names), please use: [Specifying the Table Name for Struct permanently with TableName](#Specifying-the-Table-Name-for-Struct-permanently-with-TableName). + +If your primary key field is different from `id`, you can add a tag to the field structure to specify that this field is a primary key. + +```go +type Animal struct { // animals + AnimalId int64 `primaryKey:"yes"` + Birthday time.Time + Age int64 +} +``` + # Getting Started ```go @@ -97,9 +111,9 @@ import _ "github.com/lib/pq" // import _ "github.com/go-sql-driver/mysql" // import _ "github.com/mattn/go-sqlite3" -db, err := Open("postgres", "user=gorm dbname=gorm sslmode=disable") -// db, err = Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True") -// db, err = Open("sqlite3", "/tmp/gorm.db") +db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") +// db, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True") +// db, err = gorm.Open("sqlite3", "/tmp/gorm.db") // Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB) d := db.DB() diff --git a/dialect/dialect.go b/dialect/dialect.go index 9418e533..a0f403c4 100644 --- a/dialect/dialect.go +++ b/dialect/dialect.go @@ -1,10 +1,17 @@ package dialect +import ( + "reflect" + "time" +) + +var timeType = reflect.TypeOf(time.Time{}) + type Dialect interface { BinVar(i int) string SupportLastInsertId() bool - SqlTag(column interface{}, size int) string - PrimaryKeyTag(column interface{}, size int) string + SqlTag(value reflect.Value, size int) string + PrimaryKeyTag(value reflect.Value, size int) string ReturningStr(key string) string Quote(key string) string } diff --git a/dialect/mysql.go b/dialect/mysql.go index a1c0a26e..9602bcc1 100644 --- a/dialect/mysql.go +++ b/dialect/mysql.go @@ -1,9 +1,8 @@ package dialect import ( - "database/sql" "fmt" - "time" + "reflect" ) type mysql struct{} @@ -16,41 +15,44 @@ func (s *mysql) SupportLastInsertId() bool { return true } -func (d *mysql) SqlTag(column interface{}, size int) string { - switch column.(type) { - case time.Time: - return "datetime" - case bool, sql.NullBool: +func (d *mysql) SqlTag(value reflect.Value, size int) string { + switch value.Kind() { + case reflect.Bool: return "boolean" - case int, int8, int16, int32, uint, uint8, uint16, uint32: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: return "int" - case int64, uint64, sql.NullInt64: + case reflect.Int64, reflect.Uint64: return "bigint" - case float32, float64, sql.NullFloat64: + case reflect.Float32, reflect.Float64: return "double" - case []byte: - if size > 0 && size < 65532 { - return fmt.Sprintf("varbinary(%d)", size) - } else { - return "longblob" - } - case string, sql.NullString: + case reflect.String: if size > 0 && size < 65532 { return fmt.Sprintf("varchar(%d)", size) } else { return "longtext" } + case reflect.Struct: + if value.Type() == timeType { + return "datetime" + } default: - panic("Invalid sql type for mysql") + if _, ok := value.Interface().([]byte); ok { + if size > 0 && size < 65532 { + return fmt.Sprintf("varbinary(%d)", size) + } else { + return "longblob" + } + } } + panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String())) } -func (s *mysql) PrimaryKeyTag(column interface{}, size int) string { +func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string { suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY" - switch column.(type) { - case int, int8, int16, int32, uint, uint8, uint16, uint32: + switch value.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: return "int" + suffix_str - case int64, uint64: + case reflect.Int64, reflect.Uint64: return "bigint" + suffix_str default: panic("Invalid primary key type") diff --git a/dialect/postgres.go b/dialect/postgres.go index c0981cd0..7b744fa4 100644 --- a/dialect/postgres.go +++ b/dialect/postgres.go @@ -1,9 +1,8 @@ package dialect import ( - "database/sql" "fmt" - "time" + "reflect" ) type postgres struct { @@ -17,36 +16,38 @@ func (s *postgres) SupportLastInsertId() bool { return false } -func (d *postgres) SqlTag(column interface{}, size int) string { - switch column.(type) { - case time.Time: - return "timestamp with time zone" - case bool, sql.NullBool: +func (d *postgres) SqlTag(value reflect.Value, size int) string { + switch value.Kind() { + case reflect.Bool: return "boolean" - case int, int8, int16, int32, uint, uint8, uint16, uint32: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: return "integer" - case int64, uint64, sql.NullInt64: + case reflect.Int64, reflect.Uint64: return "bigint" - case float32, float64, sql.NullFloat64: + case reflect.Float32, reflect.Float64: return "numeric" - case []byte: - return "bytea" - case string, sql.NullString: + case reflect.String: if size > 0 && size < 65532 { return fmt.Sprintf("varchar(%d)", size) - } else { - return "text" + } + return "text" + case reflect.Struct: + if value.Type() == timeType { + return "timestamp with time zone" } default: - panic("Invalid sql type for postgres") + if _, ok := value.Interface().([]byte); ok { + return "bytea" + } } + panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String())) } -func (s *postgres) PrimaryKeyTag(column interface{}, size int) string { - switch column.(type) { - case int, int8, int16, int32, uint, uint8, uint16, uint32: +func (s *postgres) PrimaryKeyTag(value reflect.Value, size int) string { + switch value.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: return "serial PRIMARY KEY" - case int64, uint64: + case reflect.Int64, reflect.Uint64: return "bigserial PRIMARY KEY" default: panic("Invalid primary key type") diff --git a/dialect/sqlite3.go b/dialect/sqlite3.go index 92063786..ae54e603 100644 --- a/dialect/sqlite3.go +++ b/dialect/sqlite3.go @@ -1,9 +1,8 @@ package dialect import ( - "database/sql" "fmt" - "time" + "reflect" ) type sqlite3 struct{} @@ -16,33 +15,41 @@ func (s *sqlite3) SupportLastInsertId() bool { return true } -func (s *sqlite3) SqlTag(column interface{}, size int) string { - switch column.(type) { - case time.Time: - return "datetime" - case bool, sql.NullBool: +func (s *sqlite3) SqlTag(value reflect.Value, size int) string { + switch value.Kind() { + case reflect.Bool: return "bool" - case int, int8, int16, int32, uint, uint8, uint16, uint32: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: return "integer" - case int64, uint64, sql.NullInt64: + case reflect.Int64, reflect.Uint64: return "bigint" - case float32, float64, sql.NullFloat64: + case reflect.Float32, reflect.Float64: return "real" - case []byte: - return "blob" - case string, sql.NullString: + case reflect.String: if size > 0 && size < 65532 { return fmt.Sprintf("varchar(%d)", size) } else { return "text" } + case reflect.Struct: + if value.Type() == timeType { + return "datetime" + } default: - panic("Invalid sql type for sqlite3") + if _, ok := value.Interface().([]byte); ok { + return "blob" + } } + panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String())) } -func (s *sqlite3) PrimaryKeyTag(column interface{}, size int) string { - return "INTEGER PRIMARY KEY" +func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string { + switch value.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr, reflect.Int64, reflect.Uint64: + return "INTEGER PRIMARY KEY" + default: + panic("Invalid primary key type") + } } func (s *sqlite3) ReturningStr(key string) (str string) { diff --git a/main.go b/main.go index ca1d24bb..c98ca2b6 100644 --- a/main.go +++ b/main.go @@ -32,6 +32,13 @@ func (s *DB) DB() *sql.DB { return s.db.(*sql.DB) } +// Return the underlying sql.DB or sql.Tx instance. +// Use of this method is discouraged. It's mainly intended to allow +// coexistence with legacy non-GORM code. +func (s *DB) CommonDB() sqlCommon { + return s.db +} + func (s *DB) Callback() *callback { s.parent.callback = s.parent.callback.clone() return s.parent.callback @@ -124,13 +131,13 @@ func (s *DB) Assign(attrs ...interface{}) *DB { func (s *DB) First(out interface{}, where ...interface{}) *DB { scope := s.clone().NewScope(out) - scope.Search = scope.Search.clone().order(scope.PrimaryKey()).limit(1) + scope.Search = scope.Search.clone().order(scope.TableName()+"."+scope.PrimaryKey()).limit(1) return scope.inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } func (s *DB) Last(out interface{}, where ...interface{}) *DB { scope := s.clone().NewScope(out) - scope.Search = scope.Search.clone().order(scope.PrimaryKey() + " DESC").limit(1) + scope.Search = scope.Search.clone().order(scope.TableName()+"."+scope.PrimaryKey() + " DESC").limit(1) return scope.inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } diff --git a/main_test.go b/main_test.go index 6d89659a..e4016407 100644 --- a/main_test.go +++ b/main_test.go @@ -7,9 +7,9 @@ import ( "fmt" _ "github.com/go-sql-driver/mysql" - "github.com/jinzhu/gorm" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" + "github.com/jinzhu/gorm" "os" "reflect" @@ -22,9 +22,45 @@ type IgnoredEmbedStruct struct { Name string } +type Num int64 + +func (i *Num) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + case int64: + *i = Num(s) + default: + return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) + } + return nil +} + +type Company struct { + Id int64 + Name string +} + +type Role struct { + Name string +} + +func (role *Role) Scan(value interface{}) error { + role.Name = string(value.([]uint8)) + return nil +} + +func (role Role) Value() (driver.Value, error) { + return role.Name, nil +} + +func (role Role) IsAdmin() bool { + return role.Name == "admin" +} + type User struct { Id int64 // Id: Primary key Age int64 + UserNum Num Name string `sql:"size:255"` Birthday time.Time // Time CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically @@ -39,8 +75,12 @@ type User struct { When time.Time CreditCard CreditCard Latitude float64 - PasswordHash []byte - IgnoreMe int64 `sql:"-"` + CompanyId int64 + Company + Role + PasswordHash []byte + IgnoreMe int64 `sql:"-"` + IgnoreStringSlice []string `sql:"-"` } type CreditCard struct { @@ -87,6 +127,13 @@ type Product struct { AfterDeleteCallTimes int64 } +type Animal struct { + Counter int64 `primaryKey:"yes"` + Name string + CreatedAt time.Time + UpdatedAt time.Time +} + var ( db gorm.DB t1, t2, t3, t4, t5 time.Time @@ -128,6 +175,13 @@ func init() { db.Exec("drop table emails;") db.Exec("drop table addresses") db.Exec("drop table credit_cards") + db.Exec("drop table roles") + db.Exec("drop table companies") + db.Exec("drop table animals") + + if err = db.CreateTable(&Animal{}).Error; err != nil { + panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) + } if err = db.CreateTable(&User{}).Error; err != nil { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) @@ -149,17 +203,30 @@ func init() { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) } + if err = db.AutoMigrate(Company{}).Error; err != nil { + panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) + } + + if err = db.AutoMigrate(Role{}).Error; err != nil { + panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) + } + var shortForm = "2006-01-02 15:04:05" t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40") t2, _ = time.Parse(shortForm, "2002-01-01 00:00:00") t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00") t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00") t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00") - db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now()}) + db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now(), UserNum: Num(111)}) db.Save(&User{Name: "2", Age: 20, Birthday: t2}) db.Save(&User{Name: "3", Age: 22, Birthday: t3}) db.Save(&User{Name: "3", Age: 24, Birthday: t4}) db.Save(&User{Name: "5", Age: 26, Birthday: t4}) + + db.Save(&Animal{Name: "First"}) + db.Save(&Animal{Name: "Amazing"}) + db.Save(&Animal{Name: "Horse"}) + db.Save(&Animal{Name: "Last"}) } func TestFirstAndLast(t *testing.T) { @@ -170,7 +237,7 @@ func TestFirstAndLast(t *testing.T) { db.Last(&user3) db.Order("id desc").Find(&user4) if user1.Id != user2.Id || user3.Id != user4.Id { - t.Errorf("First and Last should works correctly") + t.Errorf("First and Last should work correctly") } var users []User @@ -180,6 +247,52 @@ func TestFirstAndLast(t *testing.T) { } } +func TestFirstAndLastWithJoins(t *testing.T) { + var user1, user2, user3, user4 User + db.Joins("left join emails on emails.user_id = users.id").First(&user1) + db.Order("id").Find(&user2) + + db.Joins("left join emails on emails.user_id = users.id").Last(&user3) + db.Order("id desc").Find(&user4) + if user1.Id != user2.Id || user3.Id != user4.Id { + t.Errorf("First and Last should work correctly with Joins") + } +} + +func TestFirstAndLastForTableWithNoStdPrimaryKey(t *testing.T) { + var animal1, animal2, animal3, animal4 Animal + db.First(&animal1) + db.Order("counter").Find(&animal2) + + db.Last(&animal3) + db.Order("counter desc").Find(&animal4) + if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter { + t.Errorf("First and Last should work correctly") + } + + var animals []Animal + db.First(&animals) + if len(animals) != 1 { + t.Errorf("Find first record as map") + } +} + +func TestSaveCustomType(t *testing.T) { + var user, user1 User + db.First(&user, "name = ?", "1") + if user.UserNum != Num(111) { + t.Errorf("UserNum should be saved correctly") + } + + user.UserNum = Num(222) + db.Save(&user) + + db.First(&user1, "name = ?", "1") + if user1.UserNum != Num(222) { + t.Errorf("UserNum should be updated correctly") + } +} + func TestPrecision(t *testing.T) { f := 35.03554004971999 user := User{Name: "Precision", Latitude: f} @@ -518,7 +631,7 @@ func TestOrderAndPluck(t *testing.T) { db.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4) if reflect.DeepEqual(ages3, ages4) { - t.Errorf("Reorder should works") + t.Errorf("Reorder should work") } var names []string @@ -535,7 +648,7 @@ func TestLimit(t *testing.T) { db.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { - t.Errorf("Limit should works") + t.Errorf("Limit should work") } } @@ -544,7 +657,7 @@ func TestOffset(t *testing.T) { db.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { - t.Errorf("Offset should works") + t.Errorf("Offset should work") } } @@ -561,7 +674,7 @@ func TestCount(t *testing.T) { var users []User if err := db.Where("name = ?", "1").Or("name = ?", "3").Find(&users).Count(&count).Error; err != nil { - t.Errorf("Count should works", err) + t.Errorf("Count should work", err) } if count != int64(len(users)) { @@ -570,7 +683,7 @@ func TestCount(t *testing.T) { db.Model(&User{}).Where("name = ?", "1").Count(&count1).Or("name = ?", "3").Count(&count2) if count1 != 1 || count2 != 3 { - t.Errorf("Multiple count should works") + t.Errorf("Multiple count should work") } } @@ -688,7 +801,7 @@ func TestRunCallbacks(t *testing.T) { var products []Product db.Find(&products, "code = ?", "unique_code") if products[0].AfterFindCallTimes != 2 { - t.Errorf("AfterFind callbacks should works with slice") + t.Errorf("AfterFind callbacks should work with slice") } db.Where("Code = ?", "unique_code").First(&p) @@ -869,17 +982,31 @@ func TestSetTableDirectly(t *testing.T) { func TestUpdate(t *testing.T) { product1 := Product{Code: "123"} product2 := Product{Code: "234"} + animal1 := Animal{Name: "Ferdinand"} + animal2 := Animal{Name: "nerdz"} + db.Save(&product1).Save(&product2).Update("code", "456") if product2.Code != "456" { t.Errorf("Record should be updated with update attributes") } + db.Save(&animal1).Save(&animal2).Update("name", "Francis") + + if animal2.Name != "Francis" { + t.Errorf("Record should be updated with update attributes") + } + db.First(&product1, product1.Id) db.First(&product2, product2.Id) updated_at1 := product1.UpdatedAt updated_at2 := product2.UpdatedAt + db.First(&animal1, animal1.Counter) + db.First(&animal2, animal2.Counter) + animalUpdated_at1 := animal1.UpdatedAt + animalUpdated_at2 := animal2.UpdatedAt + var product3 Product db.First(&product3, product2.Id).Update("code", "456") if updated_at2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) { @@ -898,6 +1025,25 @@ func TestUpdate(t *testing.T) { t.Errorf("Product 234 should be changed to 456") } + var animal3 Animal + db.First(&animal3, animal2.Counter).Update("Name", "Robert") + + if animalUpdated_at2.Format(time.RFC3339Nano) != animal2.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("updated_at should not be updated if nothing changed") + } + + if db.First(&Animal{}, "name = 'Ferdinand'").Error != nil { + t.Errorf("Animal 'Ferdinand' should not be updated") + } + + if db.First(&Animal{}, "name = 'nerdz'").Error == nil { + t.Errorf("Animal 'nerdz' should be changed to 'Francis'") + } + + if db.First(&Animal{}, "name = 'Robert'").Error != nil { + t.Errorf("Animal 'nerdz' should be changed to 'Robert'") + } + db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789") var product4 Product @@ -925,6 +1071,34 @@ func TestUpdate(t *testing.T) { if db.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { t.Error("No error should raise when update_column with CamelCase") } + + db.Table("animals").Where("name in (?)", []string{"Ferdinand"}).Update("name", "Franz") + + var animal4 Animal + db.First(&animal4, animal1.Counter) + if animalUpdated_at1.Format(time.RFC3339Nano) != animal4.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("animalUpdated_at should be updated if something changed") + } + + if db.First(&Animal{}, "name = 'Ferdinand'").Error == nil { + t.Errorf("Animal 'Fredinand' should be changed to 'Franz'") + } + + if db.First(&Animal{}, "name = 'Robert'").Error != nil { + t.Errorf("Animal 'Robert' should not be changed to 'Francis'") + } + + if db.First(&Animal{}, "name = 'Franz'").Error != nil { + t.Errorf("Product 'nerdz' should be changed to 'Franz'") + } + + if db.Model(animal2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil { + t.Error("No error should raise when update with CamelCase") + } + + if db.Model(&animal2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { + t.Error("No error should raise when update_column with CamelCase") + } } func TestUpdates(t *testing.T) { @@ -1308,10 +1482,17 @@ func TestRelated(t *testing.T) { if len(emails) != 2 { t.Errorf("Should have two emails") } + + var emails2 []Email + db.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2) + if len(emails2) != 1 { + t.Errorf("Should have two emails") + } + var user1 User db.Model(&user).Related(&user1.Emails) if len(user1.Emails) != 2 { - t.Errorf("Should have two emails") + t.Errorf("Should have only one email match related condition") } var address1 Address @@ -1511,6 +1692,10 @@ func TestTransaction(t *testing.T) { t.Errorf("Should find saved record, but got", err) } + if sql_tx, ok := tx.CommonDB().(*sql.Tx); !ok || sql_tx == nil { + t.Errorf("Should return the underlying sql.Tx") + } + tx.Rollback() if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { @@ -1580,7 +1765,7 @@ func TestScan(t *testing.T) { var res result db.Table("users").Select("name, age").Where("name = ?", 3).Scan(&res) if res.Name != "3" { - t.Errorf("Scan into struct should works") + t.Errorf("Scan into struct should work") } var ress []result @@ -1694,6 +1879,33 @@ func TestHaving(t *testing.T) { } } +func TestAnonymousField(t *testing.T) { + user := User{Name: "anonymous_field", Company: Company{Name: "company"}} + db.Save(&user) + + var user2 User + db.First(&user2, "name = ?", "anonymous_field") + db.Model(&user2).Related(&user2.Company) + if user2.Company.Name != "company" { + t.Errorf("Should be able to get anonymous field") + } +} + +func TestAnonymousScanner(t *testing.T) { + user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}} + db.Save(&user) + + var user2 User + db.First(&user2, "name = ?", "anonymous_scanner") + if user2.Role.Name != "admin" { + t.Errorf("Should be able to get anonymous scanner") + } + + if !user2.IsAdmin() { + t.Errorf("Should be able to get anonymous scanner") + } +} + func TestExecRawSql(t *testing.T) { db.Exec("update users set name=? where name in (?)", "jinzhu", []string{"1", "2", "3"}) if db.Where("name in (?)", []string{"1", "2", "3"}).First(&User{}).Error != gorm.RecordNotFound { @@ -1701,6 +1913,45 @@ func TestExecRawSql(t *testing.T) { } } +func TestTimeWithZone(t *testing.T) { + var format = "2006-01-02 15:04:05 -0700" + var times []time.Time + GMT8, _ := time.LoadLocation("Asia/Shanghai") + times = append(times, time.Date(2013, 02, 19, 1, 51, 49, 123456789, GMT8)) + times = append(times, time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC)) + + for index, vtime := range times { + name := "time_with_zone_" + strconv.Itoa(index) + user := User{Name: name, Birthday: vtime} + db.Save(&user) + if user.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" { + t.Errorf("User's birthday should not be changed after save") + } + + if user.DeletedAt.UTC().Format(format) != "0001-01-01 00:00:00 +0000" { + t.Errorf("User's deleted at should be zero") + } + + var findUser, findUser2, findUser3 User + db.First(&findUser, "name = ?", name) + if findUser.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" { + t.Errorf("User's birthday should not be changed after find") + } + + if findUser.DeletedAt.UTC().Format(format) != "0001-01-01 00:00:00 +0000" { + t.Errorf("User's deleted at should be zero") + } + + if db.Where("birthday >= ?", vtime.Add(-time.Minute)).First(&findUser2).RecordNotFound() { + t.Errorf("User should be found") + } + + if !db.Where("birthday >= ?", vtime.Add(time.Minute)).First(&findUser3).RecordNotFound() { + t.Errorf("User should not be found") + } + } +} + func BenchmarkGorm(b *testing.B) { b.N = 2000 for x := 0; x < b.N; x++ { diff --git a/scope.go b/scope.go index dba091af..543c4cde 100644 --- a/scope.go +++ b/scope.go @@ -13,13 +13,14 @@ import ( ) type Scope struct { - Value interface{} - Search *search - Sql string - SqlVars []interface{} - db *DB - _values map[string]interface{} - skipLeft bool + Value interface{} + Search *search + Sql string + SqlVars []interface{} + db *DB + _values map[string]interface{} + skipLeft bool + primaryKey string } // NewScope create scope for callbacks, including DB's search information @@ -78,7 +79,12 @@ func (scope *Scope) HasError() bool { // PrimaryKey get the primary key's column name func (scope *Scope) PrimaryKey() string { - return "id" + if scope.primaryKey != "" { + return scope.primaryKey + } + + scope.primaryKey = scope.getPrimaryKey() + return scope.primaryKey } // PrimaryKeyZero check the primary key is blank or not @@ -227,7 +233,7 @@ func (scope *Scope) Fields() []*Field { scopeTyp := indirectValue.Type() for i := 0; i < scopeTyp.NumField(); i++ { fieldStruct := scopeTyp.Field(i) - if fieldStruct.Anonymous || !ast.IsExported(fieldStruct.Name) { + if !ast.IsExported(fieldStruct.Name) { continue } @@ -238,7 +244,13 @@ func (scope *Scope) Fields() []*Field { value := indirectValue.FieldByName(fieldStruct.Name) field.Value = value.Interface() field.IsBlank = isBlank(value) - field.isPrimaryKey = scope.PrimaryKey() == field.DBName + + // Search for primary key tag identifier + field.isPrimaryKey = scope.PrimaryKey() == field.DBName || fieldStruct.Tag.Get("primaryKey") != "" + + if field.isPrimaryKey { + scope.primaryKey = field.DBName + } if scope.db != nil { field.Tag = fieldStruct.Tag @@ -252,7 +264,7 @@ func (scope *Scope) Fields() []*Field { case reflect.Slice: typ = typ.Elem() - if _, ok := field.Value.([]byte); !ok { + if typ.Kind() == reflect.Struct { foreignKey := scopeTyp.Name() + "Id" if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { field.ForeignKey = foreignKey diff --git a/scope_private.go b/scope_private.go index dae943e8..d9f16438 100644 --- a/scope_private.go +++ b/scope_private.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "errors" "fmt" + "go/ast" "reflect" "regexp" "strconv" @@ -309,26 +310,24 @@ func (scope *Scope) sqlTagForField(field *Field) (tag string) { value := field.Value reflectValue := reflect.ValueOf(value) - if field.IsScanner() { - value = reflectValue.Field(0).Interface() - } - switch reflectValue.Kind() { case reflect.Slice: if _, ok := value.([]byte); !ok { return } case reflect.Struct: - if !field.IsTime() && !field.IsScanner() { + if field.IsScanner() { + reflectValue = reflectValue.Field(0) + } else if !field.IsTime() { return } } if len(tag) == 0 { if field.isPrimaryKey { - tag = scope.Dialect().PrimaryKeyTag(value, size) + tag = scope.Dialect().PrimaryKeyTag(reflectValue, size) } else { - tag = scope.Dialect().SqlTag(value, size) + tag = scope.Dialect().SqlTag(reflectValue, size) } } @@ -395,7 +394,7 @@ func (scope *Scope) typeName() string { } func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { - toScope := scope.New(value) + toScope := scope.db.NewScope(value) for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { if foreignValue, ok := scope.FieldByName(foreignKey); ok { @@ -474,3 +473,33 @@ func (scope *Scope) autoMigrate() *Scope { } return scope } + +func (scope *Scope) getPrimaryKey() string { + var indirectValue reflect.Value + + indirectValue = reflect.Indirect(reflect.ValueOf(scope.Value)) + + if indirectValue.Kind() == reflect.Slice { + indirectValue = reflect.New(indirectValue.Type().Elem()).Elem() + } + + if !indirectValue.IsValid() { + return "id" + } + + scopeTyp := indirectValue.Type() + for i := 0; i < scopeTyp.NumField(); i++ { + fieldStruct := scopeTyp.Field(i) + if !ast.IsExported(fieldStruct.Name) { + continue + } + + // if primaryKey tag found, return column name + if fieldStruct.Tag.Get("primaryKey") != "" { + return toSnake(fieldStruct.Name) + } + } + + //If primaryKey tag not found, fallback to id + return "id" +}