Merge branch 'master' into master

This commit is contained in:
Artemij Shepelev 2018-08-03 11:27:59 +03:00 committed by GitHub
commit c94879ea64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 102 additions and 15 deletions

View File

@ -1,12 +1,16 @@
package mssql
import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"time"
// Importing mssql driver package only in dialect file, otherwide not needed
_ "github.com/denisenkom/go-mssqldb"
"github.com/jinzhu/gorm"
)
@ -201,3 +205,27 @@ func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, st
}
return dialect.CurrentDatabase(), tableName
}
// JSON type to support easy handling of JSON data in character table fields
// using golang json.RawMessage for deferred decoding/encoding
type JSON struct {
json.RawMessage
}
// Value get value of JSON
func (j JSON) Value() (driver.Value, error) {
if len(j.RawMessage) == 0 {
return nil, nil
}
return j.MarshalJSON()
}
// Scan scan value into JSON
func (j *JSON) Scan(value interface{}) error {
str, ok := value.(string)
if !ok {
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value))
}
bytes := []byte(str)
return json.Unmarshal(bytes, j)
}

10
main.go
View File

@ -48,6 +48,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
}
var source string
var dbSQL SQLCommon
var ownDbSQL bool
switch value := args[0].(type) {
case string:
@ -59,8 +60,10 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
source = args[1].(string)
}
dbSQL, err = sql.Open(driver, source)
ownDbSQL = true
case SQLCommon:
dbSQL = value
ownDbSQL = false
default:
return nil, fmt.Errorf("invalid database source: %v is not a valid type", value)
}
@ -78,7 +81,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
}
// Send a ping to make sure the database connection is alive.
if d, ok := dbSQL.(*sql.DB); ok {
if err = d.Ping(); err != nil {
if err = d.Ping(); err != nil && ownDbSQL {
d.Close()
}
}
@ -119,7 +122,7 @@ func (s *DB) CommonDB() SQLCommon {
// Dialect get dialect
func (s *DB) Dialect() Dialect {
return s.parent.dialect
return s.dialect
}
// Callback return `Callbacks` container, you could add/change/delete callbacks with it
@ -484,6 +487,8 @@ func (s *DB) Begin() *DB {
if db, ok := c.db.(sqlDb); ok && db != nil {
tx, err := db.Begin()
c.db = interface{}(tx).(SQLCommon)
c.dialect.SetDB(c.db)
c.AddError(err)
} else {
c.AddError(ErrCantStartTransaction)
@ -748,6 +753,7 @@ func (s *DB) clone() *DB {
Value: s.Value,
Error: s.Error,
blockGlobalUpdate: s.blockGlobalUpdate,
dialect: newDialect(s.dialect.GetName(), s.db),
}
for key, value := range s.values {

View File

@ -398,6 +398,53 @@ func TestAutoMigration(t *testing.T) {
}
}
func TestCreateAndAutomigrateTransaction(t *testing.T) {
tx := DB.Begin()
func() {
type Bar struct {
ID uint
}
DB.DropTableIfExists(&Bar{})
if ok := DB.HasTable("bars"); ok {
t.Errorf("Table should not exist, but does")
}
if ok := tx.HasTable("bars"); ok {
t.Errorf("Table should not exist, but does")
}
}()
func() {
type Bar struct {
Name string
}
err := tx.CreateTable(&Bar{}).Error
if err != nil {
t.Errorf("Should have been able to create the table, but couldn't: %s", err)
}
if ok := tx.HasTable(&Bar{}); !ok {
t.Errorf("The transaction should be able to see the table")
}
}()
func() {
type Bar struct {
Stuff string
}
err := tx.AutoMigrate(&Bar{}).Error
if err != nil {
t.Errorf("Should have been able to alter the table, but couldn't")
}
}()
tx.Rollback()
}
type MultipleIndexes struct {
ID int64
UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`

View File

@ -181,17 +181,17 @@ func TestSearchWithPlainSQL(t *testing.T) {
scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users)
if len(users) != 2 {
t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users))
t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users))
}
scopedb.Where("birthday > ?", "2002-10-10").Find(&users)
if len(users) != 2 {
t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users))
t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users))
}
scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users)
if len(users) != 1 {
t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
}
DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users)
@ -532,28 +532,28 @@ func TestNot(t *testing.T) {
DB.Table("users").Where("name = ?", "user3").Count(&name3Count)
DB.Not("name", "user3").Find(&users4)
if len(users1)-len(users4) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
t.Errorf("Should find all users' name not equal 3")
}
DB.Not("name = ?", "user3").Find(&users4)
if len(users1)-len(users4) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
t.Errorf("Should find all users' name not equal 3")
}
DB.Not("name <> ?", "user3").Find(&users4)
if len(users4) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
t.Errorf("Should find all users' name not equal 3")
}
DB.Not(User{Name: "user3"}).Find(&users5)
if len(users1)-len(users5) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
t.Errorf("Should find all users' name not equal 3")
}
DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6)
if len(users1)-len(users6) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
t.Errorf("Should find all users' name not equal 3")
}
DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7)
@ -563,14 +563,14 @@ func TestNot(t *testing.T) {
DB.Not("name", []string{"user3"}).Find(&users8)
if len(users1)-len(users8) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
t.Errorf("Should find all users' name not equal 3")
}
var name2Count int64
DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
DB.Not("name", []string{"user3", "user2"}).Find(&users9)
if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) {
t.Errorf("Should find all users's name not equal 3")
t.Errorf("Should find all users' name not equal 3")
}
}

View File

@ -63,7 +63,7 @@ func (scope *Scope) SQLDB() SQLCommon {
// Dialect get dialect
func (scope *Scope) Dialect() Dialect {
return scope.db.parent.dialect
return scope.db.dialect
}
// Quote used to quote string to escape them for database
@ -1216,11 +1216,17 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
func (scope *Scope) removeForeignKey(field string, dest string) {
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
return
}
var query = `ALTER TABLE %s DROP CONSTRAINT %s;`
var mysql mysql
var query string
if scope.Dialect().GetName() == mysql.GetName() {
query = `ALTER TABLE %s DROP FOREIGN KEY %s;`
} else {
query = `ALTER TABLE %s DROP CONSTRAINT %s;`
}
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec()
}