Merge branch 'master' into master
This commit is contained in:
commit
c94879ea64
@ -1,12 +1,16 @@
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
// Importing mssql driver package only in dialect file, otherwide not needed
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
@ -201,3 +205,27 @@ func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, st
|
||||
}
|
||||
return dialect.CurrentDatabase(), tableName
|
||||
}
|
||||
|
||||
// JSON type to support easy handling of JSON data in character table fields
|
||||
// using golang json.RawMessage for deferred decoding/encoding
|
||||
type JSON struct {
|
||||
json.RawMessage
|
||||
}
|
||||
|
||||
// Value get value of JSON
|
||||
func (j JSON) Value() (driver.Value, error) {
|
||||
if len(j.RawMessage) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return j.MarshalJSON()
|
||||
}
|
||||
|
||||
// Scan scan value into JSON
|
||||
func (j *JSON) Scan(value interface{}) error {
|
||||
str, ok := value.(string)
|
||||
if !ok {
|
||||
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value))
|
||||
}
|
||||
bytes := []byte(str)
|
||||
return json.Unmarshal(bytes, j)
|
||||
}
|
||||
|
10
main.go
10
main.go
@ -48,6 +48,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
||||
}
|
||||
var source string
|
||||
var dbSQL SQLCommon
|
||||
var ownDbSQL bool
|
||||
|
||||
switch value := args[0].(type) {
|
||||
case string:
|
||||
@ -59,8 +60,10 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
||||
source = args[1].(string)
|
||||
}
|
||||
dbSQL, err = sql.Open(driver, source)
|
||||
ownDbSQL = true
|
||||
case SQLCommon:
|
||||
dbSQL = value
|
||||
ownDbSQL = false
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid database source: %v is not a valid type", value)
|
||||
}
|
||||
@ -78,7 +81,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
||||
}
|
||||
// Send a ping to make sure the database connection is alive.
|
||||
if d, ok := dbSQL.(*sql.DB); ok {
|
||||
if err = d.Ping(); err != nil {
|
||||
if err = d.Ping(); err != nil && ownDbSQL {
|
||||
d.Close()
|
||||
}
|
||||
}
|
||||
@ -119,7 +122,7 @@ func (s *DB) CommonDB() SQLCommon {
|
||||
|
||||
// Dialect get dialect
|
||||
func (s *DB) Dialect() Dialect {
|
||||
return s.parent.dialect
|
||||
return s.dialect
|
||||
}
|
||||
|
||||
// Callback return `Callbacks` container, you could add/change/delete callbacks with it
|
||||
@ -484,6 +487,8 @@ func (s *DB) Begin() *DB {
|
||||
if db, ok := c.db.(sqlDb); ok && db != nil {
|
||||
tx, err := db.Begin()
|
||||
c.db = interface{}(tx).(SQLCommon)
|
||||
|
||||
c.dialect.SetDB(c.db)
|
||||
c.AddError(err)
|
||||
} else {
|
||||
c.AddError(ErrCantStartTransaction)
|
||||
@ -748,6 +753,7 @@ func (s *DB) clone() *DB {
|
||||
Value: s.Value,
|
||||
Error: s.Error,
|
||||
blockGlobalUpdate: s.blockGlobalUpdate,
|
||||
dialect: newDialect(s.dialect.GetName(), s.db),
|
||||
}
|
||||
|
||||
for key, value := range s.values {
|
||||
|
@ -398,6 +398,53 @@ func TestAutoMigration(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAndAutomigrateTransaction(t *testing.T) {
|
||||
tx := DB.Begin()
|
||||
|
||||
func() {
|
||||
type Bar struct {
|
||||
ID uint
|
||||
}
|
||||
DB.DropTableIfExists(&Bar{})
|
||||
|
||||
if ok := DB.HasTable("bars"); ok {
|
||||
t.Errorf("Table should not exist, but does")
|
||||
}
|
||||
|
||||
if ok := tx.HasTable("bars"); ok {
|
||||
t.Errorf("Table should not exist, but does")
|
||||
}
|
||||
}()
|
||||
|
||||
func() {
|
||||
type Bar struct {
|
||||
Name string
|
||||
}
|
||||
err := tx.CreateTable(&Bar{}).Error
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Should have been able to create the table, but couldn't: %s", err)
|
||||
}
|
||||
|
||||
if ok := tx.HasTable(&Bar{}); !ok {
|
||||
t.Errorf("The transaction should be able to see the table")
|
||||
}
|
||||
}()
|
||||
|
||||
func() {
|
||||
type Bar struct {
|
||||
Stuff string
|
||||
}
|
||||
|
||||
err := tx.AutoMigrate(&Bar{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("Should have been able to alter the table, but couldn't")
|
||||
}
|
||||
}()
|
||||
|
||||
tx.Rollback()
|
||||
}
|
||||
|
||||
type MultipleIndexes struct {
|
||||
ID int64
|
||||
UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`
|
||||
|
@ -181,17 +181,17 @@ func TestSearchWithPlainSQL(t *testing.T) {
|
||||
|
||||
scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users)
|
||||
if len(users) != 2 {
|
||||
t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users))
|
||||
t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users))
|
||||
}
|
||||
|
||||
scopedb.Where("birthday > ?", "2002-10-10").Find(&users)
|
||||
if len(users) != 2 {
|
||||
t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users))
|
||||
t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users))
|
||||
}
|
||||
|
||||
scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users)
|
||||
if len(users) != 1 {
|
||||
t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
|
||||
t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
|
||||
}
|
||||
|
||||
DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users)
|
||||
@ -532,28 +532,28 @@ func TestNot(t *testing.T) {
|
||||
DB.Table("users").Where("name = ?", "user3").Count(&name3Count)
|
||||
DB.Not("name", "user3").Find(&users4)
|
||||
if len(users1)-len(users4) != int(name3Count) {
|
||||
t.Errorf("Should find all users's name not equal 3")
|
||||
t.Errorf("Should find all users' name not equal 3")
|
||||
}
|
||||
|
||||
DB.Not("name = ?", "user3").Find(&users4)
|
||||
if len(users1)-len(users4) != int(name3Count) {
|
||||
t.Errorf("Should find all users's name not equal 3")
|
||||
t.Errorf("Should find all users' name not equal 3")
|
||||
}
|
||||
|
||||
DB.Not("name <> ?", "user3").Find(&users4)
|
||||
if len(users4) != int(name3Count) {
|
||||
t.Errorf("Should find all users's name not equal 3")
|
||||
t.Errorf("Should find all users' name not equal 3")
|
||||
}
|
||||
|
||||
DB.Not(User{Name: "user3"}).Find(&users5)
|
||||
|
||||
if len(users1)-len(users5) != int(name3Count) {
|
||||
t.Errorf("Should find all users's name not equal 3")
|
||||
t.Errorf("Should find all users' name not equal 3")
|
||||
}
|
||||
|
||||
DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6)
|
||||
if len(users1)-len(users6) != int(name3Count) {
|
||||
t.Errorf("Should find all users's name not equal 3")
|
||||
t.Errorf("Should find all users' name not equal 3")
|
||||
}
|
||||
|
||||
DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7)
|
||||
@ -563,14 +563,14 @@ func TestNot(t *testing.T) {
|
||||
|
||||
DB.Not("name", []string{"user3"}).Find(&users8)
|
||||
if len(users1)-len(users8) != int(name3Count) {
|
||||
t.Errorf("Should find all users's name not equal 3")
|
||||
t.Errorf("Should find all users' name not equal 3")
|
||||
}
|
||||
|
||||
var name2Count int64
|
||||
DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
|
||||
DB.Not("name", []string{"user3", "user2"}).Find(&users9)
|
||||
if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) {
|
||||
t.Errorf("Should find all users's name not equal 3")
|
||||
t.Errorf("Should find all users' name not equal 3")
|
||||
}
|
||||
}
|
||||
|
||||
|
12
scope.go
12
scope.go
@ -63,7 +63,7 @@ func (scope *Scope) SQLDB() SQLCommon {
|
||||
|
||||
// Dialect get dialect
|
||||
func (scope *Scope) Dialect() Dialect {
|
||||
return scope.db.parent.dialect
|
||||
return scope.db.dialect
|
||||
}
|
||||
|
||||
// Quote used to quote string to escape them for database
|
||||
@ -1216,11 +1216,17 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
|
||||
|
||||
func (scope *Scope) removeForeignKey(field string, dest string) {
|
||||
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
|
||||
|
||||
if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
||||
return
|
||||
}
|
||||
var query = `ALTER TABLE %s DROP CONSTRAINT %s;`
|
||||
var mysql mysql
|
||||
var query string
|
||||
if scope.Dialect().GetName() == mysql.GetName() {
|
||||
query = `ALTER TABLE %s DROP FOREIGN KEY %s;`
|
||||
} else {
|
||||
query = `ALTER TABLE %s DROP CONSTRAINT %s;`
|
||||
}
|
||||
|
||||
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec()
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user