Lots of progress
This commit is contained in:
		
							parent
							
								
									9739cb853c
								
							
						
					
					
						commit
						c88df8eb44
					
				@ -16,7 +16,7 @@ func TestBelongsTo(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := DB.Save(&post).Error; err != nil {
 | 
						if err := DB.Save(&post).Error; err != nil {
 | 
				
			||||||
		t.Errorf("Got errors when save post", err.Error())
 | 
							t.Errorf("Got errors when save post %s", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if post.Category.ID == 0 || post.MainCategory.ID == 0 {
 | 
						if post.Category.ID == 0 || post.MainCategory.ID == 0 {
 | 
				
			||||||
@ -184,7 +184,7 @@ func TestHasOne(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := DB.Save(&user).Error; err != nil {
 | 
						if err := DB.Save(&user).Error; err != nil {
 | 
				
			||||||
		t.Errorf("Got errors when save user", err.Error())
 | 
							t.Errorf("Got errors when save user %s", err.Error())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if user.CreditCard.UserId.Int64 == 0 {
 | 
						if user.CreditCard.UserId.Int64 == 0 {
 | 
				
			||||||
@ -331,7 +331,7 @@ func TestHasMany(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := DB.Save(&post).Error; err != nil {
 | 
						if err := DB.Save(&post).Error; err != nil {
 | 
				
			||||||
		t.Errorf("Got errors when save post", err.Error())
 | 
							t.Errorf("Got errors when save post %s", err.Error())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, comment := range post.Comments {
 | 
						for _, comment := range post.Comments {
 | 
				
			||||||
 | 
				
			|||||||
@ -1,7 +1,10 @@
 | 
				
			|||||||
package gorm
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"database/sql"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"log"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -28,7 +31,16 @@ func Create(scope *Scope) {
 | 
				
			|||||||
		for _, field := range fields {
 | 
							for _, field := range fields {
 | 
				
			||||||
			if scope.changeableField(field) {
 | 
								if scope.changeableField(field) {
 | 
				
			||||||
				if field.IsNormal {
 | 
									if field.IsNormal {
 | 
				
			||||||
					if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
 | 
										supportPrimary := scope.Dialect().SupportUniquePrimaryKey()
 | 
				
			||||||
 | 
										if !field.IsPrimaryKey || (field.IsPrimaryKey && (!field.IsBlank || !supportPrimary)) {
 | 
				
			||||||
 | 
											if field.IsPrimaryKey && !supportPrimary && field.IsBlank {
 | 
				
			||||||
 | 
												id := scope.Dialect().NewUniqueKey(scope)
 | 
				
			||||||
 | 
												if scope.HasError() {
 | 
				
			||||||
 | 
													return
 | 
				
			||||||
 | 
												}
 | 
				
			||||||
 | 
												log.Printf("ID %+v %+v", id, field.Field.Type().String())
 | 
				
			||||||
 | 
												field.Field.Set(reflect.ValueOf(id).Convert(field.Field.Type()))
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
						if !field.IsBlank || !field.HasDefaultValue {
 | 
											if !field.IsBlank || !field.HasDefaultValue {
 | 
				
			||||||
							columns = append(columns, scope.Quote(field.DBName))
 | 
												columns = append(columns, scope.Quote(field.DBName))
 | 
				
			||||||
							sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
 | 
												sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
 | 
				
			||||||
@ -86,18 +98,37 @@ func Create(scope *Scope) {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			if primaryField == nil {
 | 
								if primaryField == nil {
 | 
				
			||||||
				if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
 | 
									if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == sql.ErrNoRows {
 | 
				
			||||||
 | 
									} else if err == nil {
 | 
				
			||||||
					scope.db.RowsAffected, _ = results.RowsAffected()
 | 
										scope.db.RowsAffected, _ = results.RowsAffected()
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
 | 
										log.Printf("create err no primary %#v eql %#v", err, err == sql.ErrNoRows)
 | 
				
			||||||
					scope.Err(err)
 | 
										scope.Err(err)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			} else {
 | 
								} else { // if scope.Dialect().SupportUniquePrimaryKey() {
 | 
				
			||||||
				if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
 | 
									if err := scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface()); err == nil || err == sql.ErrNoRows {
 | 
				
			||||||
					scope.db.RowsAffected = 1
 | 
										scope.db.RowsAffected = 1
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
 | 
										log.Printf("create err %#v eql %#v", err, err == sql.ErrNoRows)
 | 
				
			||||||
					scope.Err(err)
 | 
										scope.Err(err)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								} /* else {
 | 
				
			||||||
 | 
									// Create a new primary key if one is required, not set, and the server doesn't support unique primary keys.
 | 
				
			||||||
 | 
									log.Printf("key type %T %#v", val.Interface(), val.Interface())
 | 
				
			||||||
 | 
									if key, ok := val.Interface().(*uint); ok && (key == nil || *key == 0) {
 | 
				
			||||||
 | 
									val := primaryField.Field.Addr()
 | 
				
			||||||
 | 
										id := scope.Dialect().NewUniqueKey(scope)
 | 
				
			||||||
 | 
										v := reflect.Indirect(val)
 | 
				
			||||||
 | 
										v.SetUint(id)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == sql.ErrNoRows {
 | 
				
			||||||
 | 
									} else if err == nil {
 | 
				
			||||||
 | 
										scope.db.RowsAffected, _ = results.RowsAffected()
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										log.Printf("create err no primary %#v eql %#v", err, err == sql.ErrNoRows)
 | 
				
			||||||
 | 
										scope.Err(err)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}*/
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -3,6 +3,7 @@ package gorm
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"log"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -83,6 +84,8 @@ func Query(scope *Scope) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			scope.Err(rows.Scan(values...))
 | 
								scope.Err(rows.Scan(values...))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								log.Println("result values", values)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			for index, column := range columns {
 | 
								for index, column := range columns {
 | 
				
			||||||
				value := values[index]
 | 
									value := values[index]
 | 
				
			||||||
				if field, ok := fields[column]; ok {
 | 
									if field, ok := fields[column]; ok {
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,9 @@
 | 
				
			|||||||
package gorm
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import "reflect"
 | 
					import (
 | 
				
			||||||
 | 
						"database/sql"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func BeginTransaction(scope *Scope) {
 | 
					func BeginTransaction(scope *Scope) {
 | 
				
			||||||
	scope.Begin()
 | 
						scope.Begin()
 | 
				
			||||||
@ -18,7 +21,9 @@ func SaveBeforeAssociations(scope *Scope) {
 | 
				
			|||||||
		if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
 | 
							if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
 | 
				
			||||||
			if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
 | 
								if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
 | 
				
			||||||
				value := field.Field
 | 
									value := field.Field
 | 
				
			||||||
				scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
 | 
									if err := scope.NewDB().Save(value.Addr().Interface()).Error; err != nil && err != sql.ErrNoRows {
 | 
				
			||||||
 | 
										scope.Err(err)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
				if len(relationship.ForeignFieldNames) != 0 {
 | 
									if len(relationship.ForeignFieldNames) != 0 {
 | 
				
			||||||
					for idx, fieldName := range relationship.ForeignFieldNames {
 | 
										for idx, fieldName := range relationship.ForeignFieldNames {
 | 
				
			||||||
						associationForeignName := relationship.AssociationForeignDBNames[idx]
 | 
											associationForeignName := relationship.AssociationForeignDBNames[idx]
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										141
									
								
								cockroach.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										141
									
								
								cockroach.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,141 @@
 | 
				
			|||||||
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"log"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type cockroach struct {
 | 
				
			||||||
 | 
						commonDialect
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (cockroach) BinVar(i int) string {
 | 
				
			||||||
 | 
						return fmt.Sprintf("$%v", i)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (cockroach) SupportLastInsertId() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (cockroach) SupportUniquePrimaryKey() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (cockroach) NewUniqueKey(scope *Scope) uint64 {
 | 
				
			||||||
 | 
						rows, err := scope.NewDB().Raw(`SELECT experimental_unique_int()`).Rows()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							scope.Err(err)
 | 
				
			||||||
 | 
							return 0
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var id int64
 | 
				
			||||||
 | 
						for rows.Next() {
 | 
				
			||||||
 | 
							if err := rows.Scan(&id); err != nil {
 | 
				
			||||||
 | 
								log.Fatal("ERR UNIQUE ID", id, err)
 | 
				
			||||||
 | 
								scope.Err(err)
 | 
				
			||||||
 | 
								return 0
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						log.Printf("UNIQUE ID %#v", id)
 | 
				
			||||||
 | 
						return uint64(id)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (cockroach) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
 | 
				
			||||||
 | 
						switch value.Kind() {
 | 
				
			||||||
 | 
						case reflect.Bool:
 | 
				
			||||||
 | 
							return "BOOLEAN"
 | 
				
			||||||
 | 
						case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
 | 
				
			||||||
 | 
							if autoIncrease {
 | 
				
			||||||
 | 
								return "INTEGER PRIMARY KEY"
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return "INTEGER"
 | 
				
			||||||
 | 
						case reflect.Int64, reflect.Uint64:
 | 
				
			||||||
 | 
							if autoIncrease {
 | 
				
			||||||
 | 
								return "BIGINT PRIMARY KEY"
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return "BIGINT"
 | 
				
			||||||
 | 
						case reflect.Float32, reflect.Float64:
 | 
				
			||||||
 | 
							return "FLOAT"
 | 
				
			||||||
 | 
						case reflect.String:
 | 
				
			||||||
 | 
							if size > 0 && size < 65532 {
 | 
				
			||||||
 | 
								return "VARCHAR"
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return "TEXT"
 | 
				
			||||||
 | 
						case reflect.Struct:
 | 
				
			||||||
 | 
							if _, ok := value.Interface().(time.Time); ok {
 | 
				
			||||||
 | 
								return "TIMESTAMP"
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							if _, ok := value.Interface().([]byte); ok {
 | 
				
			||||||
 | 
								return "BYTES"
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						panic(fmt.Sprintf("invalid sql type %s (%s) for cockroach", value.Type().Name(), value.Kind().String()))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s cockroach) HasTable(scope *Scope, tableName string) bool {
 | 
				
			||||||
 | 
						rows, err := scope.NewDB().Raw("show tables").Rows()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							scope.Err(err)
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer rows.Close()
 | 
				
			||||||
 | 
						var name string
 | 
				
			||||||
 | 
						for rows.Next() {
 | 
				
			||||||
 | 
							rows.Scan(&name)
 | 
				
			||||||
 | 
							if name == tableName {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s cockroach) HasColumn(scope *Scope, tableName string, columnName string) bool {
 | 
				
			||||||
 | 
						rows, err := scope.NewDB().Raw(fmt.Sprintf("show columns from %s", tableName)).Rows()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							scope.Err(err)
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer rows.Close()
 | 
				
			||||||
 | 
						var column string
 | 
				
			||||||
 | 
						for rows.Next() {
 | 
				
			||||||
 | 
							rows.Scan(&column)
 | 
				
			||||||
 | 
							if column == columnName {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s cockroach) HasIndex(scope *Scope, tableName string, indexName string) bool {
 | 
				
			||||||
 | 
						/*
 | 
				
			||||||
 | 
							var count int
 | 
				
			||||||
 | 
							s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
 | 
				
			||||||
 | 
							return count > 0
 | 
				
			||||||
 | 
						*/
 | 
				
			||||||
 | 
						rows, err := scope.NewDB().Raw(fmt.Sprintf("show index from %s", tableName)).Rows()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							scope.Err(err)
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer rows.Close()
 | 
				
			||||||
 | 
						var name string
 | 
				
			||||||
 | 
						for rows.Next() {
 | 
				
			||||||
 | 
							rows.Scan(nil, &name)
 | 
				
			||||||
 | 
							if name == indexName {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (cockroach) RemoveIndex(scope *Scope, indexName string) {
 | 
				
			||||||
 | 
						scope.Err(scope.NewDB().Raw(fmt.Sprintf("DROP INDEX %v@%v", scope.QuotedTableName(), indexName)).Error)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s cockroach) CurrentDatabase(scope *Scope) string {
 | 
				
			||||||
 | 
						var name string
 | 
				
			||||||
 | 
						s.RawScanString(scope, &name, "SHOW DATABASE")
 | 
				
			||||||
 | 
						return name
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -16,6 +16,14 @@ func (commonDialect) SupportLastInsertId() bool {
 | 
				
			|||||||
	return true
 | 
						return true
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (commonDialect) SupportUniquePrimaryKey() bool {
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (commonDialect) NewUniqueKey(scope *Scope) uint64 {
 | 
				
			||||||
 | 
						panic("NewUniqueKey not supported by commonDialect")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (commonDialect) HasTop() bool {
 | 
					func (commonDialect) HasTop() bool {
 | 
				
			||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -8,6 +8,8 @@ import (
 | 
				
			|||||||
type Dialect interface {
 | 
					type Dialect interface {
 | 
				
			||||||
	BinVar(i int) string
 | 
						BinVar(i int) string
 | 
				
			||||||
	SupportLastInsertId() bool
 | 
						SupportLastInsertId() bool
 | 
				
			||||||
 | 
						SupportUniquePrimaryKey() bool
 | 
				
			||||||
 | 
						NewUniqueKey(scope *Scope) uint64
 | 
				
			||||||
	HasTop() bool
 | 
						HasTop() bool
 | 
				
			||||||
	SqlTag(value reflect.Value, size int, autoIncrease bool) string
 | 
						SqlTag(value reflect.Value, size int, autoIncrease bool) string
 | 
				
			||||||
	ReturningStr(tableName, key string) string
 | 
						ReturningStr(tableName, key string) string
 | 
				
			||||||
@ -23,6 +25,8 @@ type Dialect interface {
 | 
				
			|||||||
func NewDialect(driver string) Dialect {
 | 
					func NewDialect(driver string) Dialect {
 | 
				
			||||||
	var d Dialect
 | 
						var d Dialect
 | 
				
			||||||
	switch driver {
 | 
						switch driver {
 | 
				
			||||||
 | 
						case "cockroach":
 | 
				
			||||||
 | 
							d = &cockroach{}
 | 
				
			||||||
	case "postgres":
 | 
						case "postgres":
 | 
				
			||||||
		d = &postgres{}
 | 
							d = &postgres{}
 | 
				
			||||||
	case "foundation":
 | 
						case "foundation":
 | 
				
			||||||
 | 
				
			|||||||
@ -6,6 +6,7 @@ import (
 | 
				
			|||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_ "github.com/cockroachdb/cockroach/sql/driver"
 | 
				
			||||||
	_ "github.com/denisenkom/go-mssqldb"
 | 
						_ "github.com/denisenkom/go-mssqldb"
 | 
				
			||||||
	testdb "github.com/erikstmartin/go-testdb"
 | 
						testdb "github.com/erikstmartin/go-testdb"
 | 
				
			||||||
	_ "github.com/go-sql-driver/mysql"
 | 
						_ "github.com/go-sql-driver/mysql"
 | 
				
			||||||
@ -53,6 +54,9 @@ func OpenTestConnection() (db gorm.DB, err error) {
 | 
				
			|||||||
	case "postgres":
 | 
						case "postgres":
 | 
				
			||||||
		fmt.Println("testing postgres...")
 | 
							fmt.Println("testing postgres...")
 | 
				
			||||||
		db, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable")
 | 
							db, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable")
 | 
				
			||||||
 | 
						case "cockroach":
 | 
				
			||||||
 | 
							fmt.Println("testing cockroach...")
 | 
				
			||||||
 | 
							db, err = gorm.Open("cockroach", "http://localhost:26257?database=gorm&user=root")
 | 
				
			||||||
	case "foundation":
 | 
						case "foundation":
 | 
				
			||||||
		fmt.Println("testing foundation...")
 | 
							fmt.Println("testing foundation...")
 | 
				
			||||||
		db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable")
 | 
							db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable")
 | 
				
			||||||
 | 
				
			|||||||
@ -524,11 +524,6 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
 | 
				
			|||||||
		sqlType = value
 | 
							sqlType = value
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
 | 
					 | 
				
			||||||
	if value, ok := field.TagSettings["DEFAULT"]; ok {
 | 
					 | 
				
			||||||
		additionalType = additionalType + " DEFAULT " + value
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if field.IsScanner {
 | 
						if field.IsScanner {
 | 
				
			||||||
		var getScannerValue func(reflect.Value)
 | 
							var getScannerValue func(reflect.Value)
 | 
				
			||||||
		getScannerValue = func(value reflect.Value) {
 | 
							getScannerValue = func(value reflect.Value) {
 | 
				
			||||||
@ -558,6 +553,14 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
 | 
				
			|||||||
		sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease)
 | 
							sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
 | 
				
			||||||
 | 
						if value, ok := field.TagSettings["DEFAULT"]; ok {
 | 
				
			||||||
 | 
							if _, ok := scope.Dialect().(*cockroach); ok && strings.TrimSpace(strings.ToLower(value)) == "null" {
 | 
				
			||||||
 | 
								value = value + "::" + strings.Split(sqlType, " ")[0]
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							additionalType = additionalType + " DEFAULT " + value
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if strings.TrimSpace(additionalType) == "" {
 | 
						if strings.TrimSpace(additionalType) == "" {
 | 
				
			||||||
		return sqlType
 | 
							return sqlType
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										11
									
								
								scope.go
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								scope.go
									
									
									
									
									
								
							@ -1,9 +1,12 @@
 | 
				
			|||||||
package gorm
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"database/sql/driver"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"log"
 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
 | 
						"runtime/debug"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -103,6 +106,8 @@ func (scope *Scope) Dialect() Dialect {
 | 
				
			|||||||
// Err write error
 | 
					// Err write error
 | 
				
			||||||
func (scope *Scope) Err(err error) error {
 | 
					func (scope *Scope) Err(err error) error {
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Println("ERR", err)
 | 
				
			||||||
 | 
							debug.PrintStack()
 | 
				
			||||||
		scope.db.AddError(err)
 | 
							scope.db.AddError(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return err
 | 
						return err
 | 
				
			||||||
@ -314,13 +319,15 @@ func (scope *Scope) Raw(sql string) *Scope {
 | 
				
			|||||||
	return scope
 | 
						return scope
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var _, driverResultNoRows = driver.ResultNoRows.RowsAffected()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Exec invoke sql
 | 
					// Exec invoke sql
 | 
				
			||||||
func (scope *Scope) Exec() *Scope {
 | 
					func (scope *Scope) Exec() *Scope {
 | 
				
			||||||
	defer scope.Trace(NowFunc())
 | 
						defer scope.Trace(NowFunc())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !scope.HasError() {
 | 
						if !scope.HasError() {
 | 
				
			||||||
		if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
 | 
							if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
 | 
				
			||||||
			if count, err := result.RowsAffected(); scope.Err(err) == nil {
 | 
								if count, err := result.RowsAffected(); err != nil && err.Error() == driverResultNoRows.Error() || scope.Err(err) == nil {
 | 
				
			||||||
				scope.db.RowsAffected = count
 | 
									scope.db.RowsAffected = count
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -358,6 +365,8 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
 | 
				
			|||||||
// Trace print sql log
 | 
					// Trace print sql log
 | 
				
			||||||
func (scope *Scope) Trace(t time.Time) {
 | 
					func (scope *Scope) Trace(t time.Time) {
 | 
				
			||||||
	if len(scope.Sql) > 0 {
 | 
						if len(scope.Sql) > 0 {
 | 
				
			||||||
 | 
							// TODO(d4l3k): Remove this line
 | 
				
			||||||
 | 
							log.Println("sql", scope.Sql, scope.SqlVars)
 | 
				
			||||||
		scope.db.slog(scope.Sql, t, scope.SqlVars...)
 | 
							scope.db.slog(scope.Sql, t, scope.SqlVars...)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -8,6 +8,7 @@ import (
 | 
				
			|||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) primaryCondition(value interface{}) string {
 | 
					func (scope *Scope) primaryCondition(value interface{}) string {
 | 
				
			||||||
@ -170,7 +171,12 @@ func (scope *Scope) whereSql() (sql string) {
 | 
				
			|||||||
	var primaryConditions, andConditions, orConditions []string
 | 
						var primaryConditions, andConditions, orConditions []string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil {
 | 
						if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil {
 | 
				
			||||||
		sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName())
 | 
							time, err := time.Parse("2006-01-02", "0001-01-02")
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								scope.Err(err)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= %v)", scope.QuotedTableName(), scope.QuotedTableName(), scope.AddToVars(time))
 | 
				
			||||||
		primaryConditions = append(primaryConditions, sql)
 | 
							primaryConditions = append(primaryConditions, sql)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -39,7 +39,7 @@ type User struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type CreditCard struct {
 | 
					type CreditCard struct {
 | 
				
			||||||
	ID        int8
 | 
						ID        uint64
 | 
				
			||||||
	Number    string
 | 
						Number    string
 | 
				
			||||||
	UserId    sql.NullInt64
 | 
						UserId    sql.NullInt64
 | 
				
			||||||
	CreatedAt time.Time
 | 
						CreatedAt time.Time
 | 
				
			||||||
@ -48,7 +48,7 @@ type CreditCard struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Email struct {
 | 
					type Email struct {
 | 
				
			||||||
	Id        int16
 | 
						Id        int64
 | 
				
			||||||
	UserId    int
 | 
						UserId    int
 | 
				
			||||||
	Email     string `sql:"type:varchar(100);"`
 | 
						Email     string `sql:"type:varchar(100);"`
 | 
				
			||||||
	CreatedAt time.Time
 | 
						CreatedAt time.Time
 | 
				
			||||||
 | 
				
			|||||||
@ -1,4 +1,4 @@
 | 
				
			|||||||
dialects=("postgres" "mysql" "sqlite")
 | 
					dialects=("postgres" "mysql" "sqlite" "cockroach")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
for dialect in "${dialects[@]}" ; do
 | 
					for dialect in "${dialects[@]}" ; do
 | 
				
			||||||
    GORM_DIALECT=${dialect} go test
 | 
					    GORM_DIALECT=${dialect} go test
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user