Merge commit '48e41440afa6a741a3e345f2cfbabca08f6fb1ac' into wing

# Conflicts:
#	scope.go
This commit is contained in:
Wing Gao 2018-02-09 17:01:15 +08:00
commit 00df68d2d0
9 changed files with 61 additions and 27 deletions

View File

@ -6,7 +6,7 @@ Make sure these boxes checked before submitting your pull request.
- [] Write good commit message, try to squash your commits into a single one - [] Write good commit message, try to squash your commits into a single one
- [] Run `./build.sh` in `gh-pages` branch for document changes - [] Run `./build.sh` in `gh-pages` branch for document changes
For significant changes like big bug fixes, new features, please open an issue to make a agreement on an implementation design/plan first before starting it. For significant changes like big bug fixes, new features, please open an issue to make an agreement on an implementation design/plan first before starting it.
Thank you. Thank you.

View File

@ -3,6 +3,7 @@
The fantastic ORM library for Golang, aims to be developer friendly. The fantastic ORM library for Golang, aims to be developer friendly.
[![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm)
[![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921) [![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921)
[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) [![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm)
@ -31,7 +32,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
## Supporting the project ## Supporting the project
[![http://patreon.com/jinzhu](http://patreon_public_assets.s3.amazonaws.com/sized/becomeAPatronBanner.png)](http://patreon.com/jinzhu) [![http://patreon.com/jinzhu](https://c5.patreon.com/external/logo/become_a_patron_button.png)](http://patreon.com/jinzhu)
## Author ## Author

View File

@ -41,8 +41,8 @@ type Dialect interface {
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDReturningSuffix(tableName, columnName string) string LastInsertIDReturningSuffix(tableName, columnName string) string
// BuildForeignKeyName returns a foreign key name for the given table, field and reference // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
BuildForeignKeyName(tableName, field, dest string) string BuildKeyName(kind, tableName string, fields ...string) string
// CurrentDatabase return current database name // CurrentDatabase return current database name
CurrentDatabase() string CurrentDatabase() string
@ -114,3 +114,11 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel
return fieldValue, dataType, size, strings.TrimSpace(additionalType) return fieldValue, dataType, size, strings.TrimSpace(additionalType)
} }
func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1]
}
return dialect.CurrentDatabase(), tableName
}

View File

@ -102,7 +102,8 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
func (s commonDialect) HasIndex(tableName string, indexName string) bool { func (s commonDialect) HasIndex(tableName string, indexName string) bool {
var count int var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count) currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
return count > 0 return count > 0
} }
@ -117,13 +118,15 @@ func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bo
func (s commonDialect) HasTable(tableName string) bool { func (s commonDialect) HasTable(tableName string) bool {
var count int var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count) currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
return count > 0 return count > 0
} }
func (s commonDialect) HasColumn(tableName string, columnName string) bool { func (s commonDialect) HasColumn(tableName string, columnName string) bool {
var count int var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0 return count > 0
} }
@ -154,9 +157,10 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s
return "" return ""
} }
func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string { // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest) func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_")
return keyName return keyName
} }

View File

@ -144,7 +144,8 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count) currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
return count > 0 return count > 0
} }
@ -157,8 +158,8 @@ func (mysql) SelectFromDummyTable() string {
return "FROM DUAL" return "FROM DUAL"
} }
func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string {
keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest) keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...)
if utf8.RuneCountInString(keyName) <= 64 { if utf8.RuneCountInString(keyName) <= 64 {
return keyName return keyName
} }
@ -166,8 +167,8 @@ func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
h.Write([]byte(keyName)) h.Write([]byte(keyName))
bs := h.Sum(nil) bs := h.Sum(nil)
// sha1 is 40 digits, keep first 24 characters of destination // sha1 is 40 characters, keep first 24 characters of destination
destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_")) destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_"))
if len(destRunes) > 24 { if len(destRunes) > 24 {
destRunes = destRunes[:24] destRunes = destRunes[:24]
} }

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"encoding/json"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
@ -68,9 +69,14 @@ func (s *postgres) DataTypeOf(field *StructField) string {
default: default:
if IsByteArrayOrSlice(dataValue) { if IsByteArrayOrSlice(dataValue) {
sqlType = "bytea" sqlType = "bytea"
if isUUID(dataValue) { if isUUID(dataValue) {
sqlType = "uuid" sqlType = "uuid"
} }
if isJSON(dataValue) {
sqlType = "jsonb"
}
} }
} }
} }
@ -130,3 +136,8 @@ func isUUID(value reflect.Value) bool {
lower := strings.ToLower(typename) lower := strings.ToLower(typename)
return "uuid" == lower || "guid" == lower return "uuid" == lower || "guid" == lower
} }
func isJSON(value reflect.Value) bool {
_, ok := value.Interface().(json.RawMessage)
return ok
}

View File

@ -128,13 +128,15 @@ func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
func (s mssql) HasTable(tableName string) bool { func (s mssql) HasTable(tableName string) bool {
var count int var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count) currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count)
return count > 0 return count > 0
} }
func (s mssql) HasColumn(tableName string, columnName string) bool { func (s mssql) HasColumn(tableName string, columnName string) bool {
var count int var count int
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0 return count > 0
} }
@ -168,3 +170,11 @@ func (mssql) SelectFromDummyTable() string {
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
return "" return ""
} }
func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1]
}
return dialect.CurrentDatabase(), tableName
}

View File

@ -1165,7 +1165,8 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
} }
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest) // Compatible with old generated key
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
return return
@ -1209,7 +1210,7 @@ func (scope *Scope) autoIndex() *Scope {
for _, name := range names { for _, name := range names {
if name == "INDEX" || name == "" { if name == "INDEX" || name == "" {
name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName)
} }
indexes[name] = append(indexes[name], field.DBName) indexes[name] = append(indexes[name], field.DBName)
} }
@ -1220,7 +1221,7 @@ func (scope *Scope) autoIndex() *Scope {
for _, name := range names { for _, name := range names {
if name == "UNIQUE_INDEX" || name == "" { if name == "UNIQUE_INDEX" || name == "" {
name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName)
} }
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
} }
@ -1228,17 +1229,15 @@ func (scope *Scope) autoIndex() *Scope {
} }
for name, columns := range indexes { for name, columns := range indexes {
db := scope.NewDB().Model(scope.Value).AddIndex(name, columns...) if db := scope.NewDB().Model(scope.Value).AddIndex(name, columns...); db.Error != nil {
if db.Error != nil { scope.db.AddError(db.Error)
scope.db.Error = db.Error
return scope return scope
} }
} }
for name, columns := range uniqueIndexes { for name, columns := range uniqueIndexes {
db := scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) if db := scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil {
if db.Error != nil { scope.db.AddError(db.Error)
scope.db.Error = db.Error
return scope return scope
} }
} }

View File

@ -23,7 +23,7 @@ var NowFunc = func() time.Time {
} }
// Copied from golint // Copied from golint
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
var commonInitialismsReplacer *strings.Replacer var commonInitialismsReplacer *strings.Replacer
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`)