Merge branch 'master' into master
This commit is contained in:
commit
c94879ea64
@ -1,12 +1,16 @@
|
|||||||
package mssql
|
package mssql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
// Importing mssql driver package only in dialect file, otherwide not needed
|
||||||
_ "github.com/denisenkom/go-mssqldb"
|
_ "github.com/denisenkom/go-mssqldb"
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
)
|
)
|
||||||
@ -201,3 +205,27 @@ func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, st
|
|||||||
}
|
}
|
||||||
return dialect.CurrentDatabase(), tableName
|
return dialect.CurrentDatabase(), tableName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JSON type to support easy handling of JSON data in character table fields
|
||||||
|
// using golang json.RawMessage for deferred decoding/encoding
|
||||||
|
type JSON struct {
|
||||||
|
json.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value get value of JSON
|
||||||
|
func (j JSON) Value() (driver.Value, error) {
|
||||||
|
if len(j.RawMessage) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return j.MarshalJSON()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan scan value into JSON
|
||||||
|
func (j *JSON) Scan(value interface{}) error {
|
||||||
|
str, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value))
|
||||||
|
}
|
||||||
|
bytes := []byte(str)
|
||||||
|
return json.Unmarshal(bytes, j)
|
||||||
|
}
|
||||||
|
10
main.go
10
main.go
@ -48,6 +48,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
|||||||
}
|
}
|
||||||
var source string
|
var source string
|
||||||
var dbSQL SQLCommon
|
var dbSQL SQLCommon
|
||||||
|
var ownDbSQL bool
|
||||||
|
|
||||||
switch value := args[0].(type) {
|
switch value := args[0].(type) {
|
||||||
case string:
|
case string:
|
||||||
@ -59,8 +60,10 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
|||||||
source = args[1].(string)
|
source = args[1].(string)
|
||||||
}
|
}
|
||||||
dbSQL, err = sql.Open(driver, source)
|
dbSQL, err = sql.Open(driver, source)
|
||||||
|
ownDbSQL = true
|
||||||
case SQLCommon:
|
case SQLCommon:
|
||||||
dbSQL = value
|
dbSQL = value
|
||||||
|
ownDbSQL = false
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid database source: %v is not a valid type", value)
|
return nil, fmt.Errorf("invalid database source: %v is not a valid type", value)
|
||||||
}
|
}
|
||||||
@ -78,7 +81,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
|||||||
}
|
}
|
||||||
// Send a ping to make sure the database connection is alive.
|
// Send a ping to make sure the database connection is alive.
|
||||||
if d, ok := dbSQL.(*sql.DB); ok {
|
if d, ok := dbSQL.(*sql.DB); ok {
|
||||||
if err = d.Ping(); err != nil {
|
if err = d.Ping(); err != nil && ownDbSQL {
|
||||||
d.Close()
|
d.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -119,7 +122,7 @@ func (s *DB) CommonDB() SQLCommon {
|
|||||||
|
|
||||||
// Dialect get dialect
|
// Dialect get dialect
|
||||||
func (s *DB) Dialect() Dialect {
|
func (s *DB) Dialect() Dialect {
|
||||||
return s.parent.dialect
|
return s.dialect
|
||||||
}
|
}
|
||||||
|
|
||||||
// Callback return `Callbacks` container, you could add/change/delete callbacks with it
|
// Callback return `Callbacks` container, you could add/change/delete callbacks with it
|
||||||
@ -484,6 +487,8 @@ func (s *DB) Begin() *DB {
|
|||||||
if db, ok := c.db.(sqlDb); ok && db != nil {
|
if db, ok := c.db.(sqlDb); ok && db != nil {
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
c.db = interface{}(tx).(SQLCommon)
|
c.db = interface{}(tx).(SQLCommon)
|
||||||
|
|
||||||
|
c.dialect.SetDB(c.db)
|
||||||
c.AddError(err)
|
c.AddError(err)
|
||||||
} else {
|
} else {
|
||||||
c.AddError(ErrCantStartTransaction)
|
c.AddError(ErrCantStartTransaction)
|
||||||
@ -748,6 +753,7 @@ func (s *DB) clone() *DB {
|
|||||||
Value: s.Value,
|
Value: s.Value,
|
||||||
Error: s.Error,
|
Error: s.Error,
|
||||||
blockGlobalUpdate: s.blockGlobalUpdate,
|
blockGlobalUpdate: s.blockGlobalUpdate,
|
||||||
|
dialect: newDialect(s.dialect.GetName(), s.db),
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, value := range s.values {
|
for key, value := range s.values {
|
||||||
|
@ -398,6 +398,53 @@ func TestAutoMigration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateAndAutomigrateTransaction(t *testing.T) {
|
||||||
|
tx := DB.Begin()
|
||||||
|
|
||||||
|
func() {
|
||||||
|
type Bar struct {
|
||||||
|
ID uint
|
||||||
|
}
|
||||||
|
DB.DropTableIfExists(&Bar{})
|
||||||
|
|
||||||
|
if ok := DB.HasTable("bars"); ok {
|
||||||
|
t.Errorf("Table should not exist, but does")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok := tx.HasTable("bars"); ok {
|
||||||
|
t.Errorf("Table should not exist, but does")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
func() {
|
||||||
|
type Bar struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
err := tx.CreateTable(&Bar{}).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Should have been able to create the table, but couldn't: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok := tx.HasTable(&Bar{}); !ok {
|
||||||
|
t.Errorf("The transaction should be able to see the table")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
func() {
|
||||||
|
type Bar struct {
|
||||||
|
Stuff string
|
||||||
|
}
|
||||||
|
|
||||||
|
err := tx.AutoMigrate(&Bar{}).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Should have been able to alter the table, but couldn't")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
tx.Rollback()
|
||||||
|
}
|
||||||
|
|
||||||
type MultipleIndexes struct {
|
type MultipleIndexes struct {
|
||||||
ID int64
|
ID int64
|
||||||
UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`
|
UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`
|
||||||
|
@ -181,17 +181,17 @@ func TestSearchWithPlainSQL(t *testing.T) {
|
|||||||
|
|
||||||
scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users)
|
scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users)
|
||||||
if len(users) != 2 {
|
if len(users) != 2 {
|
||||||
t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users))
|
t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users))
|
||||||
}
|
}
|
||||||
|
|
||||||
scopedb.Where("birthday > ?", "2002-10-10").Find(&users)
|
scopedb.Where("birthday > ?", "2002-10-10").Find(&users)
|
||||||
if len(users) != 2 {
|
if len(users) != 2 {
|
||||||
t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users))
|
t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users))
|
||||||
}
|
}
|
||||||
|
|
||||||
scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users)
|
scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users)
|
||||||
if len(users) != 1 {
|
if len(users) != 1 {
|
||||||
t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
|
t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users)
|
DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users)
|
||||||
@ -532,28 +532,28 @@ func TestNot(t *testing.T) {
|
|||||||
DB.Table("users").Where("name = ?", "user3").Count(&name3Count)
|
DB.Table("users").Where("name = ?", "user3").Count(&name3Count)
|
||||||
DB.Not("name", "user3").Find(&users4)
|
DB.Not("name", "user3").Find(&users4)
|
||||||
if len(users1)-len(users4) != int(name3Count) {
|
if len(users1)-len(users4) != int(name3Count) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users' name not equal 3")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Not("name = ?", "user3").Find(&users4)
|
DB.Not("name = ?", "user3").Find(&users4)
|
||||||
if len(users1)-len(users4) != int(name3Count) {
|
if len(users1)-len(users4) != int(name3Count) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users' name not equal 3")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Not("name <> ?", "user3").Find(&users4)
|
DB.Not("name <> ?", "user3").Find(&users4)
|
||||||
if len(users4) != int(name3Count) {
|
if len(users4) != int(name3Count) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users' name not equal 3")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Not(User{Name: "user3"}).Find(&users5)
|
DB.Not(User{Name: "user3"}).Find(&users5)
|
||||||
|
|
||||||
if len(users1)-len(users5) != int(name3Count) {
|
if len(users1)-len(users5) != int(name3Count) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users' name not equal 3")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6)
|
DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6)
|
||||||
if len(users1)-len(users6) != int(name3Count) {
|
if len(users1)-len(users6) != int(name3Count) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users' name not equal 3")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7)
|
DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7)
|
||||||
@ -563,14 +563,14 @@ func TestNot(t *testing.T) {
|
|||||||
|
|
||||||
DB.Not("name", []string{"user3"}).Find(&users8)
|
DB.Not("name", []string{"user3"}).Find(&users8)
|
||||||
if len(users1)-len(users8) != int(name3Count) {
|
if len(users1)-len(users8) != int(name3Count) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users' name not equal 3")
|
||||||
}
|
}
|
||||||
|
|
||||||
var name2Count int64
|
var name2Count int64
|
||||||
DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
|
DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
|
||||||
DB.Not("name", []string{"user3", "user2"}).Find(&users9)
|
DB.Not("name", []string{"user3", "user2"}).Find(&users9)
|
||||||
if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) {
|
if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users' name not equal 3")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
12
scope.go
12
scope.go
@ -63,7 +63,7 @@ func (scope *Scope) SQLDB() SQLCommon {
|
|||||||
|
|
||||||
// Dialect get dialect
|
// Dialect get dialect
|
||||||
func (scope *Scope) Dialect() Dialect {
|
func (scope *Scope) Dialect() Dialect {
|
||||||
return scope.db.parent.dialect
|
return scope.db.dialect
|
||||||
}
|
}
|
||||||
|
|
||||||
// Quote used to quote string to escape them for database
|
// Quote used to quote string to escape them for database
|
||||||
@ -1216,11 +1216,17 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
|
|||||||
|
|
||||||
func (scope *Scope) removeForeignKey(field string, dest string) {
|
func (scope *Scope) removeForeignKey(field string, dest string) {
|
||||||
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
|
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
|
||||||
|
|
||||||
if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var query = `ALTER TABLE %s DROP CONSTRAINT %s;`
|
var mysql mysql
|
||||||
|
var query string
|
||||||
|
if scope.Dialect().GetName() == mysql.GetName() {
|
||||||
|
query = `ALTER TABLE %s DROP FOREIGN KEY %s;`
|
||||||
|
} else {
|
||||||
|
query = `ALTER TABLE %s DROP CONSTRAINT %s;`
|
||||||
|
}
|
||||||
|
|
||||||
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec()
|
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user