diff --git a/association_test.go b/association_test.go index f02d4620..58746572 100644 --- a/association_test.go +++ b/association_test.go @@ -16,7 +16,7 @@ func TestBelongsTo(t *testing.T) { } if err := DB.Save(&post).Error; err != nil { - t.Errorf("Got errors when save post", err.Error()) + t.Errorf("Got errors when save post %s", err) } if post.Category.ID == 0 || post.MainCategory.ID == 0 { @@ -184,7 +184,7 @@ func TestHasOne(t *testing.T) { } if err := DB.Save(&user).Error; err != nil { - t.Errorf("Got errors when save user", err.Error()) + t.Errorf("Got errors when save user %s", err.Error()) } if user.CreditCard.UserId.Int64 == 0 { @@ -331,7 +331,7 @@ func TestHasMany(t *testing.T) { } if err := DB.Save(&post).Error; err != nil { - t.Errorf("Got errors when save post", err.Error()) + t.Errorf("Got errors when save post %s", err.Error()) } for _, comment := range post.Comments { diff --git a/callback_create.go b/callback_create.go index d13a71be..6c418662 100644 --- a/callback_create.go +++ b/callback_create.go @@ -1,7 +1,10 @@ package gorm import ( + "database/sql" "fmt" + "log" + "reflect" "strings" ) @@ -28,7 +31,16 @@ func Create(scope *Scope) { for _, field := range fields { if scope.changeableField(field) { if field.IsNormal { - if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) { + supportPrimary := scope.Dialect().SupportUniquePrimaryKey() + if !field.IsPrimaryKey || (field.IsPrimaryKey && (!field.IsBlank || !supportPrimary)) { + if field.IsPrimaryKey && !supportPrimary && field.IsBlank { + id := scope.Dialect().NewUniqueKey(scope) + if scope.HasError() { + return + } + log.Printf("ID %+v %+v", id, field.Field.Type().String()) + field.Field.Set(reflect.ValueOf(id).Convert(field.Field.Type())) + } if !field.IsBlank || !field.HasDefaultValue { columns = append(columns, scope.Quote(field.DBName)) sqls = append(sqls, scope.AddToVars(field.Field.Interface())) @@ -86,18 +98,37 @@ func Create(scope *Scope) { } } else { if primaryField == nil { - if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil { + if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == sql.ErrNoRows { + } else if err == nil { scope.db.RowsAffected, _ = results.RowsAffected() } else { + log.Printf("create err no primary %#v eql %#v", err, err == sql.ErrNoRows) scope.Err(err) } - } else { - if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil { + } else { // if scope.Dialect().SupportUniquePrimaryKey() { + if err := scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface()); err == nil || err == sql.ErrNoRows { scope.db.RowsAffected = 1 } else { + log.Printf("create err %#v eql %#v", err, err == sql.ErrNoRows) scope.Err(err) } - } + } /* else { + // Create a new primary key if one is required, not set, and the server doesn't support unique primary keys. + log.Printf("key type %T %#v", val.Interface(), val.Interface()) + if key, ok := val.Interface().(*uint); ok && (key == nil || *key == 0) { + val := primaryField.Field.Addr() + id := scope.Dialect().NewUniqueKey(scope) + v := reflect.Indirect(val) + v.SetUint(id) + } + if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == sql.ErrNoRows { + } else if err == nil { + scope.db.RowsAffected, _ = results.RowsAffected() + } else { + log.Printf("create err no primary %#v eql %#v", err, err == sql.ErrNoRows) + scope.Err(err) + } + }*/ } } } diff --git a/callback_query.go b/callback_query.go index 5473f232..0db197ff 100644 --- a/callback_query.go +++ b/callback_query.go @@ -3,6 +3,7 @@ package gorm import ( "errors" "fmt" + "log" "reflect" ) @@ -83,6 +84,8 @@ func Query(scope *Scope) { scope.Err(rows.Scan(values...)) + log.Println("result values", values) + for index, column := range columns { value := values[index] if field, ok := fields[column]; ok { diff --git a/callback_shared.go b/callback_shared.go index 547059e3..da45b18c 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -1,6 +1,9 @@ package gorm -import "reflect" +import ( + "database/sql" + "reflect" +) func BeginTransaction(scope *Scope) { scope.Begin() @@ -18,7 +21,9 @@ func SaveBeforeAssociations(scope *Scope) { if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { value := field.Field - scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error) + if err := scope.NewDB().Save(value.Addr().Interface()).Error; err != nil && err != sql.ErrNoRows { + scope.Err(err) + } if len(relationship.ForeignFieldNames) != 0 { for idx, fieldName := range relationship.ForeignFieldNames { associationForeignName := relationship.AssociationForeignDBNames[idx] diff --git a/cockroach.go b/cockroach.go new file mode 100644 index 00000000..f4d572d0 --- /dev/null +++ b/cockroach.go @@ -0,0 +1,141 @@ +package gorm + +import ( + "fmt" + "log" + "reflect" + "time" +) + +type cockroach struct { + commonDialect +} + +func (cockroach) BinVar(i int) string { + return fmt.Sprintf("$%v", i) +} + +func (cockroach) SupportLastInsertId() bool { + return false +} + +func (cockroach) SupportUniquePrimaryKey() bool { + return false +} + +func (cockroach) NewUniqueKey(scope *Scope) uint64 { + rows, err := scope.NewDB().Raw(`SELECT experimental_unique_int()`).Rows() + if err != nil { + scope.Err(err) + return 0 + } + var id int64 + for rows.Next() { + if err := rows.Scan(&id); err != nil { + log.Fatal("ERR UNIQUE ID", id, err) + scope.Err(err) + return 0 + } + } + log.Printf("UNIQUE ID %#v", id) + return uint64(id) +} + +func (cockroach) SqlTag(value reflect.Value, size int, autoIncrease bool) string { + switch value.Kind() { + case reflect.Bool: + return "BOOLEAN" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if autoIncrease { + return "INTEGER PRIMARY KEY" + } + return "INTEGER" + case reflect.Int64, reflect.Uint64: + if autoIncrease { + return "BIGINT PRIMARY KEY" + } + return "BIGINT" + case reflect.Float32, reflect.Float64: + return "FLOAT" + case reflect.String: + if size > 0 && size < 65532 { + return "VARCHAR" + } + return "TEXT" + case reflect.Struct: + if _, ok := value.Interface().(time.Time); ok { + return "TIMESTAMP" + } + default: + if _, ok := value.Interface().([]byte); ok { + return "BYTES" + } + } + panic(fmt.Sprintf("invalid sql type %s (%s) for cockroach", value.Type().Name(), value.Kind().String())) +} + +func (s cockroach) HasTable(scope *Scope, tableName string) bool { + rows, err := scope.NewDB().Raw("show tables").Rows() + if err != nil { + scope.Err(err) + return false + } + defer rows.Close() + var name string + for rows.Next() { + rows.Scan(&name) + if name == tableName { + return true + } + } + return false +} + +func (s cockroach) HasColumn(scope *Scope, tableName string, columnName string) bool { + rows, err := scope.NewDB().Raw(fmt.Sprintf("show columns from %s", tableName)).Rows() + if err != nil { + scope.Err(err) + return false + } + defer rows.Close() + var column string + for rows.Next() { + rows.Scan(&column) + if column == columnName { + return true + } + } + return false +} + +func (s cockroach) HasIndex(scope *Scope, tableName string, indexName string) bool { + /* + var count int + s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName) + return count > 0 + */ + rows, err := scope.NewDB().Raw(fmt.Sprintf("show index from %s", tableName)).Rows() + if err != nil { + scope.Err(err) + return false + } + defer rows.Close() + var name string + for rows.Next() { + rows.Scan(nil, &name) + if name == indexName { + return true + } + } + return false +} + +func (cockroach) RemoveIndex(scope *Scope, indexName string) { + scope.Err(scope.NewDB().Raw(fmt.Sprintf("DROP INDEX %v@%v", scope.QuotedTableName(), indexName)).Error) +} + +func (s cockroach) CurrentDatabase(scope *Scope) string { + var name string + s.RawScanString(scope, &name, "SHOW DATABASE") + return name +} diff --git a/common_dialect.go b/common_dialect.go index 7f08b04f..4d5d00a5 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -16,6 +16,14 @@ func (commonDialect) SupportLastInsertId() bool { return true } +func (commonDialect) SupportUniquePrimaryKey() bool { + return true +} + +func (commonDialect) NewUniqueKey(scope *Scope) uint64 { + panic("NewUniqueKey not supported by commonDialect") +} + func (commonDialect) HasTop() bool { return false } diff --git a/dialect.go b/dialect.go index 926f8a11..f2074d13 100644 --- a/dialect.go +++ b/dialect.go @@ -8,6 +8,8 @@ import ( type Dialect interface { BinVar(i int) string SupportLastInsertId() bool + SupportUniquePrimaryKey() bool + NewUniqueKey(scope *Scope) uint64 HasTop() bool SqlTag(value reflect.Value, size int, autoIncrease bool) string ReturningStr(tableName, key string) string @@ -23,6 +25,8 @@ type Dialect interface { func NewDialect(driver string) Dialect { var d Dialect switch driver { + case "cockroach": + d = &cockroach{} case "postgres": d = &postgres{} case "foundation": diff --git a/main_test.go b/main_test.go index 65467d73..a1505126 100644 --- a/main_test.go +++ b/main_test.go @@ -6,6 +6,7 @@ import ( "fmt" "strconv" + _ "github.com/cockroachdb/cockroach/sql/driver" _ "github.com/denisenkom/go-mssqldb" testdb "github.com/erikstmartin/go-testdb" _ "github.com/go-sql-driver/mysql" @@ -53,6 +54,9 @@ func OpenTestConnection() (db gorm.DB, err error) { case "postgres": fmt.Println("testing postgres...") db, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable") + case "cockroach": + fmt.Println("testing cockroach...") + db, err = gorm.Open("cockroach", "http://localhost:26257?database=gorm&user=root") case "foundation": fmt.Println("testing foundation...") db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable") diff --git a/model_struct.go b/model_struct.go index d80165c8..866d276f 100644 --- a/model_struct.go +++ b/model_struct.go @@ -524,11 +524,6 @@ func (scope *Scope) generateSqlTag(field *StructField) string { sqlType = value } - additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] - if value, ok := field.TagSettings["DEFAULT"]; ok { - additionalType = additionalType + " DEFAULT " + value - } - if field.IsScanner { var getScannerValue func(reflect.Value) getScannerValue = func(value reflect.Value) { @@ -558,6 +553,14 @@ func (scope *Scope) generateSqlTag(field *StructField) string { sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease) } + additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] + if value, ok := field.TagSettings["DEFAULT"]; ok { + if _, ok := scope.Dialect().(*cockroach); ok && strings.TrimSpace(strings.ToLower(value)) == "null" { + value = value + "::" + strings.Split(sqlType, " ")[0] + } + additionalType = additionalType + " DEFAULT " + value + } + if strings.TrimSpace(additionalType) == "" { return sqlType } else { diff --git a/scope.go b/scope.go index a11d4ec4..0871842d 100644 --- a/scope.go +++ b/scope.go @@ -1,9 +1,12 @@ package gorm import ( + "database/sql/driver" "errors" "fmt" + "log" "regexp" + "runtime/debug" "strings" "time" @@ -103,6 +106,8 @@ func (scope *Scope) Dialect() Dialect { // Err write error func (scope *Scope) Err(err error) error { if err != nil { + log.Println("ERR", err) + debug.PrintStack() scope.db.AddError(err) } return err @@ -314,13 +319,15 @@ func (scope *Scope) Raw(sql string) *Scope { return scope } +var _, driverResultNoRows = driver.ResultNoRows.RowsAffected() + // Exec invoke sql func (scope *Scope) Exec() *Scope { defer scope.Trace(NowFunc()) if !scope.HasError() { if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { - if count, err := result.RowsAffected(); scope.Err(err) == nil { + if count, err := result.RowsAffected(); err != nil && err.Error() == driverResultNoRows.Error() || scope.Err(err) == nil { scope.db.RowsAffected = count } } @@ -358,6 +365,8 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) { // Trace print sql log func (scope *Scope) Trace(t time.Time) { if len(scope.Sql) > 0 { + // TODO(d4l3k): Remove this line + log.Println("sql", scope.Sql, scope.SqlVars) scope.db.slog(scope.Sql, t, scope.SqlVars...) } } diff --git a/scope_private.go b/scope_private.go index fa5d5f44..269b2a7d 100644 --- a/scope_private.go +++ b/scope_private.go @@ -8,6 +8,7 @@ import ( "regexp" "strconv" "strings" + "time" ) func (scope *Scope) primaryCondition(value interface{}) string { @@ -170,7 +171,12 @@ func (scope *Scope) whereSql() (sql string) { var primaryConditions, andConditions, orConditions []string if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil { - sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName()) + time, err := time.Parse("2006-01-02", "0001-01-02") + if err != nil { + scope.Err(err) + return + } + sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= %v)", scope.QuotedTableName(), scope.QuotedTableName(), scope.AddToVars(time)) primaryConditions = append(primaryConditions, sql) } diff --git a/structs_test.go b/structs_test.go index cb9c9260..da3159af 100644 --- a/structs_test.go +++ b/structs_test.go @@ -39,7 +39,7 @@ type User struct { } type CreditCard struct { - ID int8 + ID uint64 Number string UserId sql.NullInt64 CreatedAt time.Time @@ -48,7 +48,7 @@ type CreditCard struct { } type Email struct { - Id int16 + Id int64 UserId int Email string `sql:"type:varchar(100);"` CreatedAt time.Time diff --git a/test_all.sh b/test_all.sh index 6c5593b3..701018e3 100755 --- a/test_all.sh +++ b/test_all.sh @@ -1,4 +1,4 @@ -dialects=("postgres" "mysql" "sqlite") +dialects=("postgres" "mysql" "sqlite" "cockroach") for dialect in "${dialects[@]}" ; do GORM_DIALECT=${dialect} go test