Merge branch 'master' into master

This commit is contained in:
Artemij Shepelev 2018-08-03 11:27:59 +03:00 committed by GitHub
commit c94879ea64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 102 additions and 15 deletions

View File

@ -1,12 +1,16 @@
package mssql package mssql
import ( import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"time" "time"
// Importing mssql driver package only in dialect file, otherwide not needed
_ "github.com/denisenkom/go-mssqldb" _ "github.com/denisenkom/go-mssqldb"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
) )
@ -201,3 +205,27 @@ func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, st
} }
return dialect.CurrentDatabase(), tableName return dialect.CurrentDatabase(), tableName
} }
// JSON type to support easy handling of JSON data in character table fields
// using golang json.RawMessage for deferred decoding/encoding
type JSON struct {
json.RawMessage
}
// Value get value of JSON
func (j JSON) Value() (driver.Value, error) {
if len(j.RawMessage) == 0 {
return nil, nil
}
return j.MarshalJSON()
}
// Scan scan value into JSON
func (j *JSON) Scan(value interface{}) error {
str, ok := value.(string)
if !ok {
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value))
}
bytes := []byte(str)
return json.Unmarshal(bytes, j)
}

10
main.go
View File

@ -48,6 +48,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
} }
var source string var source string
var dbSQL SQLCommon var dbSQL SQLCommon
var ownDbSQL bool
switch value := args[0].(type) { switch value := args[0].(type) {
case string: case string:
@ -59,8 +60,10 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
source = args[1].(string) source = args[1].(string)
} }
dbSQL, err = sql.Open(driver, source) dbSQL, err = sql.Open(driver, source)
ownDbSQL = true
case SQLCommon: case SQLCommon:
dbSQL = value dbSQL = value
ownDbSQL = false
default: default:
return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) return nil, fmt.Errorf("invalid database source: %v is not a valid type", value)
} }
@ -78,7 +81,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
} }
// Send a ping to make sure the database connection is alive. // Send a ping to make sure the database connection is alive.
if d, ok := dbSQL.(*sql.DB); ok { if d, ok := dbSQL.(*sql.DB); ok {
if err = d.Ping(); err != nil { if err = d.Ping(); err != nil && ownDbSQL {
d.Close() d.Close()
} }
} }
@ -119,7 +122,7 @@ func (s *DB) CommonDB() SQLCommon {
// Dialect get dialect // Dialect get dialect
func (s *DB) Dialect() Dialect { func (s *DB) Dialect() Dialect {
return s.parent.dialect return s.dialect
} }
// Callback return `Callbacks` container, you could add/change/delete callbacks with it // Callback return `Callbacks` container, you could add/change/delete callbacks with it
@ -484,6 +487,8 @@ func (s *DB) Begin() *DB {
if db, ok := c.db.(sqlDb); ok && db != nil { if db, ok := c.db.(sqlDb); ok && db != nil {
tx, err := db.Begin() tx, err := db.Begin()
c.db = interface{}(tx).(SQLCommon) c.db = interface{}(tx).(SQLCommon)
c.dialect.SetDB(c.db)
c.AddError(err) c.AddError(err)
} else { } else {
c.AddError(ErrCantStartTransaction) c.AddError(ErrCantStartTransaction)
@ -748,6 +753,7 @@ func (s *DB) clone() *DB {
Value: s.Value, Value: s.Value,
Error: s.Error, Error: s.Error,
blockGlobalUpdate: s.blockGlobalUpdate, blockGlobalUpdate: s.blockGlobalUpdate,
dialect: newDialect(s.dialect.GetName(), s.db),
} }
for key, value := range s.values { for key, value := range s.values {

View File

@ -398,6 +398,53 @@ func TestAutoMigration(t *testing.T) {
} }
} }
func TestCreateAndAutomigrateTransaction(t *testing.T) {
tx := DB.Begin()
func() {
type Bar struct {
ID uint
}
DB.DropTableIfExists(&Bar{})
if ok := DB.HasTable("bars"); ok {
t.Errorf("Table should not exist, but does")
}
if ok := tx.HasTable("bars"); ok {
t.Errorf("Table should not exist, but does")
}
}()
func() {
type Bar struct {
Name string
}
err := tx.CreateTable(&Bar{}).Error
if err != nil {
t.Errorf("Should have been able to create the table, but couldn't: %s", err)
}
if ok := tx.HasTable(&Bar{}); !ok {
t.Errorf("The transaction should be able to see the table")
}
}()
func() {
type Bar struct {
Stuff string
}
err := tx.AutoMigrate(&Bar{}).Error
if err != nil {
t.Errorf("Should have been able to alter the table, but couldn't")
}
}()
tx.Rollback()
}
type MultipleIndexes struct { type MultipleIndexes struct {
ID int64 ID int64
UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`

View File

@ -181,17 +181,17 @@ func TestSearchWithPlainSQL(t *testing.T) {
scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users) scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users)
if len(users) != 2 { if len(users) != 2 {
t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users)) t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users))
} }
scopedb.Where("birthday > ?", "2002-10-10").Find(&users) scopedb.Where("birthday > ?", "2002-10-10").Find(&users)
if len(users) != 2 { if len(users) != 2 {
t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users)) t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users))
} }
scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users)
if len(users) != 1 { if len(users) != 1 {
t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
} }
DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users)
@ -532,28 +532,28 @@ func TestNot(t *testing.T) {
DB.Table("users").Where("name = ?", "user3").Count(&name3Count) DB.Table("users").Where("name = ?", "user3").Count(&name3Count)
DB.Not("name", "user3").Find(&users4) DB.Not("name", "user3").Find(&users4)
if len(users1)-len(users4) != int(name3Count) { if len(users1)-len(users4) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3") t.Errorf("Should find all users' name not equal 3")
} }
DB.Not("name = ?", "user3").Find(&users4) DB.Not("name = ?", "user3").Find(&users4)
if len(users1)-len(users4) != int(name3Count) { if len(users1)-len(users4) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3") t.Errorf("Should find all users' name not equal 3")
} }
DB.Not("name <> ?", "user3").Find(&users4) DB.Not("name <> ?", "user3").Find(&users4)
if len(users4) != int(name3Count) { if len(users4) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3") t.Errorf("Should find all users' name not equal 3")
} }
DB.Not(User{Name: "user3"}).Find(&users5) DB.Not(User{Name: "user3"}).Find(&users5)
if len(users1)-len(users5) != int(name3Count) { if len(users1)-len(users5) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3") t.Errorf("Should find all users' name not equal 3")
} }
DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6) DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6)
if len(users1)-len(users6) != int(name3Count) { if len(users1)-len(users6) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3") t.Errorf("Should find all users' name not equal 3")
} }
DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7)
@ -563,14 +563,14 @@ func TestNot(t *testing.T) {
DB.Not("name", []string{"user3"}).Find(&users8) DB.Not("name", []string{"user3"}).Find(&users8)
if len(users1)-len(users8) != int(name3Count) { if len(users1)-len(users8) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3") t.Errorf("Should find all users' name not equal 3")
} }
var name2Count int64 var name2Count int64
DB.Table("users").Where("name = ?", "user2").Count(&name2Count) DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
DB.Not("name", []string{"user3", "user2"}).Find(&users9) DB.Not("name", []string{"user3", "user2"}).Find(&users9)
if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) {
t.Errorf("Should find all users's name not equal 3") t.Errorf("Should find all users' name not equal 3")
} }
} }

View File

@ -63,7 +63,7 @@ func (scope *Scope) SQLDB() SQLCommon {
// Dialect get dialect // Dialect get dialect
func (scope *Scope) Dialect() Dialect { func (scope *Scope) Dialect() Dialect {
return scope.db.parent.dialect return scope.db.dialect
} }
// Quote used to quote string to escape them for database // Quote used to quote string to escape them for database
@ -1216,11 +1216,17 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
func (scope *Scope) removeForeignKey(field string, dest string) { func (scope *Scope) removeForeignKey(field string, dest string) {
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
return return
} }
var query = `ALTER TABLE %s DROP CONSTRAINT %s;` var mysql mysql
var query string
if scope.Dialect().GetName() == mysql.GetName() {
query = `ALTER TABLE %s DROP FOREIGN KEY %s;`
} else {
query = `ALTER TABLE %s DROP CONSTRAINT %s;`
}
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec()
} }