Merge 5bdf677206f9c18ad3f36554b61f5e7892ace855 into c7b9acefb7d4570e0f46d94ce453467cc486d8fd

This commit is contained in:
Tristan Rice 2016-01-29 23:29:12 +00:00
commit cf92ceae63
15 changed files with 243 additions and 29 deletions

View File

@ -16,7 +16,7 @@ func TestBelongsTo(t *testing.T) {
}
if err := DB.Save(&post).Error; err != nil {
t.Errorf("Got errors when save post", err.Error())
t.Errorf("Got errors when save post %s", err)
}
if post.Category.ID == 0 || post.MainCategory.ID == 0 {
@ -184,7 +184,7 @@ func TestHasOne(t *testing.T) {
}
if err := DB.Save(&user).Error; err != nil {
t.Errorf("Got errors when save user", err.Error())
t.Errorf("Got errors when save user %s", err.Error())
}
if user.CreditCard.UserId.Int64 == 0 {
@ -331,7 +331,7 @@ func TestHasMany(t *testing.T) {
}
if err := DB.Save(&post).Error; err != nil {
t.Errorf("Got errors when save post", err.Error())
t.Errorf("Got errors when save post %s", err.Error())
}
for _, comment := range post.Comments {

View File

@ -1,7 +1,9 @@
package gorm
import (
"database/sql"
"fmt"
"reflect"
"strings"
)
@ -28,7 +30,15 @@ func Create(scope *Scope) {
for _, field := range fields {
if scope.changeableField(field) {
if field.IsNormal {
if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
supportPrimary := scope.Dialect().SupportUniquePrimaryKey()
if !field.IsPrimaryKey || (field.IsPrimaryKey && (!field.IsBlank || !supportPrimary)) {
if field.IsPrimaryKey && !supportPrimary && field.IsBlank {
id := scope.Dialect().NewUniqueKey(scope)
if scope.HasError() {
return
}
field.Field.Set(reflect.ValueOf(id).Convert(field.Field.Type()))
}
if !field.IsBlank || !field.HasDefaultValue {
columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
@ -86,13 +96,14 @@ func Create(scope *Scope) {
}
} else {
if primaryField == nil {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == sql.ErrNoRows {
} else if err == nil {
scope.db.RowsAffected, _ = results.RowsAffected()
} else {
scope.Err(err)
}
} else {
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
if err := scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface()); err == nil || err == sql.ErrNoRows {
scope.db.RowsAffected = 1
} else {
scope.Err(err)

View File

@ -1,6 +1,9 @@
package gorm
import "reflect"
import (
"database/sql"
"reflect"
)
func BeginTransaction(scope *Scope) {
scope.Begin()
@ -18,7 +21,9 @@ func SaveBeforeAssociations(scope *Scope) {
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
value := field.Field
scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
if err := scope.NewDB().Save(value.Addr().Interface()).Error; err != nil && err != sql.ErrNoRows {
scope.Err(err)
}
if len(relationship.ForeignFieldNames) != 0 {
for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]

View File

@ -43,6 +43,11 @@ func Update(scope *Scope) {
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
for key, value := range updateAttrs.(map[string]interface{}) {
if !scope.Dialect().SupportUpdatePrimaryKey() {
if field, ok := scope.Fields()[key]; ok && field.IsPrimaryKey {
continue
}
}
if scope.changeableDBColumn(key) {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
}
@ -50,6 +55,9 @@ func Update(scope *Scope) {
} else {
fields := scope.Fields()
for _, field := range fields {
if field.IsPrimaryKey && !scope.Dialect().SupportUpdatePrimaryKey() {
continue
}
if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {

151
cockroach.go Normal file
View File

@ -0,0 +1,151 @@
package gorm
import (
"fmt"
"reflect"
"time"
)
type cockroach struct {
commonDialect
}
func (cockroach) BinVar(i int) string {
return fmt.Sprintf("$%v", i)
}
func (cockroach) SupportLastInsertId() bool {
return false
}
func (cockroach) SupportUniquePrimaryKey() bool {
return false
}
func (cockroach) SupportUpdatePrimaryKey() bool {
return false
}
func (cockroach) NewUniqueKey(scope *Scope) uint64 {
rows, err := scope.NewDB().Raw(`SELECT experimental_unique_int()`).Rows()
if err != nil {
scope.Err(err)
return 0
}
defer rows.Close()
var id int64
for rows.Next() {
if err := rows.Scan(&id); err != nil {
scope.Err(err)
return 0
}
}
return uint64(id)
}
func (cockroach) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "BOOLEAN"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "INTEGER PRIMARY KEY"
}
return "INTEGER"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "BIGINT PRIMARY KEY"
}
return "BIGINT"
case reflect.Float32, reflect.Float64:
return "FLOAT"
case reflect.String:
if size > 0 && size < 65532 {
return "VARCHAR"
}
return "TEXT"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
return "TIMESTAMP"
}
default:
if _, ok := value.Interface().([]byte); ok {
return "BYTES"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for cockroach", value.Type().Name(), value.Kind().String()))
}
func (s cockroach) HasTable(scope *Scope, tableName string) bool {
rows, err := scope.NewDB().Raw("show tables").Rows()
if err != nil {
scope.Err(err)
return false
}
defer rows.Close()
var name string
for rows.Next() {
if err := rows.Scan(&name); err != nil {
scope.Err(err)
return false
}
if name == tableName {
return true
}
}
return false
}
func (s cockroach) HasColumn(scope *Scope, tableName string, columnName string) bool {
rows, err := scope.NewDB().Raw(fmt.Sprintf("show columns from %s", tableName)).Rows()
if err != nil {
scope.Err(err)
return false
}
defer rows.Close()
var column string
var typ, null, defaultVal interface{}
for rows.Next() {
if err := rows.Scan(&column, &typ, &null, &defaultVal); err != nil {
scope.Err(err)
return false
}
if column == columnName {
return true
}
}
return false
}
func (s cockroach) HasIndex(scope *Scope, tableName string, indexName string) bool {
rows, err := scope.NewDB().Raw(fmt.Sprintf("show index from %s", tableName)).Rows()
if err != nil {
scope.Err(err)
return false
}
defer rows.Close()
var table, name, column, direction string
var unique, storing bool
var seq int
for rows.Next() {
if err := rows.Scan(&table, &name, &unique, &seq, &column, &direction, &storing); err != nil {
scope.Err(err)
return false
}
if name == indexName {
return true
}
}
return false
}
func (cockroach) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v@%v", scope.TableName(), indexName)).Error)
}
func (s cockroach) CurrentDatabase(scope *Scope) string {
var name string
s.RawScanString(scope, &name, "SHOW DATABASE")
return name
}

View File

@ -16,6 +16,18 @@ func (commonDialect) SupportLastInsertId() bool {
return true
}
func (commonDialect) SupportUniquePrimaryKey() bool {
return true
}
func (commonDialect) SupportUpdatePrimaryKey() bool {
return true
}
func (commonDialect) NewUniqueKey(scope *Scope) uint64 {
panic("NewUniqueKey not supported by commonDialect")
}
func (commonDialect) HasTop() bool {
return false
}

View File

@ -8,6 +8,9 @@ import (
type Dialect interface {
BinVar(i int) string
SupportLastInsertId() bool
SupportUniquePrimaryKey() bool
SupportUpdatePrimaryKey() bool
NewUniqueKey(scope *Scope) uint64
HasTop() bool
SqlTag(value reflect.Value, size int, autoIncrease bool) string
ReturningStr(tableName, key string) string
@ -23,6 +26,8 @@ type Dialect interface {
func NewDialect(driver string) Dialect {
var d Dialect
switch driver {
case "cockroach":
d = &cockroach{}
case "postgres":
d = &postgres{}
case "foundation":

View File

@ -6,6 +6,7 @@ import (
"fmt"
"strconv"
_ "github.com/cockroachdb/cockroach/sql/driver"
_ "github.com/denisenkom/go-mssqldb"
testdb "github.com/erikstmartin/go-testdb"
_ "github.com/go-sql-driver/mysql"
@ -53,6 +54,9 @@ func OpenTestConnection() (db gorm.DB, err error) {
case "postgres":
fmt.Println("testing postgres...")
db, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable")
case "cockroach":
fmt.Println("testing cockroach...")
db, err = gorm.Open("cockroach", "http://localhost:26257?database=gorm&user=root")
case "foundation":
fmt.Println("testing foundation...")
db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable")
@ -463,7 +467,9 @@ func TestJoins(t *testing.T) {
DB.Save(&user)
var result User
DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").First(&result)
if err := DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").First(&result).Error; err != nil {
t.Errorf("Error while joining: %s", err)
}
if result.Name != "joins" || result.Id != user.Id {
t.Errorf("Should find all two emails with Join")
}

View File

@ -524,11 +524,6 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
sqlType = value
}
additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
if value, ok := field.TagSettings["DEFAULT"]; ok {
additionalType = additionalType + " DEFAULT " + value
}
if field.IsScanner {
var getScannerValue func(reflect.Value)
getScannerValue = func(value reflect.Value) {
@ -558,6 +553,14 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease)
}
additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
if value, ok := field.TagSettings["DEFAULT"]; ok {
if _, ok := scope.Dialect().(*cockroach); ok && strings.TrimSpace(strings.ToLower(value)) == "null" {
value = value + "::" + strings.Split(sqlType, " ")[0]
}
additionalType = additionalType + " DEFAULT " + value
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
} else {

View File

@ -31,8 +31,8 @@ func TestFirstAndLast(t *testing.T) {
t.Errorf("Find first record as slice")
}
if DB.Joins("left join emails on emails.user_id = users.id").First(&User{}).Error != nil {
t.Errorf("Should not raise any error when order with Join table")
if err := DB.Joins("left join emails on emails.user_id = users.id").First(&User{}).Error; err != nil {
t.Errorf("Should not raise any error when order with Join table: %s", err)
}
}
@ -52,15 +52,19 @@ func TestFirstAndLastWithNoStdPrimaryKey(t *testing.T) {
}
func TestUIntPrimaryKey(t *testing.T) {
insertedAnimal := &Animal{Name: "animalUint1"}
insertedAnimal2 := &Animal{Name: "animalUint2"}
DB.Save(insertedAnimal)
DB.Save(insertedAnimal2)
var animal Animal
DB.First(&animal, uint64(1))
if animal.Counter != 1 {
t.Errorf("Fetch a record from with a non-int primary key should work, but failed")
DB.First(&animal, insertedAnimal.Counter)
if animal.Counter != insertedAnimal.Counter || animal.Counter <= 0 {
t.Errorf("Fetch a record from with a non-int primary key should work, but failed; got %d", animal.Counter)
}
DB.Model(Animal{}).Where(Animal{Counter: uint64(2)}).Scan(&animal)
if animal.Counter != 2 {
t.Errorf("Fetch a record from with a non-int primary key should work, but failed")
DB.Model(Animal{}).Where(Animal{Counter: insertedAnimal2.Counter}).Scan(&animal)
if animal.Counter != insertedAnimal2.Counter || animal.Counter <= 0 {
t.Errorf("Fetch a record from with a non-int primary key should work, but failed; got %d", animal.Counter)
}
}

View File

@ -1,6 +1,7 @@
package gorm
import (
"database/sql/driver"
"errors"
"fmt"
"regexp"
@ -314,13 +315,15 @@ func (scope *Scope) Raw(sql string) *Scope {
return scope
}
var _, driverResultNoRows = driver.ResultNoRows.RowsAffected()
// Exec invoke sql
func (scope *Scope) Exec() *Scope {
defer scope.Trace(NowFunc())
if !scope.HasError() {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
if count, err := result.RowsAffected(); scope.Err(err) == nil {
if count, err := result.RowsAffected(); err != nil && err.Error() == driverResultNoRows.Error() || scope.Err(err) == nil {
scope.db.RowsAffected = count
}
}

View File

@ -8,6 +8,7 @@ import (
"regexp"
"strconv"
"strings"
"time"
)
func (scope *Scope) primaryCondition(value interface{}) string {
@ -170,7 +171,12 @@ func (scope *Scope) whereSql() (sql string) {
var primaryConditions, andConditions, orConditions []string
if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil {
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName())
time, err := time.Parse("2006-01-02", "0001-01-02")
if err != nil {
scope.Err(err)
return
}
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= %v)", scope.QuotedTableName(), scope.QuotedTableName(), scope.AddToVars(time))
primaryConditions = append(primaryConditions, sql)
}

View File

@ -20,13 +20,13 @@ func TestScannableSlices(t *testing.T) {
}
if err := DB.Save(&r1).Error; err != nil {
t.Errorf("Should save record with slice values")
t.Errorf("Should save record with slice values; err %s", err)
}
var r2 RecordWithSlice
if err := DB.Find(&r2).Error; err != nil {
t.Errorf("Should fetch record with slice values")
t.Errorf("Should fetch record with slice values; err %s", err)
}
if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" {

View File

@ -39,7 +39,7 @@ type User struct {
}
type CreditCard struct {
ID int8
ID uint64
Number string
UserId sql.NullInt64
CreatedAt time.Time
@ -48,7 +48,7 @@ type CreditCard struct {
}
type Email struct {
Id int16
Id int64
UserId int
Email string `sql:"type:varchar(100);"`
CreatedAt time.Time

View File

@ -1,4 +1,4 @@
dialects=("postgres" "mysql" "sqlite")
dialects=("postgres" "mysql" "sqlite" "cockroach")
for dialect in "${dialects[@]}" ; do
GORM_DIALECT=${dialect} go test