diff --git a/association_test.go b/association_test.go index f02d4620..58746572 100644 --- a/association_test.go +++ b/association_test.go @@ -16,7 +16,7 @@ func TestBelongsTo(t *testing.T) { } if err := DB.Save(&post).Error; err != nil { - t.Errorf("Got errors when save post", err.Error()) + t.Errorf("Got errors when save post %s", err) } if post.Category.ID == 0 || post.MainCategory.ID == 0 { @@ -184,7 +184,7 @@ func TestHasOne(t *testing.T) { } if err := DB.Save(&user).Error; err != nil { - t.Errorf("Got errors when save user", err.Error()) + t.Errorf("Got errors when save user %s", err.Error()) } if user.CreditCard.UserId.Int64 == 0 { @@ -331,7 +331,7 @@ func TestHasMany(t *testing.T) { } if err := DB.Save(&post).Error; err != nil { - t.Errorf("Got errors when save post", err.Error()) + t.Errorf("Got errors when save post %s", err.Error()) } for _, comment := range post.Comments { diff --git a/callback_create.go b/callback_create.go index d13a71be..c0e9c59d 100644 --- a/callback_create.go +++ b/callback_create.go @@ -1,7 +1,9 @@ package gorm import ( + "database/sql" "fmt" + "reflect" "strings" ) @@ -28,7 +30,15 @@ func Create(scope *Scope) { for _, field := range fields { if scope.changeableField(field) { if field.IsNormal { - if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) { + supportPrimary := scope.Dialect().SupportUniquePrimaryKey() + if !field.IsPrimaryKey || (field.IsPrimaryKey && (!field.IsBlank || !supportPrimary)) { + if field.IsPrimaryKey && !supportPrimary && field.IsBlank { + id := scope.Dialect().NewUniqueKey(scope) + if scope.HasError() { + return + } + field.Field.Set(reflect.ValueOf(id).Convert(field.Field.Type())) + } if !field.IsBlank || !field.HasDefaultValue { columns = append(columns, scope.Quote(field.DBName)) sqls = append(sqls, scope.AddToVars(field.Field.Interface())) @@ -86,13 +96,14 @@ func Create(scope *Scope) { } } else { if primaryField == nil { - if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil { + if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == sql.ErrNoRows { + } else if err == nil { scope.db.RowsAffected, _ = results.RowsAffected() } else { scope.Err(err) } } else { - if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil { + if err := scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface()); err == nil || err == sql.ErrNoRows { scope.db.RowsAffected = 1 } else { scope.Err(err) diff --git a/callback_shared.go b/callback_shared.go index 547059e3..da45b18c 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -1,6 +1,9 @@ package gorm -import "reflect" +import ( + "database/sql" + "reflect" +) func BeginTransaction(scope *Scope) { scope.Begin() @@ -18,7 +21,9 @@ func SaveBeforeAssociations(scope *Scope) { if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { value := field.Field - scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error) + if err := scope.NewDB().Save(value.Addr().Interface()).Error; err != nil && err != sql.ErrNoRows { + scope.Err(err) + } if len(relationship.ForeignFieldNames) != 0 { for idx, fieldName := range relationship.ForeignFieldNames { associationForeignName := relationship.AssociationForeignDBNames[idx] diff --git a/callback_update.go b/callback_update.go index 4c9952d2..64d71d7c 100644 --- a/callback_update.go +++ b/callback_update.go @@ -43,6 +43,11 @@ func Update(scope *Scope) { if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { for key, value := range updateAttrs.(map[string]interface{}) { + if !scope.Dialect().SupportUpdatePrimaryKey() { + if field, ok := scope.Fields()[key]; ok && field.IsPrimaryKey { + continue + } + } if scope.changeableDBColumn(key) { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value))) } @@ -50,6 +55,9 @@ func Update(scope *Scope) { } else { fields := scope.Fields() for _, field := range fields { + if field.IsPrimaryKey && !scope.Dialect().SupportUpdatePrimaryKey() { + continue + } if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { diff --git a/cockroach.go b/cockroach.go new file mode 100644 index 00000000..7de61b8d --- /dev/null +++ b/cockroach.go @@ -0,0 +1,151 @@ +package gorm + +import ( + "fmt" + "reflect" + "time" +) + +type cockroach struct { + commonDialect +} + +func (cockroach) BinVar(i int) string { + return fmt.Sprintf("$%v", i) +} + +func (cockroach) SupportLastInsertId() bool { + return false +} + +func (cockroach) SupportUniquePrimaryKey() bool { + return false +} + +func (cockroach) SupportUpdatePrimaryKey() bool { + return false +} + +func (cockroach) NewUniqueKey(scope *Scope) uint64 { + rows, err := scope.NewDB().Raw(`SELECT experimental_unique_int()`).Rows() + if err != nil { + scope.Err(err) + return 0 + } + defer rows.Close() + var id int64 + for rows.Next() { + if err := rows.Scan(&id); err != nil { + scope.Err(err) + return 0 + } + } + return uint64(id) +} + +func (cockroach) SqlTag(value reflect.Value, size int, autoIncrease bool) string { + switch value.Kind() { + case reflect.Bool: + return "BOOLEAN" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if autoIncrease { + return "INTEGER PRIMARY KEY" + } + return "INTEGER" + case reflect.Int64, reflect.Uint64: + if autoIncrease { + return "BIGINT PRIMARY KEY" + } + return "BIGINT" + case reflect.Float32, reflect.Float64: + return "FLOAT" + case reflect.String: + if size > 0 && size < 65532 { + return "VARCHAR" + } + return "TEXT" + case reflect.Struct: + if _, ok := value.Interface().(time.Time); ok { + return "TIMESTAMP" + } + default: + if _, ok := value.Interface().([]byte); ok { + return "BYTES" + } + } + panic(fmt.Sprintf("invalid sql type %s (%s) for cockroach", value.Type().Name(), value.Kind().String())) +} + +func (s cockroach) HasTable(scope *Scope, tableName string) bool { + rows, err := scope.NewDB().Raw("show tables").Rows() + if err != nil { + scope.Err(err) + return false + } + defer rows.Close() + var name string + for rows.Next() { + if err := rows.Scan(&name); err != nil { + scope.Err(err) + return false + } + if name == tableName { + return true + } + } + return false +} + +func (s cockroach) HasColumn(scope *Scope, tableName string, columnName string) bool { + rows, err := scope.NewDB().Raw(fmt.Sprintf("show columns from %s", tableName)).Rows() + if err != nil { + scope.Err(err) + return false + } + defer rows.Close() + var column string + var typ, null, defaultVal interface{} + for rows.Next() { + if err := rows.Scan(&column, &typ, &null, &defaultVal); err != nil { + scope.Err(err) + return false + } + if column == columnName { + return true + } + } + return false +} + +func (s cockroach) HasIndex(scope *Scope, tableName string, indexName string) bool { + rows, err := scope.NewDB().Raw(fmt.Sprintf("show index from %s", tableName)).Rows() + if err != nil { + scope.Err(err) + return false + } + defer rows.Close() + + var table, name, column, direction string + var unique, storing bool + var seq int + for rows.Next() { + if err := rows.Scan(&table, &name, &unique, &seq, &column, &direction, &storing); err != nil { + scope.Err(err) + return false + } + if name == indexName { + return true + } + } + return false +} + +func (cockroach) RemoveIndex(scope *Scope, indexName string) { + scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v@%v", scope.TableName(), indexName)).Error) +} + +func (s cockroach) CurrentDatabase(scope *Scope) string { + var name string + s.RawScanString(scope, &name, "SHOW DATABASE") + return name +} diff --git a/common_dialect.go b/common_dialect.go index 7f08b04f..874523b2 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -16,6 +16,18 @@ func (commonDialect) SupportLastInsertId() bool { return true } +func (commonDialect) SupportUniquePrimaryKey() bool { + return true +} + +func (commonDialect) SupportUpdatePrimaryKey() bool { + return true +} + +func (commonDialect) NewUniqueKey(scope *Scope) uint64 { + panic("NewUniqueKey not supported by commonDialect") +} + func (commonDialect) HasTop() bool { return false } diff --git a/dialect.go b/dialect.go index 926f8a11..94294717 100644 --- a/dialect.go +++ b/dialect.go @@ -8,6 +8,9 @@ import ( type Dialect interface { BinVar(i int) string SupportLastInsertId() bool + SupportUniquePrimaryKey() bool + SupportUpdatePrimaryKey() bool + NewUniqueKey(scope *Scope) uint64 HasTop() bool SqlTag(value reflect.Value, size int, autoIncrease bool) string ReturningStr(tableName, key string) string @@ -23,6 +26,8 @@ type Dialect interface { func NewDialect(driver string) Dialect { var d Dialect switch driver { + case "cockroach": + d = &cockroach{} case "postgres": d = &postgres{} case "foundation": diff --git a/main_test.go b/main_test.go index 65467d73..28b110d1 100644 --- a/main_test.go +++ b/main_test.go @@ -6,6 +6,7 @@ import ( "fmt" "strconv" + _ "github.com/cockroachdb/cockroach/sql/driver" _ "github.com/denisenkom/go-mssqldb" testdb "github.com/erikstmartin/go-testdb" _ "github.com/go-sql-driver/mysql" @@ -53,6 +54,9 @@ func OpenTestConnection() (db gorm.DB, err error) { case "postgres": fmt.Println("testing postgres...") db, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable") + case "cockroach": + fmt.Println("testing cockroach...") + db, err = gorm.Open("cockroach", "http://localhost:26257?database=gorm&user=root") case "foundation": fmt.Println("testing foundation...") db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable") @@ -463,7 +467,9 @@ func TestJoins(t *testing.T) { DB.Save(&user) var result User - DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").First(&result) + if err := DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").First(&result).Error; err != nil { + t.Errorf("Error while joining: %s", err) + } if result.Name != "joins" || result.Id != user.Id { t.Errorf("Should find all two emails with Join") } diff --git a/model_struct.go b/model_struct.go index d80165c8..866d276f 100644 --- a/model_struct.go +++ b/model_struct.go @@ -524,11 +524,6 @@ func (scope *Scope) generateSqlTag(field *StructField) string { sqlType = value } - additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] - if value, ok := field.TagSettings["DEFAULT"]; ok { - additionalType = additionalType + " DEFAULT " + value - } - if field.IsScanner { var getScannerValue func(reflect.Value) getScannerValue = func(value reflect.Value) { @@ -558,6 +553,14 @@ func (scope *Scope) generateSqlTag(field *StructField) string { sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease) } + additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] + if value, ok := field.TagSettings["DEFAULT"]; ok { + if _, ok := scope.Dialect().(*cockroach); ok && strings.TrimSpace(strings.ToLower(value)) == "null" { + value = value + "::" + strings.Split(sqlType, " ")[0] + } + additionalType = additionalType + " DEFAULT " + value + } + if strings.TrimSpace(additionalType) == "" { return sqlType } else { diff --git a/query_test.go b/query_test.go index a7d5bc0e..148c395c 100644 --- a/query_test.go +++ b/query_test.go @@ -31,8 +31,8 @@ func TestFirstAndLast(t *testing.T) { t.Errorf("Find first record as slice") } - if DB.Joins("left join emails on emails.user_id = users.id").First(&User{}).Error != nil { - t.Errorf("Should not raise any error when order with Join table") + if err := DB.Joins("left join emails on emails.user_id = users.id").First(&User{}).Error; err != nil { + t.Errorf("Should not raise any error when order with Join table: %s", err) } } @@ -52,15 +52,19 @@ func TestFirstAndLastWithNoStdPrimaryKey(t *testing.T) { } func TestUIntPrimaryKey(t *testing.T) { + insertedAnimal := &Animal{Name: "animalUint1"} + insertedAnimal2 := &Animal{Name: "animalUint2"} + DB.Save(insertedAnimal) + DB.Save(insertedAnimal2) var animal Animal - DB.First(&animal, uint64(1)) - if animal.Counter != 1 { - t.Errorf("Fetch a record from with a non-int primary key should work, but failed") + DB.First(&animal, insertedAnimal.Counter) + if animal.Counter != insertedAnimal.Counter || animal.Counter <= 0 { + t.Errorf("Fetch a record from with a non-int primary key should work, but failed; got %d", animal.Counter) } - DB.Model(Animal{}).Where(Animal{Counter: uint64(2)}).Scan(&animal) - if animal.Counter != 2 { - t.Errorf("Fetch a record from with a non-int primary key should work, but failed") + DB.Model(Animal{}).Where(Animal{Counter: insertedAnimal2.Counter}).Scan(&animal) + if animal.Counter != insertedAnimal2.Counter || animal.Counter <= 0 { + t.Errorf("Fetch a record from with a non-int primary key should work, but failed; got %d", animal.Counter) } } diff --git a/scope.go b/scope.go index a11d4ec4..29354133 100644 --- a/scope.go +++ b/scope.go @@ -1,6 +1,7 @@ package gorm import ( + "database/sql/driver" "errors" "fmt" "regexp" @@ -314,13 +315,15 @@ func (scope *Scope) Raw(sql string) *Scope { return scope } +var _, driverResultNoRows = driver.ResultNoRows.RowsAffected() + // Exec invoke sql func (scope *Scope) Exec() *Scope { defer scope.Trace(NowFunc()) if !scope.HasError() { if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { - if count, err := result.RowsAffected(); scope.Err(err) == nil { + if count, err := result.RowsAffected(); err != nil && err.Error() == driverResultNoRows.Error() || scope.Err(err) == nil { scope.db.RowsAffected = count } } diff --git a/scope_private.go b/scope_private.go index fa5d5f44..269b2a7d 100644 --- a/scope_private.go +++ b/scope_private.go @@ -8,6 +8,7 @@ import ( "regexp" "strconv" "strings" + "time" ) func (scope *Scope) primaryCondition(value interface{}) string { @@ -170,7 +171,12 @@ func (scope *Scope) whereSql() (sql string) { var primaryConditions, andConditions, orConditions []string if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil { - sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName()) + time, err := time.Parse("2006-01-02", "0001-01-02") + if err != nil { + scope.Err(err) + return + } + sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= %v)", scope.QuotedTableName(), scope.QuotedTableName(), scope.AddToVars(time)) primaryConditions = append(primaryConditions, sql) } diff --git a/slice_test.go b/slice_test.go index 21410548..fa1c6c76 100644 --- a/slice_test.go +++ b/slice_test.go @@ -20,13 +20,13 @@ func TestScannableSlices(t *testing.T) { } if err := DB.Save(&r1).Error; err != nil { - t.Errorf("Should save record with slice values") + t.Errorf("Should save record with slice values; err %s", err) } var r2 RecordWithSlice if err := DB.Find(&r2).Error; err != nil { - t.Errorf("Should fetch record with slice values") + t.Errorf("Should fetch record with slice values; err %s", err) } if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" { diff --git a/structs_test.go b/structs_test.go index cb9c9260..da3159af 100644 --- a/structs_test.go +++ b/structs_test.go @@ -39,7 +39,7 @@ type User struct { } type CreditCard struct { - ID int8 + ID uint64 Number string UserId sql.NullInt64 CreatedAt time.Time @@ -48,7 +48,7 @@ type CreditCard struct { } type Email struct { - Id int16 + Id int64 UserId int Email string `sql:"type:varchar(100);"` CreatedAt time.Time diff --git a/test_all.sh b/test_all.sh index 6c5593b3..701018e3 100755 --- a/test_all.sh +++ b/test_all.sh @@ -1,4 +1,4 @@ -dialects=("postgres" "mysql" "sqlite") +dialects=("postgres" "mysql" "sqlite" "cockroach") for dialect in "${dialects[@]}" ; do GORM_DIALECT=${dialect} go test