Merge 5bdf677206f9c18ad3f36554b61f5e7892ace855 into c7b9acefb7d4570e0f46d94ce453467cc486d8fd
This commit is contained in:
commit
cf92ceae63
@ -16,7 +16,7 @@ func TestBelongsTo(t *testing.T) {
|
||||
}
|
||||
|
||||
if err := DB.Save(&post).Error; err != nil {
|
||||
t.Errorf("Got errors when save post", err.Error())
|
||||
t.Errorf("Got errors when save post %s", err)
|
||||
}
|
||||
|
||||
if post.Category.ID == 0 || post.MainCategory.ID == 0 {
|
||||
@ -184,7 +184,7 @@ func TestHasOne(t *testing.T) {
|
||||
}
|
||||
|
||||
if err := DB.Save(&user).Error; err != nil {
|
||||
t.Errorf("Got errors when save user", err.Error())
|
||||
t.Errorf("Got errors when save user %s", err.Error())
|
||||
}
|
||||
|
||||
if user.CreditCard.UserId.Int64 == 0 {
|
||||
@ -331,7 +331,7 @@ func TestHasMany(t *testing.T) {
|
||||
}
|
||||
|
||||
if err := DB.Save(&post).Error; err != nil {
|
||||
t.Errorf("Got errors when save post", err.Error())
|
||||
t.Errorf("Got errors when save post %s", err.Error())
|
||||
}
|
||||
|
||||
for _, comment := range post.Comments {
|
||||
|
@ -1,7 +1,9 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@ -28,7 +30,15 @@ func Create(scope *Scope) {
|
||||
for _, field := range fields {
|
||||
if scope.changeableField(field) {
|
||||
if field.IsNormal {
|
||||
if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
|
||||
supportPrimary := scope.Dialect().SupportUniquePrimaryKey()
|
||||
if !field.IsPrimaryKey || (field.IsPrimaryKey && (!field.IsBlank || !supportPrimary)) {
|
||||
if field.IsPrimaryKey && !supportPrimary && field.IsBlank {
|
||||
id := scope.Dialect().NewUniqueKey(scope)
|
||||
if scope.HasError() {
|
||||
return
|
||||
}
|
||||
field.Field.Set(reflect.ValueOf(id).Convert(field.Field.Type()))
|
||||
}
|
||||
if !field.IsBlank || !field.HasDefaultValue {
|
||||
columns = append(columns, scope.Quote(field.DBName))
|
||||
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
|
||||
@ -86,13 +96,14 @@ func Create(scope *Scope) {
|
||||
}
|
||||
} else {
|
||||
if primaryField == nil {
|
||||
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
|
||||
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == sql.ErrNoRows {
|
||||
} else if err == nil {
|
||||
scope.db.RowsAffected, _ = results.RowsAffected()
|
||||
} else {
|
||||
scope.Err(err)
|
||||
}
|
||||
} else {
|
||||
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
|
||||
if err := scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface()); err == nil || err == sql.ErrNoRows {
|
||||
scope.db.RowsAffected = 1
|
||||
} else {
|
||||
scope.Err(err)
|
||||
|
@ -1,6 +1,9 @@
|
||||
package gorm
|
||||
|
||||
import "reflect"
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func BeginTransaction(scope *Scope) {
|
||||
scope.Begin()
|
||||
@ -18,7 +21,9 @@ func SaveBeforeAssociations(scope *Scope) {
|
||||
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
|
||||
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||
value := field.Field
|
||||
scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
|
||||
if err := scope.NewDB().Save(value.Addr().Interface()).Error; err != nil && err != sql.ErrNoRows {
|
||||
scope.Err(err)
|
||||
}
|
||||
if len(relationship.ForeignFieldNames) != 0 {
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||
|
@ -43,6 +43,11 @@ func Update(scope *Scope) {
|
||||
|
||||
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
||||
for key, value := range updateAttrs.(map[string]interface{}) {
|
||||
if !scope.Dialect().SupportUpdatePrimaryKey() {
|
||||
if field, ok := scope.Fields()[key]; ok && field.IsPrimaryKey {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if scope.changeableDBColumn(key) {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
|
||||
}
|
||||
@ -50,6 +55,9 @@ func Update(scope *Scope) {
|
||||
} else {
|
||||
fields := scope.Fields()
|
||||
for _, field := range fields {
|
||||
if field.IsPrimaryKey && !scope.Dialect().SupportUpdatePrimaryKey() {
|
||||
continue
|
||||
}
|
||||
if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||
|
151
cockroach.go
Normal file
151
cockroach.go
Normal file
@ -0,0 +1,151 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
type cockroach struct {
|
||||
commonDialect
|
||||
}
|
||||
|
||||
func (cockroach) BinVar(i int) string {
|
||||
return fmt.Sprintf("$%v", i)
|
||||
}
|
||||
|
||||
func (cockroach) SupportLastInsertId() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (cockroach) SupportUniquePrimaryKey() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (cockroach) SupportUpdatePrimaryKey() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (cockroach) NewUniqueKey(scope *Scope) uint64 {
|
||||
rows, err := scope.NewDB().Raw(`SELECT experimental_unique_int()`).Rows()
|
||||
if err != nil {
|
||||
scope.Err(err)
|
||||
return 0
|
||||
}
|
||||
defer rows.Close()
|
||||
var id int64
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
scope.Err(err)
|
||||
return 0
|
||||
}
|
||||
}
|
||||
return uint64(id)
|
||||
}
|
||||
|
||||
func (cockroach) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "BOOLEAN"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
return "INTEGER PRIMARY KEY"
|
||||
}
|
||||
return "INTEGER"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if autoIncrease {
|
||||
return "BIGINT PRIMARY KEY"
|
||||
}
|
||||
return "BIGINT"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "FLOAT"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return "VARCHAR"
|
||||
}
|
||||
return "TEXT"
|
||||
case reflect.Struct:
|
||||
if _, ok := value.Interface().(time.Time); ok {
|
||||
return "TIMESTAMP"
|
||||
}
|
||||
default:
|
||||
if _, ok := value.Interface().([]byte); ok {
|
||||
return "BYTES"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for cockroach", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s cockroach) HasTable(scope *Scope, tableName string) bool {
|
||||
rows, err := scope.NewDB().Raw("show tables").Rows()
|
||||
if err != nil {
|
||||
scope.Err(err)
|
||||
return false
|
||||
}
|
||||
defer rows.Close()
|
||||
var name string
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
scope.Err(err)
|
||||
return false
|
||||
}
|
||||
if name == tableName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s cockroach) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
||||
rows, err := scope.NewDB().Raw(fmt.Sprintf("show columns from %s", tableName)).Rows()
|
||||
if err != nil {
|
||||
scope.Err(err)
|
||||
return false
|
||||
}
|
||||
defer rows.Close()
|
||||
var column string
|
||||
var typ, null, defaultVal interface{}
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&column, &typ, &null, &defaultVal); err != nil {
|
||||
scope.Err(err)
|
||||
return false
|
||||
}
|
||||
if column == columnName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s cockroach) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
rows, err := scope.NewDB().Raw(fmt.Sprintf("show index from %s", tableName)).Rows()
|
||||
if err != nil {
|
||||
scope.Err(err)
|
||||
return false
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var table, name, column, direction string
|
||||
var unique, storing bool
|
||||
var seq int
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&table, &name, &unique, &seq, &column, &direction, &storing); err != nil {
|
||||
scope.Err(err)
|
||||
return false
|
||||
}
|
||||
if name == indexName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (cockroach) RemoveIndex(scope *Scope, indexName string) {
|
||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v@%v", scope.TableName(), indexName)).Error)
|
||||
}
|
||||
|
||||
func (s cockroach) CurrentDatabase(scope *Scope) string {
|
||||
var name string
|
||||
s.RawScanString(scope, &name, "SHOW DATABASE")
|
||||
return name
|
||||
}
|
@ -16,6 +16,18 @@ func (commonDialect) SupportLastInsertId() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (commonDialect) SupportUniquePrimaryKey() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (commonDialect) SupportUpdatePrimaryKey() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (commonDialect) NewUniqueKey(scope *Scope) uint64 {
|
||||
panic("NewUniqueKey not supported by commonDialect")
|
||||
}
|
||||
|
||||
func (commonDialect) HasTop() bool {
|
||||
return false
|
||||
}
|
||||
|
@ -8,6 +8,9 @@ import (
|
||||
type Dialect interface {
|
||||
BinVar(i int) string
|
||||
SupportLastInsertId() bool
|
||||
SupportUniquePrimaryKey() bool
|
||||
SupportUpdatePrimaryKey() bool
|
||||
NewUniqueKey(scope *Scope) uint64
|
||||
HasTop() bool
|
||||
SqlTag(value reflect.Value, size int, autoIncrease bool) string
|
||||
ReturningStr(tableName, key string) string
|
||||
@ -23,6 +26,8 @@ type Dialect interface {
|
||||
func NewDialect(driver string) Dialect {
|
||||
var d Dialect
|
||||
switch driver {
|
||||
case "cockroach":
|
||||
d = &cockroach{}
|
||||
case "postgres":
|
||||
d = &postgres{}
|
||||
case "foundation":
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
_ "github.com/cockroachdb/cockroach/sql/driver"
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
testdb "github.com/erikstmartin/go-testdb"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
@ -53,6 +54,9 @@ func OpenTestConnection() (db gorm.DB, err error) {
|
||||
case "postgres":
|
||||
fmt.Println("testing postgres...")
|
||||
db, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable")
|
||||
case "cockroach":
|
||||
fmt.Println("testing cockroach...")
|
||||
db, err = gorm.Open("cockroach", "http://localhost:26257?database=gorm&user=root")
|
||||
case "foundation":
|
||||
fmt.Println("testing foundation...")
|
||||
db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable")
|
||||
@ -463,7 +467,9 @@ func TestJoins(t *testing.T) {
|
||||
DB.Save(&user)
|
||||
|
||||
var result User
|
||||
DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").First(&result)
|
||||
if err := DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").First(&result).Error; err != nil {
|
||||
t.Errorf("Error while joining: %s", err)
|
||||
}
|
||||
if result.Name != "joins" || result.Id != user.Id {
|
||||
t.Errorf("Should find all two emails with Join")
|
||||
}
|
||||
|
@ -524,11 +524,6 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
|
||||
sqlType = value
|
||||
}
|
||||
|
||||
additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
|
||||
if value, ok := field.TagSettings["DEFAULT"]; ok {
|
||||
additionalType = additionalType + " DEFAULT " + value
|
||||
}
|
||||
|
||||
if field.IsScanner {
|
||||
var getScannerValue func(reflect.Value)
|
||||
getScannerValue = func(value reflect.Value) {
|
||||
@ -558,6 +553,14 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
|
||||
sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease)
|
||||
}
|
||||
|
||||
additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
|
||||
if value, ok := field.TagSettings["DEFAULT"]; ok {
|
||||
if _, ok := scope.Dialect().(*cockroach); ok && strings.TrimSpace(strings.ToLower(value)) == "null" {
|
||||
value = value + "::" + strings.Split(sqlType, " ")[0]
|
||||
}
|
||||
additionalType = additionalType + " DEFAULT " + value
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
} else {
|
||||
|
@ -31,8 +31,8 @@ func TestFirstAndLast(t *testing.T) {
|
||||
t.Errorf("Find first record as slice")
|
||||
}
|
||||
|
||||
if DB.Joins("left join emails on emails.user_id = users.id").First(&User{}).Error != nil {
|
||||
t.Errorf("Should not raise any error when order with Join table")
|
||||
if err := DB.Joins("left join emails on emails.user_id = users.id").First(&User{}).Error; err != nil {
|
||||
t.Errorf("Should not raise any error when order with Join table: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -52,15 +52,19 @@ func TestFirstAndLastWithNoStdPrimaryKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUIntPrimaryKey(t *testing.T) {
|
||||
insertedAnimal := &Animal{Name: "animalUint1"}
|
||||
insertedAnimal2 := &Animal{Name: "animalUint2"}
|
||||
DB.Save(insertedAnimal)
|
||||
DB.Save(insertedAnimal2)
|
||||
var animal Animal
|
||||
DB.First(&animal, uint64(1))
|
||||
if animal.Counter != 1 {
|
||||
t.Errorf("Fetch a record from with a non-int primary key should work, but failed")
|
||||
DB.First(&animal, insertedAnimal.Counter)
|
||||
if animal.Counter != insertedAnimal.Counter || animal.Counter <= 0 {
|
||||
t.Errorf("Fetch a record from with a non-int primary key should work, but failed; got %d", animal.Counter)
|
||||
}
|
||||
|
||||
DB.Model(Animal{}).Where(Animal{Counter: uint64(2)}).Scan(&animal)
|
||||
if animal.Counter != 2 {
|
||||
t.Errorf("Fetch a record from with a non-int primary key should work, but failed")
|
||||
DB.Model(Animal{}).Where(Animal{Counter: insertedAnimal2.Counter}).Scan(&animal)
|
||||
if animal.Counter != insertedAnimal2.Counter || animal.Counter <= 0 {
|
||||
t.Errorf("Fetch a record from with a non-int primary key should work, but failed; got %d", animal.Counter)
|
||||
}
|
||||
}
|
||||
|
||||
|
5
scope.go
5
scope.go
@ -1,6 +1,7 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
@ -314,13 +315,15 @@ func (scope *Scope) Raw(sql string) *Scope {
|
||||
return scope
|
||||
}
|
||||
|
||||
var _, driverResultNoRows = driver.ResultNoRows.RowsAffected()
|
||||
|
||||
// Exec invoke sql
|
||||
func (scope *Scope) Exec() *Scope {
|
||||
defer scope.Trace(NowFunc())
|
||||
|
||||
if !scope.HasError() {
|
||||
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||
if count, err := result.RowsAffected(); scope.Err(err) == nil {
|
||||
if count, err := result.RowsAffected(); err != nil && err.Error() == driverResultNoRows.Error() || scope.Err(err) == nil {
|
||||
scope.db.RowsAffected = count
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (scope *Scope) primaryCondition(value interface{}) string {
|
||||
@ -170,7 +171,12 @@ func (scope *Scope) whereSql() (sql string) {
|
||||
var primaryConditions, andConditions, orConditions []string
|
||||
|
||||
if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil {
|
||||
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName())
|
||||
time, err := time.Parse("2006-01-02", "0001-01-02")
|
||||
if err != nil {
|
||||
scope.Err(err)
|
||||
return
|
||||
}
|
||||
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= %v)", scope.QuotedTableName(), scope.QuotedTableName(), scope.AddToVars(time))
|
||||
primaryConditions = append(primaryConditions, sql)
|
||||
}
|
||||
|
||||
|
@ -20,13 +20,13 @@ func TestScannableSlices(t *testing.T) {
|
||||
}
|
||||
|
||||
if err := DB.Save(&r1).Error; err != nil {
|
||||
t.Errorf("Should save record with slice values")
|
||||
t.Errorf("Should save record with slice values; err %s", err)
|
||||
}
|
||||
|
||||
var r2 RecordWithSlice
|
||||
|
||||
if err := DB.Find(&r2).Error; err != nil {
|
||||
t.Errorf("Should fetch record with slice values")
|
||||
t.Errorf("Should fetch record with slice values; err %s", err)
|
||||
}
|
||||
|
||||
if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" {
|
||||
|
@ -39,7 +39,7 @@ type User struct {
|
||||
}
|
||||
|
||||
type CreditCard struct {
|
||||
ID int8
|
||||
ID uint64
|
||||
Number string
|
||||
UserId sql.NullInt64
|
||||
CreatedAt time.Time
|
||||
@ -48,7 +48,7 @@ type CreditCard struct {
|
||||
}
|
||||
|
||||
type Email struct {
|
||||
Id int16
|
||||
Id int64
|
||||
UserId int
|
||||
Email string `sql:"type:varchar(100);"`
|
||||
CreatedAt time.Time
|
||||
|
@ -1,4 +1,4 @@
|
||||
dialects=("postgres" "mysql" "sqlite")
|
||||
dialects=("postgres" "mysql" "sqlite" "cockroach")
|
||||
|
||||
for dialect in "${dialects[@]}" ; do
|
||||
GORM_DIALECT=${dialect} go test
|
||||
|
Loading…
x
Reference in New Issue
Block a user