Merge branch 'master' into master
This commit is contained in:
commit
9601a8aa56
@ -72,6 +72,12 @@ func RegisterDialect(name string, dialect Dialect) {
|
|||||||
dialectsMap[name] = dialect
|
dialectsMap[name] = dialect
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDialect gets the dialect for the specified dialect name
|
||||||
|
func GetDialect(name string) (dialect Dialect, ok bool) {
|
||||||
|
dialect, ok = dialectsMap[name]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// ParseFieldStructForDialect get field's sql data type
|
// ParseFieldStructForDialect get field's sql data type
|
||||||
var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
|
var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
|
||||||
// Get redirected field type
|
// Get redirected field type
|
||||||
|
@ -130,7 +130,14 @@ func (s mssql) RemoveIndex(tableName string, indexName string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
return false
|
var count int
|
||||||
|
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||||
|
s.db.QueryRow(`SELECT count(*)
|
||||||
|
FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id
|
||||||
|
inner join information_schema.tables as I on I.TABLE_NAME = T.name
|
||||||
|
WHERE F.name = ?
|
||||||
|
AND T.Name = ? AND I.TABLE_CATALOG = ?;`, foreignKeyName, tableName, currentDatabase).Scan(&count)
|
||||||
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasTable(tableName string) bool {
|
func (s mssql) HasTable(tableName string) bool {
|
||||||
|
8
main.go
8
main.go
@ -61,6 +61,8 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
|||||||
dbSQL, err = sql.Open(driver, source)
|
dbSQL, err = sql.Open(driver, source)
|
||||||
case SQLCommon:
|
case SQLCommon:
|
||||||
dbSQL = value
|
dbSQL = value
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid database source: %v is not a valid type", value)
|
||||||
}
|
}
|
||||||
|
|
||||||
db = &DB{
|
db = &DB{
|
||||||
@ -491,7 +493,8 @@ func (s *DB) Begin() *DB {
|
|||||||
|
|
||||||
// Commit commit a transaction
|
// Commit commit a transaction
|
||||||
func (s *DB) Commit() *DB {
|
func (s *DB) Commit() *DB {
|
||||||
if db, ok := s.db.(sqlTx); ok && db != nil {
|
var emptySQLTx *sql.Tx
|
||||||
|
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
||||||
s.AddError(db.Commit())
|
s.AddError(db.Commit())
|
||||||
} else {
|
} else {
|
||||||
s.AddError(ErrInvalidTransaction)
|
s.AddError(ErrInvalidTransaction)
|
||||||
@ -501,7 +504,8 @@ func (s *DB) Commit() *DB {
|
|||||||
|
|
||||||
// Rollback rollback a transaction
|
// Rollback rollback a transaction
|
||||||
func (s *DB) Rollback() *DB {
|
func (s *DB) Rollback() *DB {
|
||||||
if db, ok := s.db.(sqlTx); ok && db != nil {
|
var emptySQLTx *sql.Tx
|
||||||
|
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
||||||
s.AddError(db.Rollback())
|
s.AddError(db.Rollback())
|
||||||
} else {
|
} else {
|
||||||
s.AddError(ErrInvalidTransaction)
|
s.AddError(ErrInvalidTransaction)
|
||||||
|
17
main_test.go
17
main_test.go
@ -8,6 +8,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -79,6 +80,22 @@ func OpenTestConnection() (db *gorm.DB, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpen_ReturnsError_WithBadArgs(t *testing.T) {
|
||||||
|
stringRef := "foo"
|
||||||
|
testCases := []interface{}{42, time.Now(), &stringRef}
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) {
|
||||||
|
_, err := gorm.Open("postgresql", tc)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Should got error with invalid database source")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(err.Error(), "invalid database source:") {
|
||||||
|
t.Errorf("Should got error starting with \"invalid database source:\", but got %q", err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStringPrimaryKey(t *testing.T) {
|
func TestStringPrimaryKey(t *testing.T) {
|
||||||
type UUIDStruct struct {
|
type UUIDStruct struct {
|
||||||
ID string `gorm:"primary_key"`
|
ID string `gorm:"primary_key"`
|
||||||
|
2
scope.go
2
scope.go
@ -1215,7 +1215,7 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) removeForeignKey(field string, dest string) {
|
func (scope *Scope) removeForeignKey(field string, dest string) {
|
||||||
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest)
|
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
|
||||||
|
|
||||||
if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
||||||
return
|
return
|
||||||
|
Loading…
x
Reference in New Issue
Block a user