Merge branch 'master' into master
This commit is contained in:
commit
b296778a5a
@ -59,7 +59,7 @@ func createCallback(scope *Scope) {
|
|||||||
|
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if scope.changeableField(field) {
|
if scope.changeableField(field) {
|
||||||
if field.IsNormal {
|
if field.IsNormal && !field.IsIgnored {
|
||||||
if field.IsBlank && field.HasDefaultValue {
|
if field.IsBlank && field.HasDefaultValue {
|
||||||
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
|
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
|
||||||
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
|
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
|
||||||
|
@ -19,6 +19,11 @@ func queryCallback(scope *Scope) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//we are only preloading relations, dont touch base model
|
||||||
|
if _, skip := scope.InstanceGet("gorm:only_preload"); skip {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(NowFunc())
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -14,8 +14,14 @@ func preloadCallback(scope *Scope) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := scope.Get("gorm:auto_preload"); ok {
|
if ap, ok := scope.Get("gorm:auto_preload"); ok {
|
||||||
autoPreload(scope)
|
// If gorm:auto_preload IS NOT a bool then auto preload.
|
||||||
|
// Else if it IS a bool, use the value
|
||||||
|
if apb, ok := ap.(bool); !ok {
|
||||||
|
autoPreload(scope)
|
||||||
|
} else if apb {
|
||||||
|
autoPreload(scope)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.Search.preload == nil || scope.HasError() {
|
if scope.Search.preload == nil || scope.HasError() {
|
||||||
@ -94,7 +100,7 @@ func autoPreload(scope *Scope) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if val, ok := field.TagSettings["PRELOAD"]; ok {
|
if val, ok := field.TagSettingsGet("PRELOAD"); ok {
|
||||||
if preload, err := strconv.ParseBool(val); err != nil {
|
if preload, err := strconv.ParseBool(val); err != nil {
|
||||||
scope.Err(errors.New("invalid preload option"))
|
scope.Err(errors.New("invalid preload option"))
|
||||||
return
|
return
|
||||||
@ -155,14 +161,17 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
|
|||||||
)
|
)
|
||||||
|
|
||||||
if indirectScopeValue.Kind() == reflect.Slice {
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
|
foreignValuesToResults := make(map[string]reflect.Value)
|
||||||
|
for i := 0; i < resultsValue.Len(); i++ {
|
||||||
|
result := resultsValue.Index(i)
|
||||||
|
foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames))
|
||||||
|
foreignValuesToResults[foreignValues] = result
|
||||||
|
}
|
||||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
for i := 0; i < resultsValue.Len(); i++ {
|
indirectValue := indirect(indirectScopeValue.Index(j))
|
||||||
result := resultsValue.Index(i)
|
valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames))
|
||||||
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
if result, found := foreignValuesToResults[valueString]; found {
|
||||||
if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
|
indirectValue.FieldByName(field.Name).Set(result)
|
||||||
indirectValue.FieldByName(field.Name).Set(result)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -249,13 +258,21 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
|
|||||||
indirectScopeValue = scope.IndirectValue()
|
indirectScopeValue = scope.IndirectValue()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
foreignFieldToObjects := make(map[string][]*reflect.Value)
|
||||||
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
|
object := indirect(indirectScopeValue.Index(j))
|
||||||
|
valueString := toString(getValueFromFields(object, relation.ForeignFieldNames))
|
||||||
|
foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for i := 0; i < resultsValue.Len(); i++ {
|
for i := 0; i < resultsValue.Len(); i++ {
|
||||||
result := resultsValue.Index(i)
|
result := resultsValue.Index(i)
|
||||||
if indirectScopeValue.Kind() == reflect.Slice {
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
value := getValueFromFields(result, relation.AssociationForeignFieldNames)
|
valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames))
|
||||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
if objects, found := foreignFieldToObjects[valueString]; found {
|
||||||
object := indirect(indirectScopeValue.Index(j))
|
for _, object := range objects {
|
||||||
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
|
|
||||||
object.FieldByName(field.Name).Set(result)
|
object.FieldByName(field.Name).Set(result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -21,9 +21,7 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea
|
|||||||
|
|
||||||
if v, ok := value.(string); ok {
|
if v, ok := value.(string); ok {
|
||||||
v = strings.ToLower(v)
|
v = strings.ToLower(v)
|
||||||
if v == "false" || v != "skip" {
|
return v == "true"
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
@ -36,26 +34,28 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea
|
|||||||
if value, ok := scope.Get("gorm:save_associations"); ok {
|
if value, ok := scope.Get("gorm:save_associations"); ok {
|
||||||
autoUpdate = checkTruth(value)
|
autoUpdate = checkTruth(value)
|
||||||
autoCreate = autoUpdate
|
autoCreate = autoUpdate
|
||||||
} else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok {
|
saveReference = autoUpdate
|
||||||
|
} else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok {
|
||||||
autoUpdate = checkTruth(value)
|
autoUpdate = checkTruth(value)
|
||||||
autoCreate = autoUpdate
|
autoCreate = autoUpdate
|
||||||
|
saveReference = autoUpdate
|
||||||
}
|
}
|
||||||
|
|
||||||
if value, ok := scope.Get("gorm:association_autoupdate"); ok {
|
if value, ok := scope.Get("gorm:association_autoupdate"); ok {
|
||||||
autoUpdate = checkTruth(value)
|
autoUpdate = checkTruth(value)
|
||||||
} else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok {
|
} else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok {
|
||||||
autoUpdate = checkTruth(value)
|
autoUpdate = checkTruth(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
if value, ok := scope.Get("gorm:association_autocreate"); ok {
|
if value, ok := scope.Get("gorm:association_autocreate"); ok {
|
||||||
autoCreate = checkTruth(value)
|
autoCreate = checkTruth(value)
|
||||||
} else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok {
|
} else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok {
|
||||||
autoCreate = checkTruth(value)
|
autoCreate = checkTruth(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
if value, ok := scope.Get("gorm:association_save_reference"); ok {
|
if value, ok := scope.Get("gorm:association_save_reference"); ok {
|
||||||
saveReference = checkTruth(value)
|
saveReference = checkTruth(value)
|
||||||
} else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok {
|
} else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok {
|
||||||
saveReference = checkTruth(value)
|
saveReference = checkTruth(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -76,7 +76,9 @@ func updateCallback(scope *Scope) {
|
|||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if scope.changeableField(field) {
|
if scope.changeableField(field) {
|
||||||
if !field.IsPrimaryKey && field.IsNormal {
|
if !field.IsPrimaryKey && field.IsNormal {
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||||
|
}
|
||||||
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||||
for _, foreignKey := range relationship.ForeignDBNames {
|
for _, foreignKey := range relationship.ForeignDBNames {
|
||||||
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
||||||
|
16
dialect.go
16
dialect.go
@ -72,12 +72,18 @@ func RegisterDialect(name string, dialect Dialect) {
|
|||||||
dialectsMap[name] = dialect
|
dialectsMap[name] = dialect
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDialect gets the dialect for the specified dialect name
|
||||||
|
func GetDialect(name string) (dialect Dialect, ok bool) {
|
||||||
|
dialect, ok = dialectsMap[name]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// ParseFieldStructForDialect get field's sql data type
|
// ParseFieldStructForDialect get field's sql data type
|
||||||
var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
|
var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
|
||||||
// Get redirected field type
|
// Get redirected field type
|
||||||
var (
|
var (
|
||||||
reflectType = field.Struct.Type
|
reflectType = field.Struct.Type
|
||||||
dataType = field.TagSettings["TYPE"]
|
dataType, _ = field.TagSettingsGet("TYPE")
|
||||||
)
|
)
|
||||||
|
|
||||||
for reflectType.Kind() == reflect.Ptr {
|
for reflectType.Kind() == reflect.Ptr {
|
||||||
@ -106,15 +112,17 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Default Size
|
// Default Size
|
||||||
if num, ok := field.TagSettings["SIZE"]; ok {
|
if num, ok := field.TagSettingsGet("SIZE"); ok {
|
||||||
size, _ = strconv.Atoi(num)
|
size, _ = strconv.Atoi(num)
|
||||||
} else {
|
} else {
|
||||||
size = 255
|
size = 255
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default type from tag setting
|
// Default type from tag setting
|
||||||
additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
|
notNull, _ := field.TagSettingsGet("NOT NULL")
|
||||||
if value, ok := field.TagSettings["DEFAULT"]; ok {
|
unique, _ := field.TagSettingsGet("UNIQUE")
|
||||||
|
additionalType = notNull + " " + unique
|
||||||
|
if value, ok := field.TagSettingsGet("DEFAULT"); ok {
|
||||||
additionalType = additionalType + " DEFAULT " + value
|
additionalType = additionalType + " DEFAULT " + value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ func (commonDialect) Quote(key string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool {
|
func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool {
|
||||||
if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
|
||||||
return strings.ToLower(value) != "false"
|
return strings.ToLower(value) != "false"
|
||||||
}
|
}
|
||||||
return field.IsPrimaryKey
|
return field.IsPrimaryKey
|
||||||
|
@ -33,9 +33,9 @@ func (s *mysql) DataTypeOf(field *StructField) string {
|
|||||||
|
|
||||||
// MySQL allows only one auto increment column per table, and it must
|
// MySQL allows only one auto increment column per table, and it must
|
||||||
// be a KEY column.
|
// be a KEY column.
|
||||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
|
||||||
if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey {
|
if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey {
|
||||||
delete(field.TagSettings, "AUTO_INCREMENT")
|
field.TagSettingsDelete("AUTO_INCREMENT")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,42 +45,42 @@ func (s *mysql) DataTypeOf(field *StructField) string {
|
|||||||
sqlType = "boolean"
|
sqlType = "boolean"
|
||||||
case reflect.Int8:
|
case reflect.Int8:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "tinyint AUTO_INCREMENT"
|
sqlType = "tinyint AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "tinyint"
|
sqlType = "tinyint"
|
||||||
}
|
}
|
||||||
case reflect.Int, reflect.Int16, reflect.Int32:
|
case reflect.Int, reflect.Int16, reflect.Int32:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "int AUTO_INCREMENT"
|
sqlType = "int AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "int"
|
sqlType = "int"
|
||||||
}
|
}
|
||||||
case reflect.Uint8:
|
case reflect.Uint8:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "tinyint unsigned AUTO_INCREMENT"
|
sqlType = "tinyint unsigned AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "tinyint unsigned"
|
sqlType = "tinyint unsigned"
|
||||||
}
|
}
|
||||||
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "int unsigned AUTO_INCREMENT"
|
sqlType = "int unsigned AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "int unsigned"
|
sqlType = "int unsigned"
|
||||||
}
|
}
|
||||||
case reflect.Int64:
|
case reflect.Int64:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "bigint AUTO_INCREMENT"
|
sqlType = "bigint AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "bigint"
|
sqlType = "bigint"
|
||||||
}
|
}
|
||||||
case reflect.Uint64:
|
case reflect.Uint64:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "bigint unsigned AUTO_INCREMENT"
|
sqlType = "bigint unsigned AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "bigint unsigned"
|
sqlType = "bigint unsigned"
|
||||||
@ -96,11 +96,11 @@ func (s *mysql) DataTypeOf(field *StructField) string {
|
|||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||||
precision := ""
|
precision := ""
|
||||||
if p, ok := field.TagSettings["PRECISION"]; ok {
|
if p, ok := field.TagSettingsGet("PRECISION"); ok {
|
||||||
precision = fmt.Sprintf("(%s)", p)
|
precision = fmt.Sprintf("(%s)", p)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := field.TagSettings["NOT NULL"]; ok {
|
if _, ok := field.TagSettingsGet("NOT NULL"); ok {
|
||||||
sqlType = fmt.Sprintf("timestamp%v", precision)
|
sqlType = fmt.Sprintf("timestamp%v", precision)
|
||||||
} else {
|
} else {
|
||||||
sqlType = fmt.Sprintf("timestamp%v NULL", precision)
|
sqlType = fmt.Sprintf("timestamp%v NULL", precision)
|
||||||
|
@ -34,14 +34,14 @@ func (s *postgres) DataTypeOf(field *StructField) string {
|
|||||||
sqlType = "boolean"
|
sqlType = "boolean"
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "serial"
|
sqlType = "serial"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "integer"
|
sqlType = "integer"
|
||||||
}
|
}
|
||||||
case reflect.Int64, reflect.Uint32, reflect.Uint64:
|
case reflect.Int64, reflect.Uint32, reflect.Uint64:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "bigserial"
|
sqlType = "bigserial"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "bigint"
|
sqlType = "bigint"
|
||||||
@ -49,7 +49,7 @@ func (s *postgres) DataTypeOf(field *StructField) string {
|
|||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
sqlType = "numeric"
|
sqlType = "numeric"
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
if _, ok := field.TagSettings["SIZE"]; !ok {
|
if _, ok := field.TagSettingsGet("SIZE"); !ok {
|
||||||
size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
|
size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,14 +29,14 @@ func (s *sqlite3) DataTypeOf(field *StructField) string {
|
|||||||
sqlType = "bool"
|
sqlType = "bool"
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "integer primary key autoincrement"
|
sqlType = "integer primary key autoincrement"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "integer"
|
sqlType = "integer"
|
||||||
}
|
}
|
||||||
case reflect.Int64, reflect.Uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "integer primary key autoincrement"
|
sqlType = "integer primary key autoincrement"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "bigint"
|
sqlType = "bigint"
|
||||||
|
@ -1,12 +1,16 @@
|
|||||||
package mssql
|
package mssql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
// Importing mssql driver package only in dialect file, otherwide not needed
|
||||||
_ "github.com/denisenkom/go-mssqldb"
|
_ "github.com/denisenkom/go-mssqldb"
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
)
|
)
|
||||||
@ -14,7 +18,7 @@ import (
|
|||||||
func setIdentityInsert(scope *gorm.Scope) {
|
func setIdentityInsert(scope *gorm.Scope) {
|
||||||
if scope.Dialect().GetName() == "mssql" {
|
if scope.Dialect().GetName() == "mssql" {
|
||||||
for _, field := range scope.PrimaryFields() {
|
for _, field := range scope.PrimaryFields() {
|
||||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank {
|
if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank {
|
||||||
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
|
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
|
||||||
scope.InstanceSet("mssql:identity_insert_on", true)
|
scope.InstanceSet("mssql:identity_insert_on", true)
|
||||||
}
|
}
|
||||||
@ -66,14 +70,14 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string {
|
|||||||
sqlType = "bit"
|
sqlType = "bit"
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "int IDENTITY(1,1)"
|
sqlType = "int IDENTITY(1,1)"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "int"
|
sqlType = "int"
|
||||||
}
|
}
|
||||||
case reflect.Int64, reflect.Uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "bigint IDENTITY(1,1)"
|
sqlType = "bigint IDENTITY(1,1)"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "bigint"
|
sqlType = "bigint"
|
||||||
@ -112,7 +116,7 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
|
func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
|
||||||
if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
|
||||||
return value != "FALSE"
|
return value != "FALSE"
|
||||||
}
|
}
|
||||||
return field.IsPrimaryKey
|
return field.IsPrimaryKey
|
||||||
@ -130,7 +134,14 @@ func (s mssql) RemoveIndex(tableName string, indexName string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
return false
|
var count int
|
||||||
|
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||||
|
s.db.QueryRow(`SELECT count(*)
|
||||||
|
FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id
|
||||||
|
inner join information_schema.tables as I on I.TABLE_NAME = T.name
|
||||||
|
WHERE F.name = ?
|
||||||
|
AND T.Name = ? AND I.TABLE_CATALOG = ?;`, foreignKeyName, tableName, currentDatabase).Scan(&count)
|
||||||
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasTable(tableName string) bool {
|
func (s mssql) HasTable(tableName string) bool {
|
||||||
@ -194,3 +205,27 @@ func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, st
|
|||||||
}
|
}
|
||||||
return dialect.CurrentDatabase(), tableName
|
return dialect.CurrentDatabase(), tableName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JSON type to support easy handling of JSON data in character table fields
|
||||||
|
// using golang json.RawMessage for deferred decoding/encoding
|
||||||
|
type JSON struct {
|
||||||
|
json.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value get value of JSON
|
||||||
|
func (j JSON) Value() (driver.Value, error) {
|
||||||
|
if len(j.RawMessage) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return j.MarshalJSON()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan scan value into JSON
|
||||||
|
func (j *JSON) Scan(value interface{}) error {
|
||||||
|
str, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value))
|
||||||
|
}
|
||||||
|
bytes := []byte(str)
|
||||||
|
return json.Unmarshal(bytes, j)
|
||||||
|
}
|
||||||
|
@ -4,11 +4,11 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
|
|
||||||
_ "github.com/lib/pq"
|
|
||||||
"github.com/lib/pq/hstore"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
"github.com/lib/pq/hstore"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Hstore map[string]*string
|
type Hstore map[string]*string
|
||||||
|
@ -6,7 +6,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct
|
// ErrRecordNotFound record not found error, happens when only haven't find any matched data when looking up with a struct, finding a slice won't return this error
|
||||||
ErrRecordNotFound = errors.New("record not found")
|
ErrRecordNotFound = errors.New("record not found")
|
||||||
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
|
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
|
||||||
ErrInvalidSQL = errors.New("invalid SQL")
|
ErrInvalidSQL = errors.New("invalid SQL")
|
||||||
|
@ -43,7 +43,7 @@ func TestCalculateField(t *testing.T) {
|
|||||||
|
|
||||||
if field, ok := scope.FieldByName("embedded_name"); !ok {
|
if field, ok := scope.FieldByName("embedded_name"); !ok {
|
||||||
t.Errorf("should find embedded field")
|
t.Errorf("should find embedded field")
|
||||||
} else if _, ok := field.TagSettings["NOT NULL"]; !ok {
|
} else if _, ok := field.TagSettingsGet("NOT NULL"); !ok {
|
||||||
t.Errorf("should find embedded field's tag settings")
|
t.Errorf("should find embedded field's tag settings")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
43
main.go
43
main.go
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -21,7 +22,7 @@ type DB struct {
|
|||||||
logMode int
|
logMode int
|
||||||
logger logger
|
logger logger
|
||||||
search *search
|
search *search
|
||||||
values map[string]interface{}
|
values sync.Map
|
||||||
|
|
||||||
// global db
|
// global db
|
||||||
parent *DB
|
parent *DB
|
||||||
@ -48,6 +49,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
|||||||
}
|
}
|
||||||
var source string
|
var source string
|
||||||
var dbSQL SQLCommon
|
var dbSQL SQLCommon
|
||||||
|
var ownDbSQL bool
|
||||||
|
|
||||||
switch value := args[0].(type) {
|
switch value := args[0].(type) {
|
||||||
case string:
|
case string:
|
||||||
@ -59,14 +61,17 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
|||||||
source = args[1].(string)
|
source = args[1].(string)
|
||||||
}
|
}
|
||||||
dbSQL, err = sql.Open(driver, source)
|
dbSQL, err = sql.Open(driver, source)
|
||||||
|
ownDbSQL = true
|
||||||
case SQLCommon:
|
case SQLCommon:
|
||||||
dbSQL = value
|
dbSQL = value
|
||||||
|
ownDbSQL = false
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid database source: %v is not a valid type", value)
|
||||||
}
|
}
|
||||||
|
|
||||||
db = &DB{
|
db = &DB{
|
||||||
db: dbSQL,
|
db: dbSQL,
|
||||||
logger: defaultLogger,
|
logger: defaultLogger,
|
||||||
values: map[string]interface{}{},
|
|
||||||
callbacks: DefaultCallback,
|
callbacks: DefaultCallback,
|
||||||
dialect: newDialect(dialect, dbSQL),
|
dialect: newDialect(dialect, dbSQL),
|
||||||
}
|
}
|
||||||
@ -76,7 +81,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
|||||||
}
|
}
|
||||||
// Send a ping to make sure the database connection is alive.
|
// Send a ping to make sure the database connection is alive.
|
||||||
if d, ok := dbSQL.(*sql.DB); ok {
|
if d, ok := dbSQL.(*sql.DB); ok {
|
||||||
if err = d.Ping(); err != nil {
|
if err = d.Ping(); err != nil && ownDbSQL {
|
||||||
d.Close()
|
d.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -117,7 +122,7 @@ func (s *DB) CommonDB() SQLCommon {
|
|||||||
|
|
||||||
// Dialect get dialect
|
// Dialect get dialect
|
||||||
func (s *DB) Dialect() Dialect {
|
func (s *DB) Dialect() Dialect {
|
||||||
return s.parent.dialect
|
return s.dialect
|
||||||
}
|
}
|
||||||
|
|
||||||
// Callback return `Callbacks` container, you could add/change/delete callbacks with it
|
// Callback return `Callbacks` container, you could add/change/delete callbacks with it
|
||||||
@ -157,7 +162,7 @@ func (s *DB) HasBlockGlobalUpdate() bool {
|
|||||||
|
|
||||||
// SingularTable use singular table by default
|
// SingularTable use singular table by default
|
||||||
func (s *DB) SingularTable(enable bool) {
|
func (s *DB) SingularTable(enable bool) {
|
||||||
modelStructsMap = newModelStructsMap()
|
modelStructsMap = sync.Map{}
|
||||||
s.parent.singularTable = enable
|
s.parent.singularTable = enable
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -309,6 +314,11 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
|||||||
return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//Preloads preloads relations, don`t touch out
|
||||||
|
func (s *DB) Preloads(out interface{}) *DB {
|
||||||
|
return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db
|
||||||
|
}
|
||||||
|
|
||||||
// Scan scan value to a struct
|
// Scan scan value to a struct
|
||||||
func (s *DB) Scan(dest interface{}) *DB {
|
func (s *DB) Scan(dest interface{}) *DB {
|
||||||
return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
|
return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
|
||||||
@ -482,6 +492,8 @@ func (s *DB) Begin() *DB {
|
|||||||
if db, ok := c.db.(sqlDb); ok && db != nil {
|
if db, ok := c.db.(sqlDb); ok && db != nil {
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
c.db = interface{}(tx).(SQLCommon)
|
c.db = interface{}(tx).(SQLCommon)
|
||||||
|
|
||||||
|
c.dialect.SetDB(c.db)
|
||||||
c.AddError(err)
|
c.AddError(err)
|
||||||
} else {
|
} else {
|
||||||
c.AddError(ErrCantStartTransaction)
|
c.AddError(ErrCantStartTransaction)
|
||||||
@ -491,7 +503,8 @@ func (s *DB) Begin() *DB {
|
|||||||
|
|
||||||
// Commit commit a transaction
|
// Commit commit a transaction
|
||||||
func (s *DB) Commit() *DB {
|
func (s *DB) Commit() *DB {
|
||||||
if db, ok := s.db.(sqlTx); ok && db != nil {
|
var emptySQLTx *sql.Tx
|
||||||
|
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
||||||
s.AddError(db.Commit())
|
s.AddError(db.Commit())
|
||||||
} else {
|
} else {
|
||||||
s.AddError(ErrInvalidTransaction)
|
s.AddError(ErrInvalidTransaction)
|
||||||
@ -501,7 +514,8 @@ func (s *DB) Commit() *DB {
|
|||||||
|
|
||||||
// Rollback rollback a transaction
|
// Rollback rollback a transaction
|
||||||
func (s *DB) Rollback() *DB {
|
func (s *DB) Rollback() *DB {
|
||||||
if db, ok := s.db.(sqlTx); ok && db != nil {
|
var emptySQLTx *sql.Tx
|
||||||
|
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
||||||
s.AddError(db.Rollback())
|
s.AddError(db.Rollback())
|
||||||
} else {
|
} else {
|
||||||
s.AddError(ErrInvalidTransaction)
|
s.AddError(ErrInvalidTransaction)
|
||||||
@ -670,13 +684,13 @@ func (s *DB) Set(name string, value interface{}) *DB {
|
|||||||
|
|
||||||
// InstantSet instant set setting, will affect current db
|
// InstantSet instant set setting, will affect current db
|
||||||
func (s *DB) InstantSet(name string, value interface{}) *DB {
|
func (s *DB) InstantSet(name string, value interface{}) *DB {
|
||||||
s.values[name] = value
|
s.values.Store(name, value)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get get setting by name
|
// Get get setting by name
|
||||||
func (s *DB) Get(name string) (value interface{}, ok bool) {
|
func (s *DB) Get(name string) (value interface{}, ok bool) {
|
||||||
value, ok = s.values[name]
|
value, ok = s.values.Load(name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -685,7 +699,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
|
|||||||
scope := s.NewScope(source)
|
scope := s.NewScope(source)
|
||||||
for _, field := range scope.GetModelStruct().StructFields {
|
for _, field := range scope.GetModelStruct().StructFields {
|
||||||
if field.Name == column || field.DBName == column {
|
if field.Name == column || field.DBName == column {
|
||||||
if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" {
|
||||||
source := (&Scope{Value: source}).GetModelStruct().ModelType
|
source := (&Scope{Value: source}).GetModelStruct().ModelType
|
||||||
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
|
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
|
||||||
handler.Setup(field.Relationship, many2many, source, destination)
|
handler.Setup(field.Relationship, many2many, source, destination)
|
||||||
@ -740,15 +754,16 @@ func (s *DB) clone() *DB {
|
|||||||
parent: s.parent,
|
parent: s.parent,
|
||||||
logger: s.logger,
|
logger: s.logger,
|
||||||
logMode: s.logMode,
|
logMode: s.logMode,
|
||||||
values: map[string]interface{}{},
|
|
||||||
Value: s.Value,
|
Value: s.Value,
|
||||||
Error: s.Error,
|
Error: s.Error,
|
||||||
blockGlobalUpdate: s.blockGlobalUpdate,
|
blockGlobalUpdate: s.blockGlobalUpdate,
|
||||||
|
dialect: newDialect(s.dialect.GetName(), s.db),
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, value := range s.values {
|
s.values.Range(func(k, v interface{}) bool {
|
||||||
db.values[key] = value
|
db.values.Store(k, v)
|
||||||
}
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
if s.search == nil {
|
if s.search == nil {
|
||||||
db.search = &search{limit: -1, offset: -1}
|
db.search = &search{limit: -1, offset: -1}
|
||||||
|
159
main_test.go
159
main_test.go
@ -8,6 +8,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -79,6 +80,22 @@ func OpenTestConnection() (db *gorm.DB, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpen_ReturnsError_WithBadArgs(t *testing.T) {
|
||||||
|
stringRef := "foo"
|
||||||
|
testCases := []interface{}{42, time.Now(), &stringRef}
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) {
|
||||||
|
_, err := gorm.Open("postgresql", tc)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Should got error with invalid database source")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(err.Error(), "invalid database source:") {
|
||||||
|
t.Errorf("Should got error starting with \"invalid database source:\", but got %q", err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStringPrimaryKey(t *testing.T) {
|
func TestStringPrimaryKey(t *testing.T) {
|
||||||
type UUIDStruct struct {
|
type UUIDStruct struct {
|
||||||
ID string `gorm:"primary_key"`
|
ID string `gorm:"primary_key"`
|
||||||
@ -564,6 +581,60 @@ func TestJoins(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type JoinedIds struct {
|
||||||
|
UserID int64 `gorm:"column:id"`
|
||||||
|
BillingAddressID int64 `gorm:"column:id"`
|
||||||
|
EmailID int64 `gorm:"column:id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScanIdenticalColumnNames(t *testing.T) {
|
||||||
|
var user = User{
|
||||||
|
Name: "joinsIds",
|
||||||
|
Email: "joinIds@example.com",
|
||||||
|
BillingAddress: Address{
|
||||||
|
Address1: "One Park Place",
|
||||||
|
},
|
||||||
|
Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
|
||||||
|
}
|
||||||
|
DB.Save(&user)
|
||||||
|
|
||||||
|
var users []JoinedIds
|
||||||
|
DB.Select("users.id, addresses.id, emails.id").Table("users").
|
||||||
|
Joins("left join addresses on users.billing_address_id = addresses.id").
|
||||||
|
Joins("left join emails on emails.user_id = users.id").
|
||||||
|
Where("name = ?", "joinsIds").Scan(&users)
|
||||||
|
|
||||||
|
if len(users) != 2 {
|
||||||
|
t.Fatal("should find two rows using left join")
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Id != users[0].UserID {
|
||||||
|
t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[0].UserID)
|
||||||
|
}
|
||||||
|
if user.Id != users[1].UserID {
|
||||||
|
t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[1].UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.BillingAddressID.Int64 != users[0].BillingAddressID {
|
||||||
|
t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID)
|
||||||
|
}
|
||||||
|
if user.BillingAddressID.Int64 != users[1].BillingAddressID {
|
||||||
|
t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if users[0].EmailID == users[1].EmailID {
|
||||||
|
t.Errorf("Email ids should be unique. Got %d and %d", users[0].EmailID, users[1].EmailID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if int64(user.Emails[0].Id) != users[0].EmailID && int64(user.Emails[1].Id) != users[0].EmailID {
|
||||||
|
t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[0].EmailID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if int64(user.Emails[0].Id) != users[1].EmailID && int64(user.Emails[1].Id) != users[1].EmailID {
|
||||||
|
t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[1].EmailID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestJoinsWithSelect(t *testing.T) {
|
func TestJoinsWithSelect(t *testing.T) {
|
||||||
type result struct {
|
type result struct {
|
||||||
Name string
|
Name string
|
||||||
@ -862,6 +933,94 @@ func TestOpenWithOneParameter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSaveAssociations(t *testing.T) {
|
||||||
|
db := DB.New()
|
||||||
|
deltaAddressCount := 0
|
||||||
|
if err := db.Model(&Address{}).Count(&deltaAddressCount).Error; err != nil {
|
||||||
|
t.Errorf("failed to fetch address count")
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
placeAddress := &Address{
|
||||||
|
Address1: "somewhere on earth",
|
||||||
|
}
|
||||||
|
ownerAddress1 := &Address{
|
||||||
|
Address1: "near place address",
|
||||||
|
}
|
||||||
|
ownerAddress2 := &Address{
|
||||||
|
Address1: "address2",
|
||||||
|
}
|
||||||
|
db.Create(placeAddress)
|
||||||
|
|
||||||
|
addressCountShouldBe := func(t *testing.T, expectedCount int) {
|
||||||
|
countFromDB := 0
|
||||||
|
t.Helper()
|
||||||
|
err := db.Model(&Address{}).Count(&countFromDB).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Error("failed to fetch address count")
|
||||||
|
}
|
||||||
|
if countFromDB != expectedCount {
|
||||||
|
t.Errorf("address count mismatch: %d", countFromDB)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
addressCountShouldBe(t, deltaAddressCount+1)
|
||||||
|
|
||||||
|
// owner address should be created, place address should be reused
|
||||||
|
place1 := &Place{
|
||||||
|
PlaceAddressID: placeAddress.ID,
|
||||||
|
PlaceAddress: placeAddress,
|
||||||
|
OwnerAddress: ownerAddress1,
|
||||||
|
}
|
||||||
|
err := db.Create(place1).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to store place: %s", err.Error())
|
||||||
|
}
|
||||||
|
addressCountShouldBe(t, deltaAddressCount+2)
|
||||||
|
|
||||||
|
// owner address should be created again, place address should be reused
|
||||||
|
place2 := &Place{
|
||||||
|
PlaceAddressID: placeAddress.ID,
|
||||||
|
PlaceAddress: &Address{
|
||||||
|
ID: 777,
|
||||||
|
Address1: "address1",
|
||||||
|
},
|
||||||
|
OwnerAddress: ownerAddress2,
|
||||||
|
OwnerAddressID: 778,
|
||||||
|
}
|
||||||
|
err = db.Create(place2).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to store place: %s", err.Error())
|
||||||
|
}
|
||||||
|
addressCountShouldBe(t, deltaAddressCount+3)
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
db.Model(&Place{}).Where(&Place{
|
||||||
|
PlaceAddressID: placeAddress.ID,
|
||||||
|
OwnerAddressID: ownerAddress1.ID,
|
||||||
|
}).Count(&count)
|
||||||
|
if count != 1 {
|
||||||
|
t.Errorf("only one instance of (%d, %d) should be available, found: %d",
|
||||||
|
placeAddress.ID, ownerAddress1.ID, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Model(&Place{}).Where(&Place{
|
||||||
|
PlaceAddressID: placeAddress.ID,
|
||||||
|
OwnerAddressID: ownerAddress2.ID,
|
||||||
|
}).Count(&count)
|
||||||
|
if count != 1 {
|
||||||
|
t.Errorf("only one instance of (%d, %d) should be available, found: %d",
|
||||||
|
placeAddress.ID, ownerAddress2.ID, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Model(&Place{}).Where(&Place{
|
||||||
|
PlaceAddressID: placeAddress.ID,
|
||||||
|
}).Count(&count)
|
||||||
|
if count != 2 {
|
||||||
|
t.Errorf("two instances of (%d) should be available, found: %d",
|
||||||
|
placeAddress.ID, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBlockGlobalUpdate(t *testing.T) {
|
func TestBlockGlobalUpdate(t *testing.T) {
|
||||||
db := DB.New()
|
db := DB.New()
|
||||||
db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})
|
db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})
|
||||||
|
@ -118,6 +118,14 @@ type Company struct {
|
|||||||
Owner *User `sql:"-"`
|
Owner *User `sql:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Place struct {
|
||||||
|
Id int64
|
||||||
|
PlaceAddressID int
|
||||||
|
PlaceAddress *Address `gorm:"save_associations:false"`
|
||||||
|
OwnerAddressID int
|
||||||
|
OwnerAddress *Address `gorm:"save_associations:true"`
|
||||||
|
}
|
||||||
|
|
||||||
type EncryptedData []byte
|
type EncryptedData []byte
|
||||||
|
|
||||||
func (data *EncryptedData) Scan(value interface{}) error {
|
func (data *EncryptedData) Scan(value interface{}) error {
|
||||||
@ -284,7 +292,7 @@ func runMigration() {
|
|||||||
DB.Exec(fmt.Sprintf("drop table %v;", table))
|
DB.Exec(fmt.Sprintf("drop table %v;", table))
|
||||||
}
|
}
|
||||||
|
|
||||||
values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}}
|
values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}}
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
DB.DropTable(value)
|
DB.DropTable(value)
|
||||||
}
|
}
|
||||||
@ -398,6 +406,53 @@ func TestAutoMigration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateAndAutomigrateTransaction(t *testing.T) {
|
||||||
|
tx := DB.Begin()
|
||||||
|
|
||||||
|
func() {
|
||||||
|
type Bar struct {
|
||||||
|
ID uint
|
||||||
|
}
|
||||||
|
DB.DropTableIfExists(&Bar{})
|
||||||
|
|
||||||
|
if ok := DB.HasTable("bars"); ok {
|
||||||
|
t.Errorf("Table should not exist, but does")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok := tx.HasTable("bars"); ok {
|
||||||
|
t.Errorf("Table should not exist, but does")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
func() {
|
||||||
|
type Bar struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
err := tx.CreateTable(&Bar{}).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Should have been able to create the table, but couldn't: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok := tx.HasTable(&Bar{}); !ok {
|
||||||
|
t.Errorf("The transaction should be able to see the table")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
func() {
|
||||||
|
type Bar struct {
|
||||||
|
Stuff string
|
||||||
|
}
|
||||||
|
|
||||||
|
err := tx.AutoMigrate(&Bar{}).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Should have been able to alter the table, but couldn't")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
tx.Rollback()
|
||||||
|
}
|
||||||
|
|
||||||
type MultipleIndexes struct {
|
type MultipleIndexes struct {
|
||||||
ID int64
|
ID int64
|
||||||
UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`
|
UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`
|
||||||
|
121
model_struct.go
121
model_struct.go
@ -17,45 +17,29 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
|
|||||||
return defaultTableName
|
return defaultTableName
|
||||||
}
|
}
|
||||||
|
|
||||||
type safeModelStructsMap struct {
|
var modelStructsMap sync.Map
|
||||||
m map[reflect.Type]*ModelStruct
|
|
||||||
l *sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) {
|
|
||||||
s.l.Lock()
|
|
||||||
defer s.l.Unlock()
|
|
||||||
s.m[key] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct {
|
|
||||||
s.l.RLock()
|
|
||||||
defer s.l.RUnlock()
|
|
||||||
return s.m[key]
|
|
||||||
}
|
|
||||||
|
|
||||||
func newModelStructsMap() *safeModelStructsMap {
|
|
||||||
return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)}
|
|
||||||
}
|
|
||||||
|
|
||||||
var modelStructsMap = newModelStructsMap()
|
|
||||||
|
|
||||||
// ModelStruct model definition
|
// ModelStruct model definition
|
||||||
type ModelStruct struct {
|
type ModelStruct struct {
|
||||||
PrimaryFields []*StructField
|
PrimaryFields []*StructField
|
||||||
StructFields []*StructField
|
StructFields []*StructField
|
||||||
ModelType reflect.Type
|
ModelType reflect.Type
|
||||||
|
|
||||||
defaultTableName string
|
defaultTableName string
|
||||||
|
l sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// TableName get model's table name
|
// TableName returns model's table name
|
||||||
func (s *ModelStruct) TableName(db *DB) string {
|
func (s *ModelStruct) TableName(db *DB) string {
|
||||||
|
s.l.Lock()
|
||||||
|
defer s.l.Unlock()
|
||||||
|
|
||||||
if s.defaultTableName == "" && db != nil && s.ModelType != nil {
|
if s.defaultTableName == "" && db != nil && s.ModelType != nil {
|
||||||
// Set default table name
|
// Set default table name
|
||||||
if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok {
|
if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok {
|
||||||
s.defaultTableName = tabler.TableName()
|
s.defaultTableName = tabler.TableName()
|
||||||
} else {
|
} else {
|
||||||
tableName := ToDBName(s.ModelType.Name())
|
tableName := ToTableName(s.ModelType.Name())
|
||||||
if db == nil || !db.parent.singularTable {
|
if db == nil || !db.parent.singularTable {
|
||||||
tableName = inflection.Plural(tableName)
|
tableName = inflection.Plural(tableName)
|
||||||
}
|
}
|
||||||
@ -81,6 +65,30 @@ type StructField struct {
|
|||||||
Struct reflect.StructField
|
Struct reflect.StructField
|
||||||
IsForeignKey bool
|
IsForeignKey bool
|
||||||
Relationship *Relationship
|
Relationship *Relationship
|
||||||
|
|
||||||
|
tagSettingsLock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// TagSettingsSet Sets a tag in the tag settings map
|
||||||
|
func (s *StructField) TagSettingsSet(key, val string) {
|
||||||
|
s.tagSettingsLock.Lock()
|
||||||
|
defer s.tagSettingsLock.Unlock()
|
||||||
|
s.TagSettings[key] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
// TagSettingsGet returns a tag from the tag settings
|
||||||
|
func (s *StructField) TagSettingsGet(key string) (string, bool) {
|
||||||
|
s.tagSettingsLock.RLock()
|
||||||
|
defer s.tagSettingsLock.RUnlock()
|
||||||
|
val, ok := s.TagSettings[key]
|
||||||
|
return val, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// TagSettingsDelete deletes a tag
|
||||||
|
func (s *StructField) TagSettingsDelete(key string) {
|
||||||
|
s.tagSettingsLock.Lock()
|
||||||
|
defer s.tagSettingsLock.Unlock()
|
||||||
|
delete(s.TagSettings, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (structField *StructField) clone() *StructField {
|
func (structField *StructField) clone() *StructField {
|
||||||
@ -104,6 +112,9 @@ func (structField *StructField) clone() *StructField {
|
|||||||
clone.Relationship = &relationship
|
clone.Relationship = &relationship
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// copy the struct field tagSettings, they should be read-locked while they are copied
|
||||||
|
structField.tagSettingsLock.Lock()
|
||||||
|
defer structField.tagSettingsLock.Unlock()
|
||||||
for key, value := range structField.TagSettings {
|
for key, value := range structField.TagSettings {
|
||||||
clone.TagSettings[key] = value
|
clone.TagSettings[key] = value
|
||||||
}
|
}
|
||||||
@ -126,7 +137,7 @@ type Relationship struct {
|
|||||||
|
|
||||||
func getForeignField(column string, fields []*StructField) *StructField {
|
func getForeignField(column string, fields []*StructField) *StructField {
|
||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) {
|
if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) {
|
||||||
return field
|
return field
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -152,8 +163,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get Cached model struct
|
// Get Cached model struct
|
||||||
if value := modelStructsMap.Get(reflectType); value != nil {
|
if value, ok := modelStructsMap.Load(reflectType); ok && value != nil {
|
||||||
return value
|
return value.(*ModelStruct)
|
||||||
}
|
}
|
||||||
|
|
||||||
modelStruct.ModelType = reflectType
|
modelStruct.ModelType = reflectType
|
||||||
@ -170,19 +181,19 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// is ignored field
|
// is ignored field
|
||||||
if _, ok := field.TagSettings["-"]; ok {
|
if _, ok := field.TagSettingsGet("-"); ok {
|
||||||
field.IsIgnored = true
|
field.IsIgnored = true
|
||||||
} else {
|
} else {
|
||||||
if _, ok := field.TagSettings["PRIMARY_KEY"]; ok {
|
if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok {
|
||||||
field.IsPrimaryKey = true
|
field.IsPrimaryKey = true
|
||||||
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := field.TagSettings["DEFAULT"]; ok {
|
if _, ok := field.TagSettingsGet("DEFAULT"); ok {
|
||||||
field.HasDefaultValue = true
|
field.HasDefaultValue = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsPrimaryKey {
|
if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey {
|
||||||
field.HasDefaultValue = true
|
field.HasDefaultValue = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -198,8 +209,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
if indirectType.Kind() == reflect.Struct {
|
if indirectType.Kind() == reflect.Struct {
|
||||||
for i := 0; i < indirectType.NumField(); i++ {
|
for i := 0; i < indirectType.NumField(); i++ {
|
||||||
for key, value := range parseTagSetting(indirectType.Field(i).Tag) {
|
for key, value := range parseTagSetting(indirectType.Field(i).Tag) {
|
||||||
if _, ok := field.TagSettings[key]; !ok {
|
if _, ok := field.TagSettingsGet(key); !ok {
|
||||||
field.TagSettings[key] = value
|
field.TagSettingsSet(key, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -207,17 +218,17 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
} else if _, isTime := fieldValue.(*time.Time); isTime {
|
} else if _, isTime := fieldValue.(*time.Time); isTime {
|
||||||
// is time
|
// is time
|
||||||
field.IsNormal = true
|
field.IsNormal = true
|
||||||
} else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
|
} else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous {
|
||||||
// is embedded struct
|
// is embedded struct
|
||||||
for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields {
|
for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields {
|
||||||
subField = subField.clone()
|
subField = subField.clone()
|
||||||
subField.Names = append([]string{fieldStruct.Name}, subField.Names...)
|
subField.Names = append([]string{fieldStruct.Name}, subField.Names...)
|
||||||
if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok {
|
if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok {
|
||||||
subField.DBName = prefix + subField.DBName
|
subField.DBName = prefix + subField.DBName
|
||||||
}
|
}
|
||||||
|
|
||||||
if subField.IsPrimaryKey {
|
if subField.IsPrimaryKey {
|
||||||
if _, ok := subField.TagSettings["PRIMARY_KEY"]; ok {
|
if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok {
|
||||||
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField)
|
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField)
|
||||||
} else {
|
} else {
|
||||||
subField.IsPrimaryKey = false
|
subField.IsPrimaryKey = false
|
||||||
@ -248,13 +259,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
elemType = field.Struct.Type
|
elemType = field.Struct.Type
|
||||||
)
|
)
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
|
if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" {
|
||||||
foreignKeys = strings.Split(foreignKey, ",")
|
foreignKeys = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
|
if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" {
|
||||||
associationForeignKeys = strings.Split(foreignKey, ",")
|
associationForeignKeys = strings.Split(foreignKey, ",")
|
||||||
} else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
} else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" {
|
||||||
associationForeignKeys = strings.Split(foreignKey, ",")
|
associationForeignKeys = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -263,13 +274,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if elemType.Kind() == reflect.Struct {
|
if elemType.Kind() == reflect.Struct {
|
||||||
if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" {
|
||||||
relationship.Kind = "many_to_many"
|
relationship.Kind = "many_to_many"
|
||||||
|
|
||||||
{ // Foreign Keys for Source
|
{ // Foreign Keys for Source
|
||||||
joinTableDBNames := []string{}
|
joinTableDBNames := []string{}
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
|
if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" {
|
||||||
joinTableDBNames = strings.Split(foreignKey, ",")
|
joinTableDBNames = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -290,7 +301,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
// if defined join table's foreign key
|
// if defined join table's foreign key
|
||||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
|
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
|
||||||
} else {
|
} else {
|
||||||
defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName
|
defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName
|
||||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
|
relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -300,7 +311,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
{ // Foreign Keys for Association (Destination)
|
{ // Foreign Keys for Association (Destination)
|
||||||
associationJoinTableDBNames := []string{}
|
associationJoinTableDBNames := []string{}
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
|
if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" {
|
||||||
associationJoinTableDBNames = strings.Split(foreignKey, ",")
|
associationJoinTableDBNames = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -321,7 +332,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
|
||||||
} else {
|
} else {
|
||||||
// join table foreign keys for association
|
// join table foreign keys for association
|
||||||
joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
|
joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName
|
||||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -329,7 +340,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
joinTableHandler := JoinTableHandler{}
|
joinTableHandler := JoinTableHandler{}
|
||||||
joinTableHandler.Setup(relationship, many2many, reflectType, elemType)
|
joinTableHandler.Setup(relationship, ToTableName(many2many), reflectType, elemType)
|
||||||
relationship.JoinTableHandler = &joinTableHandler
|
relationship.JoinTableHandler = &joinTableHandler
|
||||||
field.Relationship = relationship
|
field.Relationship = relationship
|
||||||
} else {
|
} else {
|
||||||
@ -338,7 +349,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
var toFields = toScope.GetStructFields()
|
var toFields = toScope.GetStructFields()
|
||||||
relationship.Kind = "has_many"
|
relationship.Kind = "has_many"
|
||||||
|
|
||||||
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" {
|
||||||
// Dog has many toys, tag polymorphic is Owner, then associationType is Owner
|
// Dog has many toys, tag polymorphic is Owner, then associationType is Owner
|
||||||
// Toy use OwnerID, OwnerType ('dogs') as foreign key
|
// Toy use OwnerID, OwnerType ('dogs') as foreign key
|
||||||
if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
|
if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
|
||||||
@ -346,7 +357,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
relationship.PolymorphicType = polymorphicType.Name
|
relationship.PolymorphicType = polymorphicType.Name
|
||||||
relationship.PolymorphicDBName = polymorphicType.DBName
|
relationship.PolymorphicDBName = polymorphicType.DBName
|
||||||
// if Dog has multiple set of toys set name of the set (instead of default 'dogs')
|
// if Dog has multiple set of toys set name of the set (instead of default 'dogs')
|
||||||
if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
|
if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok {
|
||||||
relationship.PolymorphicValue = value
|
relationship.PolymorphicValue = value
|
||||||
} else {
|
} else {
|
||||||
relationship.PolymorphicValue = scope.TableName()
|
relationship.PolymorphicValue = scope.TableName()
|
||||||
@ -428,17 +439,17 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
tagAssociationForeignKeys []string
|
tagAssociationForeignKeys []string
|
||||||
)
|
)
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
|
if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" {
|
||||||
tagForeignKeys = strings.Split(foreignKey, ",")
|
tagForeignKeys = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
|
if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" {
|
||||||
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
|
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
|
||||||
} else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
} else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" {
|
||||||
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
|
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" {
|
||||||
// Cat has one toy, tag polymorphic is Owner, then associationType is Owner
|
// Cat has one toy, tag polymorphic is Owner, then associationType is Owner
|
||||||
// Toy use OwnerID, OwnerType ('cats') as foreign key
|
// Toy use OwnerID, OwnerType ('cats') as foreign key
|
||||||
if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
|
if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
|
||||||
@ -446,7 +457,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
relationship.PolymorphicType = polymorphicType.Name
|
relationship.PolymorphicType = polymorphicType.Name
|
||||||
relationship.PolymorphicDBName = polymorphicType.DBName
|
relationship.PolymorphicDBName = polymorphicType.DBName
|
||||||
// if Cat has several different types of toys set name for each (instead of default 'cats')
|
// if Cat has several different types of toys set name for each (instead of default 'cats')
|
||||||
if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
|
if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok {
|
||||||
relationship.PolymorphicValue = value
|
relationship.PolymorphicValue = value
|
||||||
} else {
|
} else {
|
||||||
relationship.PolymorphicValue = scope.TableName()
|
relationship.PolymorphicValue = scope.TableName()
|
||||||
@ -584,10 +595,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Even it is ignored, also possible to decode db value into the field
|
// Even it is ignored, also possible to decode db value into the field
|
||||||
if value, ok := field.TagSettings["COLUMN"]; ok {
|
if value, ok := field.TagSettingsGet("COLUMN"); ok {
|
||||||
field.DBName = value
|
field.DBName = value
|
||||||
} else {
|
} else {
|
||||||
field.DBName = ToDBName(fieldStruct.Name)
|
field.DBName = ToColumnName(fieldStruct.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
||||||
@ -601,7 +612,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
modelStructsMap.Set(reflectType, &modelStruct)
|
modelStructsMap.Store(reflectType, &modelStruct)
|
||||||
|
|
||||||
return &modelStruct
|
return &modelStruct
|
||||||
}
|
}
|
||||||
|
124
naming.go
Normal file
124
naming.go
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Namer is a function type which is given a string and return a string
|
||||||
|
type Namer func(string) string
|
||||||
|
|
||||||
|
// NamingStrategy represents naming strategies
|
||||||
|
type NamingStrategy struct {
|
||||||
|
DB Namer
|
||||||
|
Table Namer
|
||||||
|
Column Namer
|
||||||
|
}
|
||||||
|
|
||||||
|
// TheNamingStrategy is being initialized with defaultNamingStrategy
|
||||||
|
var TheNamingStrategy = &NamingStrategy{
|
||||||
|
DB: defaultNamer,
|
||||||
|
Table: defaultNamer,
|
||||||
|
Column: defaultNamer,
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddNamingStrategy sets the naming strategy
|
||||||
|
func AddNamingStrategy(ns *NamingStrategy) {
|
||||||
|
if ns.DB == nil {
|
||||||
|
ns.DB = defaultNamer
|
||||||
|
}
|
||||||
|
if ns.Table == nil {
|
||||||
|
ns.Table = defaultNamer
|
||||||
|
}
|
||||||
|
if ns.Column == nil {
|
||||||
|
ns.Column = defaultNamer
|
||||||
|
}
|
||||||
|
TheNamingStrategy = ns
|
||||||
|
}
|
||||||
|
|
||||||
|
// DBName alters the given name by DB
|
||||||
|
func (ns *NamingStrategy) DBName(name string) string {
|
||||||
|
return ns.DB(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName alters the given name by Table
|
||||||
|
func (ns *NamingStrategy) TableName(name string) string {
|
||||||
|
return ns.Table(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ColumnName alters the given name by Column
|
||||||
|
func (ns *NamingStrategy) ColumnName(name string) string {
|
||||||
|
return ns.Column(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToDBName convert string to db name
|
||||||
|
func ToDBName(name string) string {
|
||||||
|
return TheNamingStrategy.DBName(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToTableName convert string to table name
|
||||||
|
func ToTableName(name string) string {
|
||||||
|
return TheNamingStrategy.TableName(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToColumnName convert string to db name
|
||||||
|
func ToColumnName(name string) string {
|
||||||
|
return TheNamingStrategy.ColumnName(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
var smap = newSafeMap()
|
||||||
|
|
||||||
|
func defaultNamer(name string) string {
|
||||||
|
const (
|
||||||
|
lower = false
|
||||||
|
upper = true
|
||||||
|
)
|
||||||
|
|
||||||
|
if v := smap.Get(name); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
if name == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
value = commonInitialismsReplacer.Replace(name)
|
||||||
|
buf = bytes.NewBufferString("")
|
||||||
|
lastCase, currCase, nextCase, nextNumber bool
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, v := range value[:len(value)-1] {
|
||||||
|
nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z')
|
||||||
|
nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9')
|
||||||
|
|
||||||
|
if i > 0 {
|
||||||
|
if currCase == upper {
|
||||||
|
if lastCase == upper && (nextCase == upper || nextNumber == upper) {
|
||||||
|
buf.WriteRune(v)
|
||||||
|
} else {
|
||||||
|
if value[i-1] != '_' && value[i+1] != '_' {
|
||||||
|
buf.WriteRune('_')
|
||||||
|
}
|
||||||
|
buf.WriteRune(v)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
buf.WriteRune(v)
|
||||||
|
if i == len(value)-2 && (nextCase == upper && nextNumber == lower) {
|
||||||
|
buf.WriteRune('_')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
currCase = upper
|
||||||
|
buf.WriteRune(v)
|
||||||
|
}
|
||||||
|
lastCase = currCase
|
||||||
|
currCase = nextCase
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteByte(value[len(value)-1])
|
||||||
|
|
||||||
|
s := strings.ToLower(buf.String())
|
||||||
|
smap.Set(name, s)
|
||||||
|
return s
|
||||||
|
}
|
69
naming_test.go
Normal file
69
naming_test.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
package gorm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTheNamingStrategy(t *testing.T) {
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
namer gorm.Namer
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{name: "auth", expected: "auth", namer: gorm.TheNamingStrategy.DB},
|
||||||
|
{name: "userRestrictions", expected: "user_restrictions", namer: gorm.TheNamingStrategy.Table},
|
||||||
|
{name: "clientID", expected: "client_id", namer: gorm.TheNamingStrategy.Column},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
result := c.namer(c.name)
|
||||||
|
if result != c.expected {
|
||||||
|
t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamingStrategy(t *testing.T) {
|
||||||
|
|
||||||
|
dbNameNS := func(name string) string {
|
||||||
|
return "db_" + name
|
||||||
|
}
|
||||||
|
tableNameNS := func(name string) string {
|
||||||
|
return "tbl_" + name
|
||||||
|
}
|
||||||
|
columnNameNS := func(name string) string {
|
||||||
|
return "col_" + name
|
||||||
|
}
|
||||||
|
|
||||||
|
ns := &gorm.NamingStrategy{
|
||||||
|
DB: dbNameNS,
|
||||||
|
Table: tableNameNS,
|
||||||
|
Column: columnNameNS,
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
namer gorm.Namer
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{name: "auth", expected: "db_auth", namer: ns.DB},
|
||||||
|
{name: "user", expected: "tbl_user", namer: ns.Table},
|
||||||
|
{name: "password", expected: "col_password", namer: ns.Column},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
result := c.namer(c.name)
|
||||||
|
if result != c.expected {
|
||||||
|
t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -123,6 +123,31 @@ func TestAutoPreload(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAutoPreloadFalseDoesntPreload(t *testing.T) {
|
||||||
|
user1 := getPreloadUser("auto_user1")
|
||||||
|
DB.Save(user1)
|
||||||
|
|
||||||
|
preloadDB := DB.Set("gorm:auto_preload", false).Where("role = ?", "Preload")
|
||||||
|
var user User
|
||||||
|
preloadDB.Find(&user)
|
||||||
|
|
||||||
|
if user.BillingAddress.Address1 != "" {
|
||||||
|
t.Error("AutoPreload was set to fasle, but still fetched data")
|
||||||
|
}
|
||||||
|
|
||||||
|
user2 := getPreloadUser("auto_user2")
|
||||||
|
DB.Save(user2)
|
||||||
|
|
||||||
|
var users []User
|
||||||
|
preloadDB.Find(&users)
|
||||||
|
|
||||||
|
for _, user := range users {
|
||||||
|
if user.BillingAddress.Address1 != "" {
|
||||||
|
t.Error("AutoPreload was set to fasle, but still fetched data")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNestedPreload1(t *testing.T) {
|
func TestNestedPreload1(t *testing.T) {
|
||||||
type (
|
type (
|
||||||
Level1 struct {
|
Level1 struct {
|
||||||
|
@ -181,17 +181,17 @@ func TestSearchWithPlainSQL(t *testing.T) {
|
|||||||
|
|
||||||
scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users)
|
scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users)
|
||||||
if len(users) != 2 {
|
if len(users) != 2 {
|
||||||
t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users))
|
t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users))
|
||||||
}
|
}
|
||||||
|
|
||||||
scopedb.Where("birthday > ?", "2002-10-10").Find(&users)
|
scopedb.Where("birthday > ?", "2002-10-10").Find(&users)
|
||||||
if len(users) != 2 {
|
if len(users) != 2 {
|
||||||
t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users))
|
t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users))
|
||||||
}
|
}
|
||||||
|
|
||||||
scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users)
|
scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users)
|
||||||
if len(users) != 1 {
|
if len(users) != 1 {
|
||||||
t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
|
t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users)
|
DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users)
|
||||||
@ -532,28 +532,28 @@ func TestNot(t *testing.T) {
|
|||||||
DB.Table("users").Where("name = ?", "user3").Count(&name3Count)
|
DB.Table("users").Where("name = ?", "user3").Count(&name3Count)
|
||||||
DB.Not("name", "user3").Find(&users4)
|
DB.Not("name", "user3").Find(&users4)
|
||||||
if len(users1)-len(users4) != int(name3Count) {
|
if len(users1)-len(users4) != int(name3Count) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users' name not equal 3")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Not("name = ?", "user3").Find(&users4)
|
DB.Not("name = ?", "user3").Find(&users4)
|
||||||
if len(users1)-len(users4) != int(name3Count) {
|
if len(users1)-len(users4) != int(name3Count) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users' name not equal 3")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Not("name <> ?", "user3").Find(&users4)
|
DB.Not("name <> ?", "user3").Find(&users4)
|
||||||
if len(users4) != int(name3Count) {
|
if len(users4) != int(name3Count) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users' name not equal 3")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Not(User{Name: "user3"}).Find(&users5)
|
DB.Not(User{Name: "user3"}).Find(&users5)
|
||||||
|
|
||||||
if len(users1)-len(users5) != int(name3Count) {
|
if len(users1)-len(users5) != int(name3Count) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users' name not equal 3")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6)
|
DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6)
|
||||||
if len(users1)-len(users6) != int(name3Count) {
|
if len(users1)-len(users6) != int(name3Count) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users' name not equal 3")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7)
|
DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7)
|
||||||
@ -563,14 +563,14 @@ func TestNot(t *testing.T) {
|
|||||||
|
|
||||||
DB.Not("name", []string{"user3"}).Find(&users8)
|
DB.Not("name", []string{"user3"}).Find(&users8)
|
||||||
if len(users1)-len(users8) != int(name3Count) {
|
if len(users1)-len(users8) != int(name3Count) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users' name not equal 3")
|
||||||
}
|
}
|
||||||
|
|
||||||
var name2Count int64
|
var name2Count int64
|
||||||
DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
|
DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
|
||||||
DB.Not("name", []string{"user3", "user2"}).Find(&users9)
|
DB.Not("name", []string{"user3", "user2"}).Find(&users9)
|
||||||
if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) {
|
if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users' name not equal 3")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
50
scope.go
50
scope.go
@ -63,7 +63,7 @@ func (scope *Scope) SQLDB() SQLCommon {
|
|||||||
|
|
||||||
// Dialect get dialect
|
// Dialect get dialect
|
||||||
func (scope *Scope) Dialect() Dialect {
|
func (scope *Scope) Dialect() Dialect {
|
||||||
return scope.db.parent.dialect
|
return scope.db.dialect
|
||||||
}
|
}
|
||||||
|
|
||||||
// Quote used to quote string to escape them for database
|
// Quote used to quote string to escape them for database
|
||||||
@ -134,7 +134,7 @@ func (scope *Scope) Fields() []*Field {
|
|||||||
// FieldByName find `gorm.Field` with field name or db name
|
// FieldByName find `gorm.Field` with field name or db name
|
||||||
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
||||||
var (
|
var (
|
||||||
dbName = ToDBName(name)
|
dbName = ToColumnName(name)
|
||||||
mostMatchedField *Field
|
mostMatchedField *Field
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -486,8 +486,10 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
|
|||||||
values[index] = &ignored
|
values[index] = &ignored
|
||||||
|
|
||||||
selectFields = fields
|
selectFields = fields
|
||||||
|
offset := 0
|
||||||
if idx, ok := selectedColumnsMap[column]; ok {
|
if idx, ok := selectedColumnsMap[column]; ok {
|
||||||
selectFields = selectFields[idx+1:]
|
offset = idx + 1
|
||||||
|
selectFields = selectFields[offset:]
|
||||||
}
|
}
|
||||||
|
|
||||||
for fieldIndex, field := range selectFields {
|
for fieldIndex, field := range selectFields {
|
||||||
@ -501,7 +503,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
|
|||||||
resetFields[index] = field
|
resetFields[index] = field
|
||||||
}
|
}
|
||||||
|
|
||||||
selectedColumnsMap[column] = fieldIndex
|
selectedColumnsMap[column] = offset + fieldIndex
|
||||||
|
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
break
|
break
|
||||||
@ -586,10 +588,10 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
|
|||||||
scope.Err(fmt.Errorf("invalid query condition: %v", value))
|
scope.Err(fmt.Errorf("invalid query condition: %v", value))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
scopeQuotedTableName := newScope.QuotedTableName()
|
||||||
for _, field := range newScope.Fields() {
|
for _, field := range newScope.Fields() {
|
||||||
if !field.IsIgnored && !field.IsBlank {
|
if !field.IsIgnored && !field.IsBlank {
|
||||||
sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
|
sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return strings.Join(sqls, " AND ")
|
return strings.Join(sqls, " AND ")
|
||||||
@ -853,6 +855,14 @@ func (scope *Scope) inlineCondition(values ...interface{}) *Scope {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
if db, ok := scope.db.db.(sqlTx); ok {
|
||||||
|
db.Rollback()
|
||||||
|
}
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
for _, f := range funcs {
|
for _, f := range funcs {
|
||||||
(*f)(scope)
|
(*f)(scope)
|
||||||
if scope.skipLeft {
|
if scope.skipLeft {
|
||||||
@ -880,7 +890,7 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string
|
|||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Map:
|
case reflect.Map:
|
||||||
for _, key := range reflectValue.MapKeys() {
|
for _, key := range reflectValue.MapKeys() {
|
||||||
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
|
attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
for _, field := range (&Scope{Value: values}).Fields() {
|
for _, field := range (&Scope{Value: values}).Fields() {
|
||||||
@ -907,7 +917,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin
|
|||||||
results[field.DBName] = value
|
results[field.DBName] = value
|
||||||
} else {
|
} else {
|
||||||
err := field.Set(value)
|
err := field.Set(value)
|
||||||
if field.IsNormal {
|
if field.IsNormal && !field.IsIgnored {
|
||||||
hasUpdate = true
|
hasUpdate = true
|
||||||
if err == ErrUnaddressable {
|
if err == ErrUnaddressable {
|
||||||
results[field.DBName] = value
|
results[field.DBName] = value
|
||||||
@ -1113,8 +1123,8 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
|||||||
if field, ok := scope.FieldByName(fieldName); ok {
|
if field, ok := scope.FieldByName(fieldName); ok {
|
||||||
foreignKeyStruct := field.clone()
|
foreignKeyStruct := field.clone()
|
||||||
foreignKeyStruct.IsPrimaryKey = false
|
foreignKeyStruct.IsPrimaryKey = false
|
||||||
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true")
|
||||||
delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT")
|
foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT")
|
||||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
||||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
|
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
|
||||||
}
|
}
|
||||||
@ -1124,8 +1134,8 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
|||||||
if field, ok := toScope.FieldByName(fieldName); ok {
|
if field, ok := toScope.FieldByName(fieldName); ok {
|
||||||
foreignKeyStruct := field.clone()
|
foreignKeyStruct := field.clone()
|
||||||
foreignKeyStruct.IsPrimaryKey = false
|
foreignKeyStruct.IsPrimaryKey = false
|
||||||
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true")
|
||||||
delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT")
|
foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT")
|
||||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
||||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
|
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
|
||||||
}
|
}
|
||||||
@ -1215,12 +1225,18 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) removeForeignKey(field string, dest string) {
|
func (scope *Scope) removeForeignKey(field string, dest string) {
|
||||||
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest)
|
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
|
||||||
|
|
||||||
if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var query = `ALTER TABLE %s DROP CONSTRAINT %s;`
|
var mysql mysql
|
||||||
|
var query string
|
||||||
|
if scope.Dialect().GetName() == mysql.GetName() {
|
||||||
|
query = `ALTER TABLE %s DROP FOREIGN KEY %s;`
|
||||||
|
} else {
|
||||||
|
query = `ALTER TABLE %s DROP CONSTRAINT %s;`
|
||||||
|
}
|
||||||
|
|
||||||
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec()
|
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1254,7 +1270,7 @@ func (scope *Scope) autoIndex() *Scope {
|
|||||||
var uniqueIndexes = map[string][]string{}
|
var uniqueIndexes = map[string][]string{}
|
||||||
|
|
||||||
for _, field := range scope.GetStructFields() {
|
for _, field := range scope.GetStructFields() {
|
||||||
if name, ok := field.TagSettings["INDEX"]; ok {
|
if name, ok := field.TagSettingsGet("INDEX"); ok {
|
||||||
names := strings.Split(name, ",")
|
names := strings.Split(name, ",")
|
||||||
|
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
@ -1265,7 +1281,7 @@ func (scope *Scope) autoIndex() *Scope {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok {
|
if name, ok := field.TagSettingsGet("UNIQUE_INDEX"); ok {
|
||||||
names := strings.Split(name, ",")
|
names := strings.Split(name, ",")
|
||||||
|
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
|
65
utils.go
65
utils.go
@ -1,7 +1,6 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -26,8 +25,8 @@ var NowFunc = func() time.Time {
|
|||||||
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
||||||
var commonInitialismsReplacer *strings.Replacer
|
var commonInitialismsReplacer *strings.Replacer
|
||||||
|
|
||||||
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`)
|
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`)
|
||||||
var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`)
|
var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
var commonInitialismsForReplacer []string
|
var commonInitialismsForReplacer []string
|
||||||
@ -58,64 +57,6 @@ func newSafeMap() *safeMap {
|
|||||||
return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
|
return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
|
||||||
}
|
}
|
||||||
|
|
||||||
var smap = newSafeMap()
|
|
||||||
|
|
||||||
type strCase bool
|
|
||||||
|
|
||||||
const (
|
|
||||||
lower strCase = false
|
|
||||||
upper strCase = true
|
|
||||||
)
|
|
||||||
|
|
||||||
// ToDBName convert string to db name
|
|
||||||
func ToDBName(name string) string {
|
|
||||||
if v := smap.Get(name); v != "" {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
if name == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
value = commonInitialismsReplacer.Replace(name)
|
|
||||||
buf = bytes.NewBufferString("")
|
|
||||||
lastCase, currCase, nextCase strCase
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, v := range value[:len(value)-1] {
|
|
||||||
nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z')
|
|
||||||
if i > 0 {
|
|
||||||
if currCase == upper {
|
|
||||||
if lastCase == upper && nextCase == upper {
|
|
||||||
buf.WriteRune(v)
|
|
||||||
} else {
|
|
||||||
if value[i-1] != '_' && value[i+1] != '_' {
|
|
||||||
buf.WriteRune('_')
|
|
||||||
}
|
|
||||||
buf.WriteRune(v)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
buf.WriteRune(v)
|
|
||||||
if i == len(value)-2 && nextCase == upper {
|
|
||||||
buf.WriteRune('_')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
currCase = upper
|
|
||||||
buf.WriteRune(v)
|
|
||||||
}
|
|
||||||
lastCase = currCase
|
|
||||||
currCase = nextCase
|
|
||||||
}
|
|
||||||
|
|
||||||
buf.WriteByte(value[len(value)-1])
|
|
||||||
|
|
||||||
s := strings.ToLower(buf.String())
|
|
||||||
smap.Set(name, s)
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SQL expression
|
// SQL expression
|
||||||
type expr struct {
|
type expr struct {
|
||||||
expr string
|
expr string
|
||||||
@ -265,7 +206,7 @@ func getValueFromFields(value reflect.Value, fieldNames []string) (results []int
|
|||||||
// as FieldByName could panic
|
// as FieldByName could panic
|
||||||
if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
|
if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
|
||||||
for _, fieldName := range fieldNames {
|
for _, fieldName := range fieldNames {
|
||||||
if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() {
|
if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() {
|
||||||
result := fieldValue.Interface()
|
result := fieldValue.Interface()
|
||||||
if r, ok := result.(driver.Valuer); ok {
|
if r, ok := result.(driver.Valuer); ok {
|
||||||
result, _ = r.Value()
|
result, _ = r.Value()
|
||||||
|
@ -1,32 +0,0 @@
|
|||||||
package gorm_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestToDBNameGenerateFriendlyName(t *testing.T) {
|
|
||||||
var maps = map[string]string{
|
|
||||||
"": "",
|
|
||||||
"X": "x",
|
|
||||||
"ThisIsATest": "this_is_a_test",
|
|
||||||
"PFAndESI": "pf_and_esi",
|
|
||||||
"AbcAndJkl": "abc_and_jkl",
|
|
||||||
"EmployeeID": "employee_id",
|
|
||||||
"SKU_ID": "sku_id",
|
|
||||||
"FieldX": "field_x",
|
|
||||||
"HTTPAndSMTP": "http_and_smtp",
|
|
||||||
"HTTPServerHandlerForURLID": "http_server_handler_for_url_id",
|
|
||||||
"UUID": "uuid",
|
|
||||||
"HTTPURL": "http_url",
|
|
||||||
"HTTP_URL": "http_url",
|
|
||||||
"ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id",
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, value := range maps {
|
|
||||||
if gorm.ToDBName(key) != value {
|
|
||||||
t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
x
Reference in New Issue
Block a user