diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index a8d3c45a..731721cb 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -1,12 +1,16 @@ package mssql import ( + "database/sql/driver" + "encoding/json" + "errors" "fmt" "reflect" "strconv" "strings" "time" + // Importing mssql driver package only in dialect file, otherwide not needed _ "github.com/denisenkom/go-mssqldb" "github.com/jinzhu/gorm" ) @@ -201,3 +205,27 @@ func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, st } return dialect.CurrentDatabase(), tableName } + +// JSON type to support easy handling of JSON data in character table fields +// using golang json.RawMessage for deferred decoding/encoding +type JSON struct { + json.RawMessage +} + +// Value get value of JSON +func (j JSON) Value() (driver.Value, error) { + if len(j.RawMessage) == 0 { + return nil, nil + } + return j.MarshalJSON() +} + +// Scan scan value into JSON +func (j *JSON) Scan(value interface{}) error { + str, ok := value.(string) + if !ok { + return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value)) + } + bytes := []byte(str) + return json.Unmarshal(bytes, j) +} diff --git a/main.go b/main.go index 25c3a06b..de6ce428 100644 --- a/main.go +++ b/main.go @@ -48,6 +48,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { } var source string var dbSQL SQLCommon + var ownDbSQL bool switch value := args[0].(type) { case string: @@ -59,8 +60,10 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { source = args[1].(string) } dbSQL, err = sql.Open(driver, source) + ownDbSQL = true case SQLCommon: dbSQL = value + ownDbSQL = false default: return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) } @@ -78,7 +81,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { } // Send a ping to make sure the database connection is alive. if d, ok := dbSQL.(*sql.DB); ok { - if err = d.Ping(); err != nil { + if err = d.Ping(); err != nil && ownDbSQL { d.Close() } } @@ -119,7 +122,7 @@ func (s *DB) CommonDB() SQLCommon { // Dialect get dialect func (s *DB) Dialect() Dialect { - return s.parent.dialect + return s.dialect } // Callback return `Callbacks` container, you could add/change/delete callbacks with it @@ -484,6 +487,8 @@ func (s *DB) Begin() *DB { if db, ok := c.db.(sqlDb); ok && db != nil { tx, err := db.Begin() c.db = interface{}(tx).(SQLCommon) + + c.dialect.SetDB(c.db) c.AddError(err) } else { c.AddError(ErrCantStartTransaction) @@ -748,6 +753,7 @@ func (s *DB) clone() *DB { Value: s.Value, Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate, + dialect: newDialect(s.dialect.GetName(), s.db), } for key, value := range s.values { diff --git a/migration_test.go b/migration_test.go index 7c694485..78555dcc 100644 --- a/migration_test.go +++ b/migration_test.go @@ -398,6 +398,53 @@ func TestAutoMigration(t *testing.T) { } } +func TestCreateAndAutomigrateTransaction(t *testing.T) { + tx := DB.Begin() + + func() { + type Bar struct { + ID uint + } + DB.DropTableIfExists(&Bar{}) + + if ok := DB.HasTable("bars"); ok { + t.Errorf("Table should not exist, but does") + } + + if ok := tx.HasTable("bars"); ok { + t.Errorf("Table should not exist, but does") + } + }() + + func() { + type Bar struct { + Name string + } + err := tx.CreateTable(&Bar{}).Error + + if err != nil { + t.Errorf("Should have been able to create the table, but couldn't: %s", err) + } + + if ok := tx.HasTable(&Bar{}); !ok { + t.Errorf("The transaction should be able to see the table") + } + }() + + func() { + type Bar struct { + Stuff string + } + + err := tx.AutoMigrate(&Bar{}).Error + if err != nil { + t.Errorf("Should have been able to alter the table, but couldn't") + } + }() + + tx.Rollback() +} + type MultipleIndexes struct { ID int64 UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` diff --git a/query_test.go b/query_test.go index fac7d4d8..15bf8b3c 100644 --- a/query_test.go +++ b/query_test.go @@ -181,17 +181,17 @@ func TestSearchWithPlainSQL(t *testing.T) { scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users) if len(users) != 2 { - t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users)) + t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users)) } scopedb.Where("birthday > ?", "2002-10-10").Find(&users) if len(users) != 2 { - t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users)) + t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users)) } scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) if len(users) != 1 { - t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) + t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) } DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) @@ -532,28 +532,28 @@ func TestNot(t *testing.T) { DB.Table("users").Where("name = ?", "user3").Count(&name3Count) DB.Not("name", "user3").Find(&users4) if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not("name = ?", "user3").Find(&users4) if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not("name <> ?", "user3").Find(&users4) if len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(User{Name: "user3"}).Find(&users5) if len(users1)-len(users5) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6) if len(users1)-len(users6) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) @@ -563,14 +563,14 @@ func TestNot(t *testing.T) { DB.Not("name", []string{"user3"}).Find(&users8) if len(users1)-len(users8) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } var name2Count int64 DB.Table("users").Where("name = ?", "user2").Count(&name2Count) DB.Not("name", []string{"user3", "user2"}).Find(&users9) if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } } diff --git a/scope.go b/scope.go index 397ccf0b..a05c1d61 100644 --- a/scope.go +++ b/scope.go @@ -63,7 +63,7 @@ func (scope *Scope) SQLDB() SQLCommon { // Dialect get dialect func (scope *Scope) Dialect() Dialect { - return scope.db.parent.dialect + return scope.db.dialect } // Quote used to quote string to escape them for database @@ -1216,11 +1216,17 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on func (scope *Scope) removeForeignKey(field string, dest string) { keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return } - var query = `ALTER TABLE %s DROP CONSTRAINT %s;` + var mysql mysql + var query string + if scope.Dialect().GetName() == mysql.GetName() { + query = `ALTER TABLE %s DROP FOREIGN KEY %s;` + } else { + query = `ALTER TABLE %s DROP CONSTRAINT %s;` + } + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() }