
Exporting sqlCommon as SQLCommon. This allows passing alternate implementations of the database connection, or wrapping the connection with middleware. This change didn't change any usages of the database variables. All usages were already only using the functions defined in SQLCommon. This does cause a breaking change in Dialect, since *sql.DB was referenced in the interface.
155 lines
4.1 KiB
Go
155 lines
4.1 KiB
Go
package mssql
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
_ "github.com/denisenkom/go-mssqldb"
|
|
"github.com/jinzhu/gorm"
|
|
)
|
|
|
|
func setIdentityInsert(scope *gorm.Scope) {
|
|
if scope.Dialect().GetName() == "mssql" {
|
|
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
|
|
}
|
|
}
|
|
|
|
func init() {
|
|
gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert)
|
|
gorm.RegisterDialect("mssql", &mssql{})
|
|
}
|
|
|
|
type mssql struct {
|
|
db gorm.SQLCommon
|
|
gorm.DefaultForeignKeyNamer
|
|
}
|
|
|
|
func (mssql) GetName() string {
|
|
return "mssql"
|
|
}
|
|
|
|
func (s *mssql) SetDB(db gorm.SQLCommon) {
|
|
s.db = db
|
|
}
|
|
|
|
func (mssql) BindVar(i int) string {
|
|
return "$$" // ?
|
|
}
|
|
|
|
func (mssql) Quote(key string) string {
|
|
return fmt.Sprintf(`"%s"`, key)
|
|
}
|
|
|
|
func (s *mssql) DataTypeOf(field *gorm.StructField) string {
|
|
var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s)
|
|
|
|
if sqlType == "" {
|
|
switch dataValue.Kind() {
|
|
case reflect.Bool:
|
|
sqlType = "bit"
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
|
sqlType = "int IDENTITY(1,1)"
|
|
} else {
|
|
sqlType = "int"
|
|
}
|
|
case reflect.Int64, reflect.Uint64:
|
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
|
sqlType = "bigint IDENTITY(1,1)"
|
|
} else {
|
|
sqlType = "bigint"
|
|
}
|
|
case reflect.Float32, reflect.Float64:
|
|
sqlType = "float"
|
|
case reflect.String:
|
|
if size > 0 && size < 65532 {
|
|
sqlType = fmt.Sprintf("nvarchar(%d)", size)
|
|
} else {
|
|
sqlType = "text"
|
|
}
|
|
case reflect.Struct:
|
|
if _, ok := dataValue.Interface().(time.Time); ok {
|
|
sqlType = "datetime2"
|
|
}
|
|
default:
|
|
if _, ok := dataValue.Interface().([]byte); ok {
|
|
if size > 0 && size < 65532 {
|
|
sqlType = fmt.Sprintf("varchar(%d)", size)
|
|
} else {
|
|
sqlType = "text"
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if sqlType == "" {
|
|
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String()))
|
|
}
|
|
|
|
if strings.TrimSpace(additionalType) == "" {
|
|
return sqlType
|
|
}
|
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
|
}
|
|
|
|
func (s mssql) HasIndex(tableName string, indexName string) bool {
|
|
var count int
|
|
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
|
|
return count > 0
|
|
}
|
|
|
|
func (s mssql) RemoveIndex(tableName string, indexName string) error {
|
|
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
|
return err
|
|
}
|
|
|
|
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
|
return false
|
|
}
|
|
|
|
func (s mssql) HasTable(tableName string) bool {
|
|
var count int
|
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count)
|
|
return count > 0
|
|
}
|
|
|
|
func (s mssql) HasColumn(tableName string, columnName string) bool {
|
|
var count int
|
|
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
|
|
return count > 0
|
|
}
|
|
|
|
func (s mssql) CurrentDatabase() (name string) {
|
|
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
|
|
return
|
|
}
|
|
|
|
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
|
if offset != nil {
|
|
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset > 0 {
|
|
sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset)
|
|
}
|
|
}
|
|
if limit != nil {
|
|
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit > 0 {
|
|
if sql == "" {
|
|
// add default zero offset
|
|
sql += " OFFSET 0 ROWS"
|
|
}
|
|
sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (mssql) SelectFromDummyTable() string {
|
|
return ""
|
|
}
|
|
|
|
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
|
return ""
|
|
}
|