
The dialect must define its own foreign key generator method. The previous default is available as a method on gorm.DefaultForeignKeyNamer and can be embedded in other dialects. The mysql dialect uses the first 24 characters plus an sha1 hash of the full key name if the key name is more than 64 characters.
135 lines
3.6 KiB
Go
135 lines
3.6 KiB
Go
package gorm
|
|
|
|
import (
|
|
"crypto/sha1"
|
|
"fmt"
|
|
"reflect"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type mysql struct {
|
|
commonDialect
|
|
}
|
|
|
|
func init() {
|
|
RegisterDialect("mysql", &mysql{})
|
|
}
|
|
|
|
func (mysql) GetName() string {
|
|
return "mysql"
|
|
}
|
|
|
|
func (mysql) Quote(key string) string {
|
|
return fmt.Sprintf("`%s`", key)
|
|
}
|
|
|
|
// Get Data Type for MySQL Dialect
|
|
func (mysql) DataTypeOf(field *StructField) string {
|
|
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
|
|
|
|
if sqlType == "" {
|
|
switch dataValue.Kind() {
|
|
case reflect.Bool:
|
|
sqlType = "boolean"
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
|
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
|
sqlType = "int AUTO_INCREMENT"
|
|
} else {
|
|
sqlType = "int"
|
|
}
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
|
sqlType = "int unsigned AUTO_INCREMENT"
|
|
} else {
|
|
sqlType = "int unsigned"
|
|
}
|
|
case reflect.Int64:
|
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
|
sqlType = "bigint AUTO_INCREMENT"
|
|
} else {
|
|
sqlType = "bigint"
|
|
}
|
|
case reflect.Uint64:
|
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
|
sqlType = "bigint unsigned AUTO_INCREMENT"
|
|
} else {
|
|
sqlType = "bigint unsigned"
|
|
}
|
|
case reflect.Float32, reflect.Float64:
|
|
sqlType = "double"
|
|
case reflect.String:
|
|
if size > 0 && size < 65532 {
|
|
sqlType = fmt.Sprintf("varchar(%d)", size)
|
|
} else {
|
|
sqlType = "longtext"
|
|
}
|
|
case reflect.Struct:
|
|
if _, ok := dataValue.Interface().(time.Time); ok {
|
|
if _, ok := field.TagSettings["NOT NULL"]; ok {
|
|
sqlType = "timestamp"
|
|
} else {
|
|
sqlType = "timestamp NULL"
|
|
}
|
|
}
|
|
default:
|
|
if _, ok := dataValue.Interface().([]byte); ok {
|
|
if size > 0 && size < 65532 {
|
|
sqlType = fmt.Sprintf("varbinary(%d)", size)
|
|
} else {
|
|
sqlType = "longblob"
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if sqlType == "" {
|
|
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
|
|
}
|
|
|
|
if strings.TrimSpace(additionalType) == "" {
|
|
return sqlType
|
|
}
|
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
|
}
|
|
|
|
func (s mysql) RemoveIndex(tableName string, indexName string) error {
|
|
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
|
return err
|
|
}
|
|
|
|
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
|
var count int
|
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.currentDatabase(), tableName, foreignKeyName).Scan(&count)
|
|
return count > 0
|
|
}
|
|
|
|
func (s mysql) currentDatabase() (name string) {
|
|
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
|
return
|
|
}
|
|
|
|
func (mysql) SelectFromDummyTable() string {
|
|
return "FROM DUAL"
|
|
}
|
|
|
|
func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
|
|
keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest)
|
|
if len(keyName) <= 64 {
|
|
return keyName
|
|
}
|
|
h := sha1.New()
|
|
h.Write([]byte(keyName))
|
|
bs := h.Sum(nil)
|
|
|
|
// sha1 is 40 digits, keep first 24 characters of destination
|
|
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_")
|
|
|
|
return fmt.Sprintf("%s%x", keyName[:24], bs)
|
|
}
|