Lots of progress
This commit is contained in:
parent
9739cb853c
commit
c88df8eb44
@ -16,7 +16,7 @@ func TestBelongsTo(t *testing.T) {
|
||||
}
|
||||
|
||||
if err := DB.Save(&post).Error; err != nil {
|
||||
t.Errorf("Got errors when save post", err.Error())
|
||||
t.Errorf("Got errors when save post %s", err)
|
||||
}
|
||||
|
||||
if post.Category.ID == 0 || post.MainCategory.ID == 0 {
|
||||
@ -184,7 +184,7 @@ func TestHasOne(t *testing.T) {
|
||||
}
|
||||
|
||||
if err := DB.Save(&user).Error; err != nil {
|
||||
t.Errorf("Got errors when save user", err.Error())
|
||||
t.Errorf("Got errors when save user %s", err.Error())
|
||||
}
|
||||
|
||||
if user.CreditCard.UserId.Int64 == 0 {
|
||||
@ -331,7 +331,7 @@ func TestHasMany(t *testing.T) {
|
||||
}
|
||||
|
||||
if err := DB.Save(&post).Error; err != nil {
|
||||
t.Errorf("Got errors when save post", err.Error())
|
||||
t.Errorf("Got errors when save post %s", err.Error())
|
||||
}
|
||||
|
||||
for _, comment := range post.Comments {
|
||||
|
@ -1,7 +1,10 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@ -28,7 +31,16 @@ func Create(scope *Scope) {
|
||||
for _, field := range fields {
|
||||
if scope.changeableField(field) {
|
||||
if field.IsNormal {
|
||||
if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
|
||||
supportPrimary := scope.Dialect().SupportUniquePrimaryKey()
|
||||
if !field.IsPrimaryKey || (field.IsPrimaryKey && (!field.IsBlank || !supportPrimary)) {
|
||||
if field.IsPrimaryKey && !supportPrimary && field.IsBlank {
|
||||
id := scope.Dialect().NewUniqueKey(scope)
|
||||
if scope.HasError() {
|
||||
return
|
||||
}
|
||||
log.Printf("ID %+v %+v", id, field.Field.Type().String())
|
||||
field.Field.Set(reflect.ValueOf(id).Convert(field.Field.Type()))
|
||||
}
|
||||
if !field.IsBlank || !field.HasDefaultValue {
|
||||
columns = append(columns, scope.Quote(field.DBName))
|
||||
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
|
||||
@ -86,18 +98,37 @@ func Create(scope *Scope) {
|
||||
}
|
||||
} else {
|
||||
if primaryField == nil {
|
||||
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
|
||||
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == sql.ErrNoRows {
|
||||
} else if err == nil {
|
||||
scope.db.RowsAffected, _ = results.RowsAffected()
|
||||
} else {
|
||||
log.Printf("create err no primary %#v eql %#v", err, err == sql.ErrNoRows)
|
||||
scope.Err(err)
|
||||
}
|
||||
} else {
|
||||
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
|
||||
} else { // if scope.Dialect().SupportUniquePrimaryKey() {
|
||||
if err := scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface()); err == nil || err == sql.ErrNoRows {
|
||||
scope.db.RowsAffected = 1
|
||||
} else {
|
||||
log.Printf("create err %#v eql %#v", err, err == sql.ErrNoRows)
|
||||
scope.Err(err)
|
||||
}
|
||||
} /* else {
|
||||
// Create a new primary key if one is required, not set, and the server doesn't support unique primary keys.
|
||||
log.Printf("key type %T %#v", val.Interface(), val.Interface())
|
||||
if key, ok := val.Interface().(*uint); ok && (key == nil || *key == 0) {
|
||||
val := primaryField.Field.Addr()
|
||||
id := scope.Dialect().NewUniqueKey(scope)
|
||||
v := reflect.Indirect(val)
|
||||
v.SetUint(id)
|
||||
}
|
||||
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == sql.ErrNoRows {
|
||||
} else if err == nil {
|
||||
scope.db.RowsAffected, _ = results.RowsAffected()
|
||||
} else {
|
||||
log.Printf("create err no primary %#v eql %#v", err, err == sql.ErrNoRows)
|
||||
scope.Err(err)
|
||||
}
|
||||
}*/
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package gorm
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
@ -83,6 +84,8 @@ func Query(scope *Scope) {
|
||||
|
||||
scope.Err(rows.Scan(values...))
|
||||
|
||||
log.Println("result values", values)
|
||||
|
||||
for index, column := range columns {
|
||||
value := values[index]
|
||||
if field, ok := fields[column]; ok {
|
||||
|
@ -1,6 +1,9 @@
|
||||
package gorm
|
||||
|
||||
import "reflect"
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func BeginTransaction(scope *Scope) {
|
||||
scope.Begin()
|
||||
@ -18,7 +21,9 @@ func SaveBeforeAssociations(scope *Scope) {
|
||||
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
|
||||
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||
value := field.Field
|
||||
scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
|
||||
if err := scope.NewDB().Save(value.Addr().Interface()).Error; err != nil && err != sql.ErrNoRows {
|
||||
scope.Err(err)
|
||||
}
|
||||
if len(relationship.ForeignFieldNames) != 0 {
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||
|
141
cockroach.go
Normal file
141
cockroach.go
Normal file
@ -0,0 +1,141 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
type cockroach struct {
|
||||
commonDialect
|
||||
}
|
||||
|
||||
func (cockroach) BinVar(i int) string {
|
||||
return fmt.Sprintf("$%v", i)
|
||||
}
|
||||
|
||||
func (cockroach) SupportLastInsertId() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (cockroach) SupportUniquePrimaryKey() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (cockroach) NewUniqueKey(scope *Scope) uint64 {
|
||||
rows, err := scope.NewDB().Raw(`SELECT experimental_unique_int()`).Rows()
|
||||
if err != nil {
|
||||
scope.Err(err)
|
||||
return 0
|
||||
}
|
||||
var id int64
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
log.Fatal("ERR UNIQUE ID", id, err)
|
||||
scope.Err(err)
|
||||
return 0
|
||||
}
|
||||
}
|
||||
log.Printf("UNIQUE ID %#v", id)
|
||||
return uint64(id)
|
||||
}
|
||||
|
||||
func (cockroach) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "BOOLEAN"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
return "INTEGER PRIMARY KEY"
|
||||
}
|
||||
return "INTEGER"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if autoIncrease {
|
||||
return "BIGINT PRIMARY KEY"
|
||||
}
|
||||
return "BIGINT"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "FLOAT"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return "VARCHAR"
|
||||
}
|
||||
return "TEXT"
|
||||
case reflect.Struct:
|
||||
if _, ok := value.Interface().(time.Time); ok {
|
||||
return "TIMESTAMP"
|
||||
}
|
||||
default:
|
||||
if _, ok := value.Interface().([]byte); ok {
|
||||
return "BYTES"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for cockroach", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s cockroach) HasTable(scope *Scope, tableName string) bool {
|
||||
rows, err := scope.NewDB().Raw("show tables").Rows()
|
||||
if err != nil {
|
||||
scope.Err(err)
|
||||
return false
|
||||
}
|
||||
defer rows.Close()
|
||||
var name string
|
||||
for rows.Next() {
|
||||
rows.Scan(&name)
|
||||
if name == tableName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s cockroach) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
||||
rows, err := scope.NewDB().Raw(fmt.Sprintf("show columns from %s", tableName)).Rows()
|
||||
if err != nil {
|
||||
scope.Err(err)
|
||||
return false
|
||||
}
|
||||
defer rows.Close()
|
||||
var column string
|
||||
for rows.Next() {
|
||||
rows.Scan(&column)
|
||||
if column == columnName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s cockroach) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
/*
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
|
||||
return count > 0
|
||||
*/
|
||||
rows, err := scope.NewDB().Raw(fmt.Sprintf("show index from %s", tableName)).Rows()
|
||||
if err != nil {
|
||||
scope.Err(err)
|
||||
return false
|
||||
}
|
||||
defer rows.Close()
|
||||
var name string
|
||||
for rows.Next() {
|
||||
rows.Scan(nil, &name)
|
||||
if name == indexName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (cockroach) RemoveIndex(scope *Scope, indexName string) {
|
||||
scope.Err(scope.NewDB().Raw(fmt.Sprintf("DROP INDEX %v@%v", scope.QuotedTableName(), indexName)).Error)
|
||||
}
|
||||
|
||||
func (s cockroach) CurrentDatabase(scope *Scope) string {
|
||||
var name string
|
||||
s.RawScanString(scope, &name, "SHOW DATABASE")
|
||||
return name
|
||||
}
|
@ -16,6 +16,14 @@ func (commonDialect) SupportLastInsertId() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (commonDialect) SupportUniquePrimaryKey() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (commonDialect) NewUniqueKey(scope *Scope) uint64 {
|
||||
panic("NewUniqueKey not supported by commonDialect")
|
||||
}
|
||||
|
||||
func (commonDialect) HasTop() bool {
|
||||
return false
|
||||
}
|
||||
|
@ -8,6 +8,8 @@ import (
|
||||
type Dialect interface {
|
||||
BinVar(i int) string
|
||||
SupportLastInsertId() bool
|
||||
SupportUniquePrimaryKey() bool
|
||||
NewUniqueKey(scope *Scope) uint64
|
||||
HasTop() bool
|
||||
SqlTag(value reflect.Value, size int, autoIncrease bool) string
|
||||
ReturningStr(tableName, key string) string
|
||||
@ -23,6 +25,8 @@ type Dialect interface {
|
||||
func NewDialect(driver string) Dialect {
|
||||
var d Dialect
|
||||
switch driver {
|
||||
case "cockroach":
|
||||
d = &cockroach{}
|
||||
case "postgres":
|
||||
d = &postgres{}
|
||||
case "foundation":
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
_ "github.com/cockroachdb/cockroach/sql/driver"
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
testdb "github.com/erikstmartin/go-testdb"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
@ -53,6 +54,9 @@ func OpenTestConnection() (db gorm.DB, err error) {
|
||||
case "postgres":
|
||||
fmt.Println("testing postgres...")
|
||||
db, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable")
|
||||
case "cockroach":
|
||||
fmt.Println("testing cockroach...")
|
||||
db, err = gorm.Open("cockroach", "http://localhost:26257?database=gorm&user=root")
|
||||
case "foundation":
|
||||
fmt.Println("testing foundation...")
|
||||
db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable")
|
||||
|
@ -524,11 +524,6 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
|
||||
sqlType = value
|
||||
}
|
||||
|
||||
additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
|
||||
if value, ok := field.TagSettings["DEFAULT"]; ok {
|
||||
additionalType = additionalType + " DEFAULT " + value
|
||||
}
|
||||
|
||||
if field.IsScanner {
|
||||
var getScannerValue func(reflect.Value)
|
||||
getScannerValue = func(value reflect.Value) {
|
||||
@ -558,6 +553,14 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
|
||||
sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease)
|
||||
}
|
||||
|
||||
additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
|
||||
if value, ok := field.TagSettings["DEFAULT"]; ok {
|
||||
if _, ok := scope.Dialect().(*cockroach); ok && strings.TrimSpace(strings.ToLower(value)) == "null" {
|
||||
value = value + "::" + strings.Split(sqlType, " ")[0]
|
||||
}
|
||||
additionalType = additionalType + " DEFAULT " + value
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
} else {
|
||||
|
11
scope.go
11
scope.go
@ -1,9 +1,12 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"regexp"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -103,6 +106,8 @@ func (scope *Scope) Dialect() Dialect {
|
||||
// Err write error
|
||||
func (scope *Scope) Err(err error) error {
|
||||
if err != nil {
|
||||
log.Println("ERR", err)
|
||||
debug.PrintStack()
|
||||
scope.db.AddError(err)
|
||||
}
|
||||
return err
|
||||
@ -314,13 +319,15 @@ func (scope *Scope) Raw(sql string) *Scope {
|
||||
return scope
|
||||
}
|
||||
|
||||
var _, driverResultNoRows = driver.ResultNoRows.RowsAffected()
|
||||
|
||||
// Exec invoke sql
|
||||
func (scope *Scope) Exec() *Scope {
|
||||
defer scope.Trace(NowFunc())
|
||||
|
||||
if !scope.HasError() {
|
||||
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||
if count, err := result.RowsAffected(); scope.Err(err) == nil {
|
||||
if count, err := result.RowsAffected(); err != nil && err.Error() == driverResultNoRows.Error() || scope.Err(err) == nil {
|
||||
scope.db.RowsAffected = count
|
||||
}
|
||||
}
|
||||
@ -358,6 +365,8 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
|
||||
// Trace print sql log
|
||||
func (scope *Scope) Trace(t time.Time) {
|
||||
if len(scope.Sql) > 0 {
|
||||
// TODO(d4l3k): Remove this line
|
||||
log.Println("sql", scope.Sql, scope.SqlVars)
|
||||
scope.db.slog(scope.Sql, t, scope.SqlVars...)
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (scope *Scope) primaryCondition(value interface{}) string {
|
||||
@ -170,7 +171,12 @@ func (scope *Scope) whereSql() (sql string) {
|
||||
var primaryConditions, andConditions, orConditions []string
|
||||
|
||||
if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil {
|
||||
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName())
|
||||
time, err := time.Parse("2006-01-02", "0001-01-02")
|
||||
if err != nil {
|
||||
scope.Err(err)
|
||||
return
|
||||
}
|
||||
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= %v)", scope.QuotedTableName(), scope.QuotedTableName(), scope.AddToVars(time))
|
||||
primaryConditions = append(primaryConditions, sql)
|
||||
}
|
||||
|
||||
|
@ -39,7 +39,7 @@ type User struct {
|
||||
}
|
||||
|
||||
type CreditCard struct {
|
||||
ID int8
|
||||
ID uint64
|
||||
Number string
|
||||
UserId sql.NullInt64
|
||||
CreatedAt time.Time
|
||||
@ -48,7 +48,7 @@ type CreditCard struct {
|
||||
}
|
||||
|
||||
type Email struct {
|
||||
Id int16
|
||||
Id int64
|
||||
UserId int
|
||||
Email string `sql:"type:varchar(100);"`
|
||||
CreatedAt time.Time
|
||||
|
@ -1,4 +1,4 @@
|
||||
dialects=("postgres" "mysql" "sqlite")
|
||||
dialects=("postgres" "mysql" "sqlite" "cockroach")
|
||||
|
||||
for dialect in "${dialects[@]}" ; do
|
||||
GORM_DIALECT=${dialect} go test
|
||||
|
Loading…
x
Reference in New Issue
Block a user