Lots of progress

This commit is contained in:
Tristan Rice 2016-01-27 17:44:58 -08:00
parent 9739cb853c
commit c88df8eb44
13 changed files with 234 additions and 20 deletions

View File

@ -16,7 +16,7 @@ func TestBelongsTo(t *testing.T) {
}
if err := DB.Save(&post).Error; err != nil {
t.Errorf("Got errors when save post", err.Error())
t.Errorf("Got errors when save post %s", err)
}
if post.Category.ID == 0 || post.MainCategory.ID == 0 {
@ -184,7 +184,7 @@ func TestHasOne(t *testing.T) {
}
if err := DB.Save(&user).Error; err != nil {
t.Errorf("Got errors when save user", err.Error())
t.Errorf("Got errors when save user %s", err.Error())
}
if user.CreditCard.UserId.Int64 == 0 {
@ -331,7 +331,7 @@ func TestHasMany(t *testing.T) {
}
if err := DB.Save(&post).Error; err != nil {
t.Errorf("Got errors when save post", err.Error())
t.Errorf("Got errors when save post %s", err.Error())
}
for _, comment := range post.Comments {

View File

@ -1,7 +1,10 @@
package gorm
import (
"database/sql"
"fmt"
"log"
"reflect"
"strings"
)
@ -28,7 +31,16 @@ func Create(scope *Scope) {
for _, field := range fields {
if scope.changeableField(field) {
if field.IsNormal {
if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
supportPrimary := scope.Dialect().SupportUniquePrimaryKey()
if !field.IsPrimaryKey || (field.IsPrimaryKey && (!field.IsBlank || !supportPrimary)) {
if field.IsPrimaryKey && !supportPrimary && field.IsBlank {
id := scope.Dialect().NewUniqueKey(scope)
if scope.HasError() {
return
}
log.Printf("ID %+v %+v", id, field.Field.Type().String())
field.Field.Set(reflect.ValueOf(id).Convert(field.Field.Type()))
}
if !field.IsBlank || !field.HasDefaultValue {
columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
@ -86,18 +98,37 @@ func Create(scope *Scope) {
}
} else {
if primaryField == nil {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == sql.ErrNoRows {
} else if err == nil {
scope.db.RowsAffected, _ = results.RowsAffected()
} else {
log.Printf("create err no primary %#v eql %#v", err, err == sql.ErrNoRows)
scope.Err(err)
}
} else {
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
} else { // if scope.Dialect().SupportUniquePrimaryKey() {
if err := scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface()); err == nil || err == sql.ErrNoRows {
scope.db.RowsAffected = 1
} else {
log.Printf("create err %#v eql %#v", err, err == sql.ErrNoRows)
scope.Err(err)
}
} /* else {
// Create a new primary key if one is required, not set, and the server doesn't support unique primary keys.
log.Printf("key type %T %#v", val.Interface(), val.Interface())
if key, ok := val.Interface().(*uint); ok && (key == nil || *key == 0) {
val := primaryField.Field.Addr()
id := scope.Dialect().NewUniqueKey(scope)
v := reflect.Indirect(val)
v.SetUint(id)
}
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == sql.ErrNoRows {
} else if err == nil {
scope.db.RowsAffected, _ = results.RowsAffected()
} else {
log.Printf("create err no primary %#v eql %#v", err, err == sql.ErrNoRows)
scope.Err(err)
}
}*/
}
}
}

View File

@ -3,6 +3,7 @@ package gorm
import (
"errors"
"fmt"
"log"
"reflect"
)
@ -83,6 +84,8 @@ func Query(scope *Scope) {
scope.Err(rows.Scan(values...))
log.Println("result values", values)
for index, column := range columns {
value := values[index]
if field, ok := fields[column]; ok {

View File

@ -1,6 +1,9 @@
package gorm
import "reflect"
import (
"database/sql"
"reflect"
)
func BeginTransaction(scope *Scope) {
scope.Begin()
@ -18,7 +21,9 @@ func SaveBeforeAssociations(scope *Scope) {
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
value := field.Field
scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
if err := scope.NewDB().Save(value.Addr().Interface()).Error; err != nil && err != sql.ErrNoRows {
scope.Err(err)
}
if len(relationship.ForeignFieldNames) != 0 {
for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]

141
cockroach.go Normal file
View File

@ -0,0 +1,141 @@
package gorm
import (
"fmt"
"log"
"reflect"
"time"
)
type cockroach struct {
commonDialect
}
func (cockroach) BinVar(i int) string {
return fmt.Sprintf("$%v", i)
}
func (cockroach) SupportLastInsertId() bool {
return false
}
func (cockroach) SupportUniquePrimaryKey() bool {
return false
}
func (cockroach) NewUniqueKey(scope *Scope) uint64 {
rows, err := scope.NewDB().Raw(`SELECT experimental_unique_int()`).Rows()
if err != nil {
scope.Err(err)
return 0
}
var id int64
for rows.Next() {
if err := rows.Scan(&id); err != nil {
log.Fatal("ERR UNIQUE ID", id, err)
scope.Err(err)
return 0
}
}
log.Printf("UNIQUE ID %#v", id)
return uint64(id)
}
func (cockroach) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "BOOLEAN"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "INTEGER PRIMARY KEY"
}
return "INTEGER"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "BIGINT PRIMARY KEY"
}
return "BIGINT"
case reflect.Float32, reflect.Float64:
return "FLOAT"
case reflect.String:
if size > 0 && size < 65532 {
return "VARCHAR"
}
return "TEXT"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
return "TIMESTAMP"
}
default:
if _, ok := value.Interface().([]byte); ok {
return "BYTES"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for cockroach", value.Type().Name(), value.Kind().String()))
}
func (s cockroach) HasTable(scope *Scope, tableName string) bool {
rows, err := scope.NewDB().Raw("show tables").Rows()
if err != nil {
scope.Err(err)
return false
}
defer rows.Close()
var name string
for rows.Next() {
rows.Scan(&name)
if name == tableName {
return true
}
}
return false
}
func (s cockroach) HasColumn(scope *Scope, tableName string, columnName string) bool {
rows, err := scope.NewDB().Raw(fmt.Sprintf("show columns from %s", tableName)).Rows()
if err != nil {
scope.Err(err)
return false
}
defer rows.Close()
var column string
for rows.Next() {
rows.Scan(&column)
if column == columnName {
return true
}
}
return false
}
func (s cockroach) HasIndex(scope *Scope, tableName string, indexName string) bool {
/*
var count int
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
return count > 0
*/
rows, err := scope.NewDB().Raw(fmt.Sprintf("show index from %s", tableName)).Rows()
if err != nil {
scope.Err(err)
return false
}
defer rows.Close()
var name string
for rows.Next() {
rows.Scan(nil, &name)
if name == indexName {
return true
}
}
return false
}
func (cockroach) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Raw(fmt.Sprintf("DROP INDEX %v@%v", scope.QuotedTableName(), indexName)).Error)
}
func (s cockroach) CurrentDatabase(scope *Scope) string {
var name string
s.RawScanString(scope, &name, "SHOW DATABASE")
return name
}

View File

@ -16,6 +16,14 @@ func (commonDialect) SupportLastInsertId() bool {
return true
}
func (commonDialect) SupportUniquePrimaryKey() bool {
return true
}
func (commonDialect) NewUniqueKey(scope *Scope) uint64 {
panic("NewUniqueKey not supported by commonDialect")
}
func (commonDialect) HasTop() bool {
return false
}

View File

@ -8,6 +8,8 @@ import (
type Dialect interface {
BinVar(i int) string
SupportLastInsertId() bool
SupportUniquePrimaryKey() bool
NewUniqueKey(scope *Scope) uint64
HasTop() bool
SqlTag(value reflect.Value, size int, autoIncrease bool) string
ReturningStr(tableName, key string) string
@ -23,6 +25,8 @@ type Dialect interface {
func NewDialect(driver string) Dialect {
var d Dialect
switch driver {
case "cockroach":
d = &cockroach{}
case "postgres":
d = &postgres{}
case "foundation":

View File

@ -6,6 +6,7 @@ import (
"fmt"
"strconv"
_ "github.com/cockroachdb/cockroach/sql/driver"
_ "github.com/denisenkom/go-mssqldb"
testdb "github.com/erikstmartin/go-testdb"
_ "github.com/go-sql-driver/mysql"
@ -53,6 +54,9 @@ func OpenTestConnection() (db gorm.DB, err error) {
case "postgres":
fmt.Println("testing postgres...")
db, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable")
case "cockroach":
fmt.Println("testing cockroach...")
db, err = gorm.Open("cockroach", "http://localhost:26257?database=gorm&user=root")
case "foundation":
fmt.Println("testing foundation...")
db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable")

View File

@ -524,11 +524,6 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
sqlType = value
}
additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
if value, ok := field.TagSettings["DEFAULT"]; ok {
additionalType = additionalType + " DEFAULT " + value
}
if field.IsScanner {
var getScannerValue func(reflect.Value)
getScannerValue = func(value reflect.Value) {
@ -558,6 +553,14 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease)
}
additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
if value, ok := field.TagSettings["DEFAULT"]; ok {
if _, ok := scope.Dialect().(*cockroach); ok && strings.TrimSpace(strings.ToLower(value)) == "null" {
value = value + "::" + strings.Split(sqlType, " ")[0]
}
additionalType = additionalType + " DEFAULT " + value
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
} else {

View File

@ -1,9 +1,12 @@
package gorm
import (
"database/sql/driver"
"errors"
"fmt"
"log"
"regexp"
"runtime/debug"
"strings"
"time"
@ -103,6 +106,8 @@ func (scope *Scope) Dialect() Dialect {
// Err write error
func (scope *Scope) Err(err error) error {
if err != nil {
log.Println("ERR", err)
debug.PrintStack()
scope.db.AddError(err)
}
return err
@ -314,13 +319,15 @@ func (scope *Scope) Raw(sql string) *Scope {
return scope
}
var _, driverResultNoRows = driver.ResultNoRows.RowsAffected()
// Exec invoke sql
func (scope *Scope) Exec() *Scope {
defer scope.Trace(NowFunc())
if !scope.HasError() {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
if count, err := result.RowsAffected(); scope.Err(err) == nil {
if count, err := result.RowsAffected(); err != nil && err.Error() == driverResultNoRows.Error() || scope.Err(err) == nil {
scope.db.RowsAffected = count
}
}
@ -358,6 +365,8 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
// Trace print sql log
func (scope *Scope) Trace(t time.Time) {
if len(scope.Sql) > 0 {
// TODO(d4l3k): Remove this line
log.Println("sql", scope.Sql, scope.SqlVars)
scope.db.slog(scope.Sql, t, scope.SqlVars...)
}
}

View File

@ -8,6 +8,7 @@ import (
"regexp"
"strconv"
"strings"
"time"
)
func (scope *Scope) primaryCondition(value interface{}) string {
@ -170,7 +171,12 @@ func (scope *Scope) whereSql() (sql string) {
var primaryConditions, andConditions, orConditions []string
if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil {
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName())
time, err := time.Parse("2006-01-02", "0001-01-02")
if err != nil {
scope.Err(err)
return
}
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= %v)", scope.QuotedTableName(), scope.QuotedTableName(), scope.AddToVars(time))
primaryConditions = append(primaryConditions, sql)
}

View File

@ -39,7 +39,7 @@ type User struct {
}
type CreditCard struct {
ID int8
ID uint64
Number string
UserId sql.NullInt64
CreatedAt time.Time
@ -48,7 +48,7 @@ type CreditCard struct {
}
type Email struct {
Id int16
Id int64
UserId int
Email string `sql:"type:varchar(100);"`
CreatedAt time.Time

View File

@ -1,4 +1,4 @@
dialects=("postgres" "mysql" "sqlite")
dialects=("postgres" "mysql" "sqlite" "cockroach")
for dialect in "${dialects[@]}" ; do
GORM_DIALECT=${dialect} go test