From 24d527670b0d64a094c915522cad3f612eb02fed Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 7 Mar 2014 19:08:33 +0800 Subject: [PATCH 01/20] Use the same database for Related --- scope_private.go | 1 + 1 file changed, 1 insertion(+) diff --git a/scope_private.go b/scope_private.go index dae943e8..7212af2b 100644 --- a/scope_private.go +++ b/scope_private.go @@ -396,6 +396,7 @@ func (scope *Scope) typeName() string { func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { toScope := scope.New(value) + toScope.db = scope.db for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { if foreignValue, ok := scope.FieldByName(foreignKey); ok { From 1086009fce2796a205e0eeff70c50b917613023d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 11 Mar 2014 10:08:16 +0800 Subject: [PATCH 02/20] Check if value is struct before check Field --- main_test.go | 3 ++- scope.go | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/main_test.go b/main_test.go index 6d89659a..dc60de24 100644 --- a/main_test.go +++ b/main_test.go @@ -40,7 +40,8 @@ type User struct { CreditCard CreditCard Latitude float64 PasswordHash []byte - IgnoreMe int64 `sql:"-"` + IgnoreMe int64 `sql:"-"` + IgnoreStringSlice []string `sql:"-"` } type CreditCard struct { diff --git a/scope.go b/scope.go index dba091af..78742d89 100644 --- a/scope.go +++ b/scope.go @@ -252,7 +252,7 @@ func (scope *Scope) Fields() []*Field { case reflect.Slice: typ = typ.Elem() - if _, ok := field.Value.([]byte); !ok { + if typ.Kind() == reflect.Struct { foreignKey := scopeTyp.Name() + "Id" if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { field.ForeignKey = foreignKey From 65e594e2d627477855cef929b6e4f2f3b0edca25 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 15 Mar 2014 09:05:08 +0800 Subject: [PATCH 03/20] Fix README --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index d3dfdd11..8a904a3e 100644 --- a/README.md +++ b/README.md @@ -97,9 +97,9 @@ import _ "github.com/lib/pq" // import _ "github.com/go-sql-driver/mysql" // import _ "github.com/mattn/go-sqlite3" -db, err := Open("postgres", "user=gorm dbname=gorm sslmode=disable") -// db, err = Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True") -// db, err = Open("sqlite3", "/tmp/gorm.db") +db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") +// db, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True") +// db, err = gorm.Open("sqlite3", "/tmp/gorm.db") // Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB) d := db.DB() From 7bbf71fb29aea449284f42d0c29f13198a484079 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 15 Mar 2014 10:17:43 +0800 Subject: [PATCH 04/20] Add tests to make sure time with zone won't be changed after save --- main_test.go | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/main_test.go b/main_test.go index dc60de24..65caf58d 100644 --- a/main_test.go +++ b/main_test.go @@ -1702,6 +1702,36 @@ func TestExecRawSql(t *testing.T) { } } +func TestTimeWithZone(t *testing.T) { + var format = "2006-01-02 15:04:05 -0700" + var times []time.Time + GMT8, _ := time.LoadLocation("Asia/Shanghai") + times = append(times, time.Date(2013, 02, 19, 1, 51, 49, 123456789, GMT8)) + times = append(times, time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC)) + + for _, vtime := range times { + user := User{Name: "time_with_zone", Birthday: vtime} + db.Save(&user) + if user.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" { + t.Errorf("User's birthday should not be changed after save") + } + + if user.DeletedAt.UTC().Format(format) != "0001-01-01 00:00:00 +0000" { + t.Errorf("User's deleted at should be zero") + } + + var findUser User + db.First(&findUser, "name = ?", "time_with_zone") + if findUser.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" { + t.Errorf("User's birthday should not be changed after find") + } + + if findUser.DeletedAt.UTC().Format(format) != "0001-01-01 00:00:00 +0000" { + t.Errorf("User's deleted at should be zero") + } + } +} + func BenchmarkGorm(b *testing.B) { b.N = 2000 for x := 0; x < b.N; x++ { From 844a0ddfccf65890a41d917b64e6c6fc07014f25 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 15 Mar 2014 10:31:26 +0800 Subject: [PATCH 05/20] update tests for time with zone --- main_test.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/main_test.go b/main_test.go index 65caf58d..87fa9b21 100644 --- a/main_test.go +++ b/main_test.go @@ -1709,8 +1709,9 @@ func TestTimeWithZone(t *testing.T) { times = append(times, time.Date(2013, 02, 19, 1, 51, 49, 123456789, GMT8)) times = append(times, time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC)) - for _, vtime := range times { - user := User{Name: "time_with_zone", Birthday: vtime} + for index, vtime := range times { + name := "time_with_zone_" + strconv.Itoa(index) + user := User{Name: name, Birthday: vtime} db.Save(&user) if user.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" { t.Errorf("User's birthday should not be changed after save") @@ -1721,7 +1722,7 @@ func TestTimeWithZone(t *testing.T) { } var findUser User - db.First(&findUser, "name = ?", "time_with_zone") + db.First(&findUser, "name = ?", name) if findUser.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" { t.Errorf("User's birthday should not be changed after find") } From dc2f27401eb44bb575423b467709f8f529dd9002 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 15 Mar 2014 10:41:12 +0800 Subject: [PATCH 06/20] Test search data using time with zone --- main_test.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/main_test.go b/main_test.go index 87fa9b21..1df4558a 100644 --- a/main_test.go +++ b/main_test.go @@ -1721,7 +1721,7 @@ func TestTimeWithZone(t *testing.T) { t.Errorf("User's deleted at should be zero") } - var findUser User + var findUser, findUser2, findUser3 User db.First(&findUser, "name = ?", name) if findUser.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" { t.Errorf("User's birthday should not be changed after find") @@ -1730,6 +1730,14 @@ func TestTimeWithZone(t *testing.T) { if findUser.DeletedAt.UTC().Format(format) != "0001-01-01 00:00:00 +0000" { t.Errorf("User's deleted at should be zero") } + + if db.Where("birthday >= ?", vtime.Add(-time.Minute)).First(&findUser2).RecordNotFound() { + t.Errorf("User should be found") + } + + if !db.Where("birthday >= ?", vtime.Add(time.Minute)).First(&findUser3).RecordNotFound() { + t.Errorf("User should not be found") + } } } From 4969fc9cb52529ecbb0e1d4ff5dfec5494f9ffdd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 16 Mar 2014 07:50:57 +0800 Subject: [PATCH 07/20] Remove unused Scanner type from dialect --- dialect/mysql.go | 9 ++++----- dialect/postgres.go | 9 ++++----- dialect/sqlite3.go | 9 ++++----- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/dialect/mysql.go b/dialect/mysql.go index a1c0a26e..2a5ac0c0 100644 --- a/dialect/mysql.go +++ b/dialect/mysql.go @@ -1,7 +1,6 @@ package dialect import ( - "database/sql" "fmt" "time" ) @@ -20,13 +19,13 @@ func (d *mysql) SqlTag(column interface{}, size int) string { switch column.(type) { case time.Time: return "datetime" - case bool, sql.NullBool: + case bool: return "boolean" case int, int8, int16, int32, uint, uint8, uint16, uint32: return "int" - case int64, uint64, sql.NullInt64: + case int64, uint64: return "bigint" - case float32, float64, sql.NullFloat64: + case float32, float64: return "double" case []byte: if size > 0 && size < 65532 { @@ -34,7 +33,7 @@ func (d *mysql) SqlTag(column interface{}, size int) string { } else { return "longblob" } - case string, sql.NullString: + case string: if size > 0 && size < 65532 { return fmt.Sprintf("varchar(%d)", size) } else { diff --git a/dialect/postgres.go b/dialect/postgres.go index c0981cd0..7cffe9dc 100644 --- a/dialect/postgres.go +++ b/dialect/postgres.go @@ -1,7 +1,6 @@ package dialect import ( - "database/sql" "fmt" "time" ) @@ -21,17 +20,17 @@ func (d *postgres) SqlTag(column interface{}, size int) string { switch column.(type) { case time.Time: return "timestamp with time zone" - case bool, sql.NullBool: + case bool: return "boolean" case int, int8, int16, int32, uint, uint8, uint16, uint32: return "integer" - case int64, uint64, sql.NullInt64: + case int64, uint64: return "bigint" - case float32, float64, sql.NullFloat64: + case float32, float64: return "numeric" case []byte: return "bytea" - case string, sql.NullString: + case string: if size > 0 && size < 65532 { return fmt.Sprintf("varchar(%d)", size) } else { diff --git a/dialect/sqlite3.go b/dialect/sqlite3.go index 92063786..4d16c3a8 100644 --- a/dialect/sqlite3.go +++ b/dialect/sqlite3.go @@ -1,7 +1,6 @@ package dialect import ( - "database/sql" "fmt" "time" ) @@ -20,17 +19,17 @@ func (s *sqlite3) SqlTag(column interface{}, size int) string { switch column.(type) { case time.Time: return "datetime" - case bool, sql.NullBool: + case bool: return "bool" case int, int8, int16, int32, uint, uint8, uint16, uint32: return "integer" - case int64, uint64, sql.NullInt64: + case int64, uint64: return "bigint" - case float32, float64, sql.NullFloat64: + case float32, float64: return "real" case []byte: return "blob" - case string, sql.NullString: + case string: if size > 0 && size < 65532 { return fmt.Sprintf("varchar(%d)", size) } else { From e6c953dd4c5370a09dc30be60d4e0fc6bc65bab8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 16 Mar 2014 09:28:43 +0800 Subject: [PATCH 08/20] Support custom types from base types --- dialect/dialect.go | 11 +++++++++-- dialect/mysql.go | 45 ++++++++++++++++++++++++--------------------- dialect/postgres.go | 40 +++++++++++++++++++++------------------- dialect/sqlite3.go | 38 +++++++++++++++++++++++--------------- main_test.go | 30 +++++++++++++++++++++++++++++- scope_private.go | 12 +++++------- 6 files changed, 111 insertions(+), 65 deletions(-) diff --git a/dialect/dialect.go b/dialect/dialect.go index 9418e533..a0f403c4 100644 --- a/dialect/dialect.go +++ b/dialect/dialect.go @@ -1,10 +1,17 @@ package dialect +import ( + "reflect" + "time" +) + +var timeType = reflect.TypeOf(time.Time{}) + type Dialect interface { BinVar(i int) string SupportLastInsertId() bool - SqlTag(column interface{}, size int) string - PrimaryKeyTag(column interface{}, size int) string + SqlTag(value reflect.Value, size int) string + PrimaryKeyTag(value reflect.Value, size int) string ReturningStr(key string) string Quote(key string) string } diff --git a/dialect/mysql.go b/dialect/mysql.go index 2a5ac0c0..9602bcc1 100644 --- a/dialect/mysql.go +++ b/dialect/mysql.go @@ -2,7 +2,7 @@ package dialect import ( "fmt" - "time" + "reflect" ) type mysql struct{} @@ -15,41 +15,44 @@ func (s *mysql) SupportLastInsertId() bool { return true } -func (d *mysql) SqlTag(column interface{}, size int) string { - switch column.(type) { - case time.Time: - return "datetime" - case bool: +func (d *mysql) SqlTag(value reflect.Value, size int) string { + switch value.Kind() { + case reflect.Bool: return "boolean" - case int, int8, int16, int32, uint, uint8, uint16, uint32: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: return "int" - case int64, uint64: + case reflect.Int64, reflect.Uint64: return "bigint" - case float32, float64: + case reflect.Float32, reflect.Float64: return "double" - case []byte: - if size > 0 && size < 65532 { - return fmt.Sprintf("varbinary(%d)", size) - } else { - return "longblob" - } - case string: + case reflect.String: if size > 0 && size < 65532 { return fmt.Sprintf("varchar(%d)", size) } else { return "longtext" } + case reflect.Struct: + if value.Type() == timeType { + return "datetime" + } default: - panic("Invalid sql type for mysql") + if _, ok := value.Interface().([]byte); ok { + if size > 0 && size < 65532 { + return fmt.Sprintf("varbinary(%d)", size) + } else { + return "longblob" + } + } } + panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String())) } -func (s *mysql) PrimaryKeyTag(column interface{}, size int) string { +func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string { suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY" - switch column.(type) { - case int, int8, int16, int32, uint, uint8, uint16, uint32: + switch value.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: return "int" + suffix_str - case int64, uint64: + case reflect.Int64, reflect.Uint64: return "bigint" + suffix_str default: panic("Invalid primary key type") diff --git a/dialect/postgres.go b/dialect/postgres.go index 7cffe9dc..7b744fa4 100644 --- a/dialect/postgres.go +++ b/dialect/postgres.go @@ -2,7 +2,7 @@ package dialect import ( "fmt" - "time" + "reflect" ) type postgres struct { @@ -16,36 +16,38 @@ func (s *postgres) SupportLastInsertId() bool { return false } -func (d *postgres) SqlTag(column interface{}, size int) string { - switch column.(type) { - case time.Time: - return "timestamp with time zone" - case bool: +func (d *postgres) SqlTag(value reflect.Value, size int) string { + switch value.Kind() { + case reflect.Bool: return "boolean" - case int, int8, int16, int32, uint, uint8, uint16, uint32: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: return "integer" - case int64, uint64: + case reflect.Int64, reflect.Uint64: return "bigint" - case float32, float64: + case reflect.Float32, reflect.Float64: return "numeric" - case []byte: - return "bytea" - case string: + case reflect.String: if size > 0 && size < 65532 { return fmt.Sprintf("varchar(%d)", size) - } else { - return "text" + } + return "text" + case reflect.Struct: + if value.Type() == timeType { + return "timestamp with time zone" } default: - panic("Invalid sql type for postgres") + if _, ok := value.Interface().([]byte); ok { + return "bytea" + } } + panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String())) } -func (s *postgres) PrimaryKeyTag(column interface{}, size int) string { - switch column.(type) { - case int, int8, int16, int32, uint, uint8, uint16, uint32: +func (s *postgres) PrimaryKeyTag(value reflect.Value, size int) string { + switch value.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: return "serial PRIMARY KEY" - case int64, uint64: + case reflect.Int64, reflect.Uint64: return "bigserial PRIMARY KEY" default: panic("Invalid primary key type") diff --git a/dialect/sqlite3.go b/dialect/sqlite3.go index 4d16c3a8..ae54e603 100644 --- a/dialect/sqlite3.go +++ b/dialect/sqlite3.go @@ -2,7 +2,7 @@ package dialect import ( "fmt" - "time" + "reflect" ) type sqlite3 struct{} @@ -15,33 +15,41 @@ func (s *sqlite3) SupportLastInsertId() bool { return true } -func (s *sqlite3) SqlTag(column interface{}, size int) string { - switch column.(type) { - case time.Time: - return "datetime" - case bool: +func (s *sqlite3) SqlTag(value reflect.Value, size int) string { + switch value.Kind() { + case reflect.Bool: return "bool" - case int, int8, int16, int32, uint, uint8, uint16, uint32: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: return "integer" - case int64, uint64: + case reflect.Int64, reflect.Uint64: return "bigint" - case float32, float64: + case reflect.Float32, reflect.Float64: return "real" - case []byte: - return "blob" - case string: + case reflect.String: if size > 0 && size < 65532 { return fmt.Sprintf("varchar(%d)", size) } else { return "text" } + case reflect.Struct: + if value.Type() == timeType { + return "datetime" + } default: - panic("Invalid sql type for sqlite3") + if _, ok := value.Interface().([]byte); ok { + return "blob" + } } + panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String())) } -func (s *sqlite3) PrimaryKeyTag(column interface{}, size int) string { - return "INTEGER PRIMARY KEY" +func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string { + switch value.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr, reflect.Int64, reflect.Uint64: + return "INTEGER PRIMARY KEY" + default: + panic("Invalid primary key type") + } } func (s *sqlite3) ReturningStr(key string) (str string) { diff --git a/main_test.go b/main_test.go index 1df4558a..f244c03d 100644 --- a/main_test.go +++ b/main_test.go @@ -22,6 +22,17 @@ type IgnoredEmbedStruct struct { Name string } +type Num int64 + +func (i *Num) Scan(src interface{}) error { + v := reflect.ValueOf(src) + if v.Kind() != reflect.Int64 { + return errors.New("Cannot scan NamedInt from " + v.String()) + } + *i = Num(v.Int()) + return nil +} + type User struct { Id int64 // Id: Primary key Age int64 @@ -42,6 +53,7 @@ type User struct { PasswordHash []byte IgnoreMe int64 `sql:"-"` IgnoreStringSlice []string `sql:"-"` + UserNum Num } type CreditCard struct { @@ -156,7 +168,7 @@ func init() { t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00") t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00") t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00") - db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now()}) + db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now(), UserNum: Num(111)}) db.Save(&User{Name: "2", Age: 20, Birthday: t2}) db.Save(&User{Name: "3", Age: 22, Birthday: t3}) db.Save(&User{Name: "3", Age: 24, Birthday: t4}) @@ -181,6 +193,22 @@ func TestFirstAndLast(t *testing.T) { } } +func TestSaveCustomType(t *testing.T) { + var user, user1 User + db.First(&user, "name = ?", "1") + if user.UserNum != Num(111) { + t.Errorf("UserNum should be saved correctly") + } + + user.UserNum = Num(222) + db.Save(&user) + + db.First(&user1, "name = ?", "1") + if user1.UserNum != Num(222) { + t.Errorf("UserNum should be updated correctly") + } +} + func TestPrecision(t *testing.T) { f := 35.03554004971999 user := User{Name: "Precision", Latitude: f} diff --git a/scope_private.go b/scope_private.go index 7212af2b..c2e5627e 100644 --- a/scope_private.go +++ b/scope_private.go @@ -309,26 +309,24 @@ func (scope *Scope) sqlTagForField(field *Field) (tag string) { value := field.Value reflectValue := reflect.ValueOf(value) - if field.IsScanner() { - value = reflectValue.Field(0).Interface() - } - switch reflectValue.Kind() { case reflect.Slice: if _, ok := value.([]byte); !ok { return } case reflect.Struct: - if !field.IsTime() && !field.IsScanner() { + if field.IsScanner() { + reflectValue = reflectValue.Field(0) + } else if !field.IsTime() { return } } if len(tag) == 0 { if field.isPrimaryKey { - tag = scope.Dialect().PrimaryKeyTag(value, size) + tag = scope.Dialect().PrimaryKeyTag(reflectValue, size) } else { - tag = scope.Dialect().SqlTag(value, size) + tag = scope.Dialect().SqlTag(reflectValue, size) } } From d232c69369c25db72812831eb92b34ea9dddcc2c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 16 Mar 2014 10:57:38 +0800 Subject: [PATCH 09/20] Fix exception in mysql --- main_test.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/main_test.go b/main_test.go index f244c03d..47611908 100644 --- a/main_test.go +++ b/main_test.go @@ -25,17 +25,20 @@ type IgnoredEmbedStruct struct { type Num int64 func (i *Num) Scan(src interface{}) error { - v := reflect.ValueOf(src) - if v.Kind() != reflect.Int64 { - return errors.New("Cannot scan NamedInt from " + v.String()) + switch s := src.(type) { + case []byte: + case int64: + *i = Num(s) + default: + return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) } - *i = Num(v.Int()) return nil } type User struct { Id int64 // Id: Primary key Age int64 + UserNum Num Name string `sql:"size:255"` Birthday time.Time // Time CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically @@ -53,7 +56,6 @@ type User struct { PasswordHash []byte IgnoreMe int64 `sql:"-"` IgnoreStringSlice []string `sql:"-"` - UserNum Num } type CreditCard struct { From a336f51444af6ea727308630346efccd8b0e716b Mon Sep 17 00:00:00 2001 From: Timothy Stranex Date: Sun, 16 Mar 2014 18:24:32 +0200 Subject: [PATCH 10/20] Add DB.Tx() method to provice access to the underlying sql.Tx instance. --- main.go | 12 ++++++++++++ main_test.go | 5 +++++ 2 files changed, 17 insertions(+) diff --git a/main.go b/main.go index ca1d24bb..f205b6c3 100644 --- a/main.go +++ b/main.go @@ -28,10 +28,22 @@ func Open(driver, source string) (DB, error) { return db, err } +// Return the underlying sql.DB instance. +// +// If called inside a transaction, it will panic. +// Use Tx() instead in this case. func (s *DB) DB() *sql.DB { return s.db.(*sql.DB) } +// Return the underlying sql.Tx instance. +// +// If called outside of a transaction, it will panic. +// Use DB() instead in this case. +func (s *DB) Tx() *sql.Tx { + return s.db.(*sql.Tx) +} + func (s *DB) Callback() *callback { s.parent.callback = s.parent.callback.clone() return s.parent.callback diff --git a/main_test.go b/main_test.go index 47611908..14de6701 100644 --- a/main_test.go +++ b/main_test.go @@ -1542,6 +1542,11 @@ func TestTransaction(t *testing.T) { t.Errorf("Should find saved record, but got", err) } + sql_tx := tx.Tx() // This shouldn't panic. + if sql_tx == nil { + t.Errorf("Should return the underlying sql.Tx, but got nil") + } + tx.Rollback() if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { From 42448cb5d6f68cf2c3a61bf68b217c5bcfb5ab30 Mon Sep 17 00:00:00 2001 From: Timothy Stranex Date: Mon, 17 Mar 2014 12:08:44 +0200 Subject: [PATCH 11/20] Add DB.CommonDB() instead of DB.Tx(), as discussed in the PR thread. --- main.go | 15 +++++---------- main_test.go | 7 +++---- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/main.go b/main.go index f205b6c3..942880b5 100644 --- a/main.go +++ b/main.go @@ -28,20 +28,15 @@ func Open(driver, source string) (DB, error) { return db, err } -// Return the underlying sql.DB instance. -// -// If called inside a transaction, it will panic. -// Use Tx() instead in this case. func (s *DB) DB() *sql.DB { return s.db.(*sql.DB) } -// Return the underlying sql.Tx instance. -// -// If called outside of a transaction, it will panic. -// Use DB() instead in this case. -func (s *DB) Tx() *sql.Tx { - return s.db.(*sql.Tx) +// Return the underlying sql.DB or sql.Tx instance. +// Use of this method is discouraged. It's mainly intended to allow +// coexistence with legacy non-GORM code. +func (s *DB) CommonDB() sqlCommon { + return s.db } func (s *DB) Callback() *callback { diff --git a/main_test.go b/main_test.go index 14de6701..32d9ac62 100644 --- a/main_test.go +++ b/main_test.go @@ -1542,10 +1542,9 @@ func TestTransaction(t *testing.T) { t.Errorf("Should find saved record, but got", err) } - sql_tx := tx.Tx() // This shouldn't panic. - if sql_tx == nil { - t.Errorf("Should return the underlying sql.Tx, but got nil") - } + if sql_tx, ok := tx.CommonDB().(*sql.Tx); !ok || sql_tx == nil { + t.Errorf("Should return the underlying sql.Tx") + } tx.Rollback() From d7d9e24e1ef251b76ec847f9c7153489f6e4d1a9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 26 Mar 2014 08:26:45 +0800 Subject: [PATCH 12/20] Add test for anonymous field --- main_test.go | 30 +++++++++++++++++++++++++++--- scope.go | 2 +- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/main_test.go b/main_test.go index 32d9ac62..540c1e13 100644 --- a/main_test.go +++ b/main_test.go @@ -35,6 +35,11 @@ func (i *Num) Scan(src interface{}) error { return nil } +type Role struct { + Id int64 + Name string +} + type User struct { Id int64 // Id: Primary key Age int64 @@ -53,9 +58,11 @@ type User struct { When time.Time CreditCard CreditCard Latitude float64 - PasswordHash []byte - IgnoreMe int64 `sql:"-"` - IgnoreStringSlice []string `sql:"-"` + Role + RoleId int64 + PasswordHash []byte + IgnoreMe int64 `sql:"-"` + IgnoreStringSlice []string `sql:"-"` } type CreditCard struct { @@ -143,6 +150,7 @@ func init() { db.Exec("drop table emails;") db.Exec("drop table addresses") db.Exec("drop table credit_cards") + db.Exec("drop table roles") if err = db.CreateTable(&User{}).Error; err != nil { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) @@ -164,6 +172,10 @@ func init() { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) } + if err = db.AutoMigrate(Role{}).Error; err != nil { + panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) + } + var shortForm = "2006-01-02 15:04:05" t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40") t2, _ = time.Parse(shortForm, "2002-01-01 00:00:00") @@ -1729,6 +1741,18 @@ func TestHaving(t *testing.T) { } } +func TestAnonymousField(t *testing.T) { + user := User{Name: "anonymous_field", Role: Role{Name: "admin"}} + db.Save(&user) + + var user2 User + db.First(&user2, "name = ?", "anonymous_field") + db.Model(&user2).Related(&user2.Role) + if user2.Role.Name != "admin" { + t.Errorf("Should be able to get anonymous field") + } +} + func TestExecRawSql(t *testing.T) { db.Exec("update users set name=? where name in (?)", "jinzhu", []string{"1", "2", "3"}) if db.Where("name in (?)", []string{"1", "2", "3"}).First(&User{}).Error != gorm.RecordNotFound { diff --git a/scope.go b/scope.go index 78742d89..40697f9a 100644 --- a/scope.go +++ b/scope.go @@ -227,7 +227,7 @@ func (scope *Scope) Fields() []*Field { scopeTyp := indirectValue.Type() for i := 0; i < scopeTyp.NumField(); i++ { fieldStruct := scopeTyp.Field(i) - if fieldStruct.Anonymous || !ast.IsExported(fieldStruct.Name) { + if !ast.IsExported(fieldStruct.Name) { continue } From 1949baf5c87d06e22625f1cb18e63b05f4c5a1c6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 26 Mar 2014 08:48:40 +0800 Subject: [PATCH 13/20] Test Related with search conditions --- main_test.go | 15 +++++++++++---- scope_private.go | 3 +-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/main_test.go b/main_test.go index 540c1e13..b6d1c17c 100644 --- a/main_test.go +++ b/main_test.go @@ -1351,10 +1351,17 @@ func TestRelated(t *testing.T) { if len(emails) != 2 { t.Errorf("Should have two emails") } + + var emails2 []Email + db.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2) + if len(emails2) != 1 { + t.Errorf("Should have two emails") + } + var user1 User db.Model(&user).Related(&user1.Emails) if len(user1.Emails) != 2 { - t.Errorf("Should have two emails") + t.Errorf("Should have only one email match related condition") } var address1 Address @@ -1554,9 +1561,9 @@ func TestTransaction(t *testing.T) { t.Errorf("Should find saved record, but got", err) } - if sql_tx, ok := tx.CommonDB().(*sql.Tx); !ok || sql_tx == nil { - t.Errorf("Should return the underlying sql.Tx") - } + if sql_tx, ok := tx.CommonDB().(*sql.Tx); !ok || sql_tx == nil { + t.Errorf("Should return the underlying sql.Tx") + } tx.Rollback() diff --git a/scope_private.go b/scope_private.go index c2e5627e..72f631cc 100644 --- a/scope_private.go +++ b/scope_private.go @@ -393,8 +393,7 @@ func (scope *Scope) typeName() string { } func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { - toScope := scope.New(value) - toScope.db = scope.db + toScope := scope.db.NewScope(value) for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { if foreignValue, ok := scope.FieldByName(foreignKey); ok { From 663c06cfb1bc0bb9d0135185d41180517fbb2ddb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 26 Mar 2014 10:28:34 +0800 Subject: [PATCH 14/20] Add test for anonymous scanner --- main_test.go | 45 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/main_test.go b/main_test.go index b6d1c17c..2852395b 100644 --- a/main_test.go +++ b/main_test.go @@ -35,11 +35,28 @@ func (i *Num) Scan(src interface{}) error { return nil } -type Role struct { +type Company struct { Id int64 Name string } +type Role struct { + Name string +} + +func (role *Role) Scan(value interface{}) error { + role.Name = string(value.([]uint8)) + return nil +} + +func (role Role) Value() (driver.Value, error) { + return role.Name, nil +} + +func (role Role) IsAdmin() bool { + return role.Name == "admin" +} + type User struct { Id int64 // Id: Primary key Age int64 @@ -58,8 +75,9 @@ type User struct { When time.Time CreditCard CreditCard Latitude float64 + CompanyId int64 + Company Role - RoleId int64 PasswordHash []byte IgnoreMe int64 `sql:"-"` IgnoreStringSlice []string `sql:"-"` @@ -172,7 +190,7 @@ func init() { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) } - if err = db.AutoMigrate(Role{}).Error; err != nil { + if err = db.AutoMigrate(Company{}).Error; err != nil { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) } @@ -1749,17 +1767,32 @@ func TestHaving(t *testing.T) { } func TestAnonymousField(t *testing.T) { - user := User{Name: "anonymous_field", Role: Role{Name: "admin"}} + user := User{Name: "anonymous_field", Company: Company{Name: "company"}} db.Save(&user) var user2 User db.First(&user2, "name = ?", "anonymous_field") - db.Model(&user2).Related(&user2.Role) - if user2.Role.Name != "admin" { + db.Model(&user2).Related(&user2.Company) + if user2.Company.Name != "company" { t.Errorf("Should be able to get anonymous field") } } +func TestAnonymousScanner(t *testing.T) { + user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}} + db.Save(&user) + + var user2 User + db.First(&user2, "name = ?", "anonymous_scanner") + if user2.Role.Name != "admin" { + t.Errorf("Should be able to get anonymous scanner") + } + + if !user2.IsAdmin() { + t.Errorf("Should be able to get anonymous scanner") + } +} + func TestExecRawSql(t *testing.T) { db.Exec("update users set name=? where name in (?)", "jinzhu", []string{"1", "2", "3"}) if db.Where("name in (?)", []string{"1", "2", "3"}).First(&User{}).Error != gorm.RecordNotFound { From 22cf9719bf2a7306dd19748be83cebf96a6d55fd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 26 Mar 2014 11:02:17 +0800 Subject: [PATCH 15/20] update test --- main_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/main_test.go b/main_test.go index 2852395b..5e8e3070 100644 --- a/main_test.go +++ b/main_test.go @@ -169,6 +169,7 @@ func init() { db.Exec("drop table addresses") db.Exec("drop table credit_cards") db.Exec("drop table roles") + db.Exec("drop table companies") if err = db.CreateTable(&User{}).Error; err != nil { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) @@ -194,6 +195,10 @@ func init() { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) } + if err = db.AutoMigrate(Role{}).Error; err != nil { + panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) + } + var shortForm = "2006-01-02 15:04:05" t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40") t2, _ = time.Parse(shortForm, "2002-01-01 00:00:00") From bcb1ca67c0e22470fa12c2d3afd50401f5c32b1d Mon Sep 17 00:00:00 2001 From: Paolo Galeone Date: Wed, 2 Apr 2014 10:59:07 +0200 Subject: [PATCH 16/20] Add support for primary key different from id --- README.md | 15 +++++++- main_test.go | 98 +++++++++++++++++++++++++++++++++++++++++++++++- scope.go | 30 ++++++++++----- scope_private.go | 31 +++++++++++++++ 4 files changed, 163 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 8a904a3e..5059f71d 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ go get github.com/jinzhu/gorm ## Conventions * Table name is the plural of struct name's snake case. - Disable pluralization with `db.SingularTable(true)`, or [specify your table name](#specify-table-name) + Disable pluralization with `db.SingularTable(true)`, or [Specifying the Table Name for Struct permanently with TableName](#Specifying-the-Table-Name-for-Struct-permanently-with-TableName) * Column name is the snake case of field's name. * Use `Id int64` field as primary key. * Use tag `sql` to change field's property, change the tag name with `db.SetTagIdentifier(new_name)`. @@ -47,6 +47,19 @@ db.First(&user) DB.Save(&User{Name: "xxx"}) // table "users" ``` +## Existing schema + +If you have and existing database schema and some of your tables does not follow the conventions, (and you can't rename your table names), please use: [Specifying the Table Name for Struct permanently with TableName](#Specifying-the-Table-Name-for-Struct-permanently-with-TableName). + +If your primary key field is different from `id`, you can add a tag to the field structure to specify that this field is a primary key. + +```go +type Animal struct { // animals + AnimalId int64 `primaryKey:"yes"` + Birthday time.Time + Age int64 +``` + # Getting Started ```go diff --git a/main_test.go b/main_test.go index 5e8e3070..5deb7822 100644 --- a/main_test.go +++ b/main_test.go @@ -7,9 +7,9 @@ import ( "fmt" _ "github.com/go-sql-driver/mysql" - "github.com/jinzhu/gorm" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" + "github.com/nerdzeu/gorm" "os" "reflect" @@ -127,6 +127,13 @@ type Product struct { AfterDeleteCallTimes int64 } +type Animal struct { + Counter int64 `primaryKey:"yes"` + Name string + CreatedAt time.Time + UpdatedAt time.Time +} + var ( db gorm.DB t1, t2, t3, t4, t5 time.Time @@ -170,6 +177,11 @@ func init() { db.Exec("drop table credit_cards") db.Exec("drop table roles") db.Exec("drop table companies") + db.Exec("drop table animals") + + if err = db.CreateTable(&Animal{}).Error; err != nil { + panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) + } if err = db.CreateTable(&User{}).Error; err != nil { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) @@ -210,6 +222,11 @@ func init() { db.Save(&User{Name: "3", Age: 22, Birthday: t3}) db.Save(&User{Name: "3", Age: 24, Birthday: t4}) db.Save(&User{Name: "5", Age: 26, Birthday: t4}) + + db.Save(&Animal{Name: "First"}) + db.Save(&Animal{Name: "Amazing"}) + db.Save(&Animal{Name: "Horse"}) + db.Save(&Animal{Name: "Last"}) } func TestFirstAndLast(t *testing.T) { @@ -230,6 +247,24 @@ func TestFirstAndLast(t *testing.T) { } } +func TestFirstAndLastForTableWithNoStdPrimaryKey(t *testing.T) { + var animal1, animal2, animal3, animal4 Animal + db.First(&animal1) + db.Order("counter").Find(&animal2) + + db.Last(&animal3) + db.Order("counter desc").Find(&animal4) + if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter { + t.Errorf("First and Last should works correctly") + } + + var animals []Animal + db.First(&animals) + if len(animals) != 1 { + t.Errorf("Find first record as map") + } +} + func TestSaveCustomType(t *testing.T) { var user, user1 User db.First(&user, "name = ?", "1") @@ -935,17 +970,31 @@ func TestSetTableDirectly(t *testing.T) { func TestUpdate(t *testing.T) { product1 := Product{Code: "123"} product2 := Product{Code: "234"} + animal1 := Animal{Name: "Ferdinand"} + animal2 := Animal{Name: "nerdz"} + db.Save(&product1).Save(&product2).Update("code", "456") if product2.Code != "456" { t.Errorf("Record should be updated with update attributes") } + db.Save(&animal1).Save(&animal2).Update("name", "Francis") + + if animal2.Name != "Francis" { + t.Errorf("Record should be updated with update attributes") + } + db.First(&product1, product1.Id) db.First(&product2, product2.Id) updated_at1 := product1.UpdatedAt updated_at2 := product2.UpdatedAt + db.First(&animal1, animal1.Counter) + db.First(&animal2, animal2.Counter) + animalUpdated_at1 := animal1.UpdatedAt + animalUpdated_at2 := animal2.UpdatedAt + var product3 Product db.First(&product3, product2.Id).Update("code", "456") if updated_at2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) { @@ -964,6 +1013,25 @@ func TestUpdate(t *testing.T) { t.Errorf("Product 234 should be changed to 456") } + var animal3 Animal + db.First(&animal3, animal2.Counter).Update("Name", "Robert") + + if animalUpdated_at2.Format(time.RFC3339Nano) != animal2.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("updated_at should not be updated if nothing changed") + } + + if db.First(&Animal{}, "name = 'Ferdinand'").Error != nil { + t.Errorf("Animal 'Ferdinand' should not be updated") + } + + if db.First(&Animal{}, "name = 'nerdz'").Error == nil { + t.Errorf("Animal 'nerdz' should be changed to 'Francis'") + } + + if db.First(&Animal{}, "name = 'Robert'").Error != nil { + t.Errorf("Animal 'nerdz' should be changed to 'Robert'") + } + db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789") var product4 Product @@ -991,6 +1059,34 @@ func TestUpdate(t *testing.T) { if db.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { t.Error("No error should raise when update_column with CamelCase") } + + db.Table("animals").Where("name in (?)", []string{"Ferdinand"}).Update("name", "Franz") + + var animal4 Animal + db.First(&animal4, animal1.Counter) + if animalUpdated_at1.Format(time.RFC3339Nano) != animal4.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("animalUpdated_at should be updated if something changed") + } + + if db.First(&Animal{}, "name = 'Ferdinand'").Error == nil { + t.Errorf("Animal 'Fredinand' should be changed to 'Franz'") + } + + if db.First(&Animal{}, "name = 'Robert'").Error != nil { + t.Errorf("Animal 'Robert' should not be changed to 'Francis'") + } + + if db.First(&Animal{}, "name = 'Franz'").Error != nil { + t.Errorf("Product 'nerdz' should be changed to 'Franz'") + } + + if db.Model(animal2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil { + t.Error("No error should raise when update with CamelCase") + } + + if db.Model(&animal2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { + t.Error("No error should raise when update_column with CamelCase") + } } func TestUpdates(t *testing.T) { diff --git a/scope.go b/scope.go index 40697f9a..543c4cde 100644 --- a/scope.go +++ b/scope.go @@ -13,13 +13,14 @@ import ( ) type Scope struct { - Value interface{} - Search *search - Sql string - SqlVars []interface{} - db *DB - _values map[string]interface{} - skipLeft bool + Value interface{} + Search *search + Sql string + SqlVars []interface{} + db *DB + _values map[string]interface{} + skipLeft bool + primaryKey string } // NewScope create scope for callbacks, including DB's search information @@ -78,7 +79,12 @@ func (scope *Scope) HasError() bool { // PrimaryKey get the primary key's column name func (scope *Scope) PrimaryKey() string { - return "id" + if scope.primaryKey != "" { + return scope.primaryKey + } + + scope.primaryKey = scope.getPrimaryKey() + return scope.primaryKey } // PrimaryKeyZero check the primary key is blank or not @@ -238,7 +244,13 @@ func (scope *Scope) Fields() []*Field { value := indirectValue.FieldByName(fieldStruct.Name) field.Value = value.Interface() field.IsBlank = isBlank(value) - field.isPrimaryKey = scope.PrimaryKey() == field.DBName + + // Search for primary key tag identifier + field.isPrimaryKey = scope.PrimaryKey() == field.DBName || fieldStruct.Tag.Get("primaryKey") != "" + + if field.isPrimaryKey { + scope.primaryKey = field.DBName + } if scope.db != nil { field.Tag = fieldStruct.Tag diff --git a/scope_private.go b/scope_private.go index 72f631cc..d9f16438 100644 --- a/scope_private.go +++ b/scope_private.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "errors" "fmt" + "go/ast" "reflect" "regexp" "strconv" @@ -472,3 +473,33 @@ func (scope *Scope) autoMigrate() *Scope { } return scope } + +func (scope *Scope) getPrimaryKey() string { + var indirectValue reflect.Value + + indirectValue = reflect.Indirect(reflect.ValueOf(scope.Value)) + + if indirectValue.Kind() == reflect.Slice { + indirectValue = reflect.New(indirectValue.Type().Elem()).Elem() + } + + if !indirectValue.IsValid() { + return "id" + } + + scopeTyp := indirectValue.Type() + for i := 0; i < scopeTyp.NumField(); i++ { + fieldStruct := scopeTyp.Field(i) + if !ast.IsExported(fieldStruct.Name) { + continue + } + + // if primaryKey tag found, return column name + if fieldStruct.Tag.Get("primaryKey") != "" { + return toSnake(fieldStruct.Name) + } + } + + //If primaryKey tag not found, fallback to id + return "id" +} From 1a5a4b707d48c411faaf3685361cc208a7d77969 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 6 Apr 2014 18:04:10 +0800 Subject: [PATCH 17/20] Use offical gorm package in tests --- main_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_test.go b/main_test.go index 5deb7822..2434c8ef 100644 --- a/main_test.go +++ b/main_test.go @@ -9,7 +9,7 @@ import ( _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" - "github.com/nerdzeu/gorm" + "github.com/jinzhu/gorm" "os" "reflect" From 98e9670b8e0bc0b44c42048e30a7198349e764d4 Mon Sep 17 00:00:00 2001 From: Duke Date: Thu, 10 Apr 2014 03:49:00 -0300 Subject: [PATCH 18/20] fixed example code --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 5059f71d..b06743b5 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,7 @@ type Animal struct { // animals AnimalId int64 `primaryKey:"yes"` Birthday time.Time Age int64 +} ``` # Getting Started From 5e62e7fdad2a56ccd57f67bc4831b3816c8d1779 Mon Sep 17 00:00:00 2001 From: Xavier Dumesnil Date: Thu, 10 Apr 2014 16:29:09 +0200 Subject: [PATCH 19/20] Include scope.TableName() in ORDER statement for First() & Last() --- main.go | 4 ++-- main_test.go | 12 ++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/main.go b/main.go index 942880b5..c98ca2b6 100644 --- a/main.go +++ b/main.go @@ -131,13 +131,13 @@ func (s *DB) Assign(attrs ...interface{}) *DB { func (s *DB) First(out interface{}, where ...interface{}) *DB { scope := s.clone().NewScope(out) - scope.Search = scope.Search.clone().order(scope.PrimaryKey()).limit(1) + scope.Search = scope.Search.clone().order(scope.TableName()+"."+scope.PrimaryKey()).limit(1) return scope.inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } func (s *DB) Last(out interface{}, where ...interface{}) *DB { scope := s.clone().NewScope(out) - scope.Search = scope.Search.clone().order(scope.PrimaryKey() + " DESC").limit(1) + scope.Search = scope.Search.clone().order(scope.TableName()+"."+scope.PrimaryKey() + " DESC").limit(1) return scope.inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } diff --git a/main_test.go b/main_test.go index 2434c8ef..31771a2c 100644 --- a/main_test.go +++ b/main_test.go @@ -247,6 +247,18 @@ func TestFirstAndLast(t *testing.T) { } } +func TestFirstAndLastWithJoins(t *testing.T) { + var user1, user2, user3, user4 User + db.Joins("left join emails on emails.user_id = users.id").First(&user1) + db.Order("id").Find(&user2) + + db.Joins("left join emails on emails.user_id = users.id").Last(&user3) + db.Order("id desc").Find(&user4) + if user1.Id != user2.Id || user3.Id != user4.Id { + t.Errorf("First and Last should works correctly") + } +} + func TestFirstAndLastForTableWithNoStdPrimaryKey(t *testing.T) { var animal1, animal2, animal3, animal4 Animal db.First(&animal1) From 2b7306aca1a250bf53984eed68954334d84fb9d6 Mon Sep 17 00:00:00 2001 From: Xavier Dumesnil Date: Fri, 11 Apr 2014 09:58:23 +0200 Subject: [PATCH 20/20] Fix typos --- main_test.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/main_test.go b/main_test.go index 31771a2c..e4016407 100644 --- a/main_test.go +++ b/main_test.go @@ -237,7 +237,7 @@ func TestFirstAndLast(t *testing.T) { db.Last(&user3) db.Order("id desc").Find(&user4) if user1.Id != user2.Id || user3.Id != user4.Id { - t.Errorf("First and Last should works correctly") + t.Errorf("First and Last should work correctly") } var users []User @@ -255,7 +255,7 @@ func TestFirstAndLastWithJoins(t *testing.T) { db.Joins("left join emails on emails.user_id = users.id").Last(&user3) db.Order("id desc").Find(&user4) if user1.Id != user2.Id || user3.Id != user4.Id { - t.Errorf("First and Last should works correctly") + t.Errorf("First and Last should work correctly with Joins") } } @@ -267,7 +267,7 @@ func TestFirstAndLastForTableWithNoStdPrimaryKey(t *testing.T) { db.Last(&animal3) db.Order("counter desc").Find(&animal4) if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter { - t.Errorf("First and Last should works correctly") + t.Errorf("First and Last should work correctly") } var animals []Animal @@ -631,7 +631,7 @@ func TestOrderAndPluck(t *testing.T) { db.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4) if reflect.DeepEqual(ages3, ages4) { - t.Errorf("Reorder should works") + t.Errorf("Reorder should work") } var names []string @@ -648,7 +648,7 @@ func TestLimit(t *testing.T) { db.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { - t.Errorf("Limit should works") + t.Errorf("Limit should work") } } @@ -657,7 +657,7 @@ func TestOffset(t *testing.T) { db.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { - t.Errorf("Offset should works") + t.Errorf("Offset should work") } } @@ -674,7 +674,7 @@ func TestCount(t *testing.T) { var users []User if err := db.Where("name = ?", "1").Or("name = ?", "3").Find(&users).Count(&count).Error; err != nil { - t.Errorf("Count should works", err) + t.Errorf("Count should work", err) } if count != int64(len(users)) { @@ -683,7 +683,7 @@ func TestCount(t *testing.T) { db.Model(&User{}).Where("name = ?", "1").Count(&count1).Or("name = ?", "3").Count(&count2) if count1 != 1 || count2 != 3 { - t.Errorf("Multiple count should works") + t.Errorf("Multiple count should work") } } @@ -801,7 +801,7 @@ func TestRunCallbacks(t *testing.T) { var products []Product db.Find(&products, "code = ?", "unique_code") if products[0].AfterFindCallTimes != 2 { - t.Errorf("AfterFind callbacks should works with slice") + t.Errorf("AfterFind callbacks should work with slice") } db.Where("Code = ?", "unique_code").First(&p) @@ -1765,7 +1765,7 @@ func TestScan(t *testing.T) { var res result db.Table("users").Select("name, age").Where("name = ?", 3).Scan(&res) if res.Name != "3" { - t.Errorf("Scan into struct should works") + t.Errorf("Scan into struct should work") } var ress []result