Merge branch 'master' into master
This commit is contained in:
commit
4cd0932d73
@ -1,11 +0,0 @@
|
|||||||
---
|
|
||||||
engines:
|
|
||||||
gofmt:
|
|
||||||
enabled: true
|
|
||||||
govet:
|
|
||||||
enabled: true
|
|
||||||
golint:
|
|
||||||
enabled: true
|
|
||||||
ratings:
|
|
||||||
paths:
|
|
||||||
- "**.go"
|
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,2 +1,3 @@
|
|||||||
documents
|
documents
|
||||||
|
coverage.txt
|
||||||
_book
|
_book
|
||||||
|
@ -4,10 +4,11 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
|||||||
|
|
||||||
[](https://goreportcard.com/report/github.com/jinzhu/gorm)
|
[](https://goreportcard.com/report/github.com/jinzhu/gorm)
|
||||||
[](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b)
|
[](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b)
|
||||||
|
[](https://codecov.io/gh/jinzhu/gorm)
|
||||||
[](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
[](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||||
[](https://opencollective.com/gorm)
|
[](https://opencollective.com/gorm)
|
||||||
[](https://opencollective.com/gorm)
|
[](https://opencollective.com/gorm)
|
||||||
[](http://opensource.org/licenses/MIT)
|
[](https://opensource.org/licenses/MIT)
|
||||||
[](https://godoc.org/github.com/jinzhu/gorm)
|
[](https://godoc.org/github.com/jinzhu/gorm)
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
@ -27,11 +28,11 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
|||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
* GORM Guides [http://gorm.io](http://gorm.io)
|
* GORM Guides [https://gorm.io](https://gorm.io)
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
[You can help to deliver a better GORM, check out things you can do](http://gorm.io/contribute.html)
|
[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
|
@ -267,15 +267,16 @@ func (association *Association) Count() int {
|
|||||||
query = scope.DB()
|
query = scope.DB()
|
||||||
)
|
)
|
||||||
|
|
||||||
if relationship.Kind == "many_to_many" {
|
switch relationship.Kind {
|
||||||
|
case "many_to_many":
|
||||||
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
|
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
|
||||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
case "has_many", "has_one":
|
||||||
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
|
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
|
||||||
query = query.Where(
|
query = query.Where(
|
||||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
||||||
toQueryValues(primaryKeys)...,
|
toQueryValues(primaryKeys)...,
|
||||||
)
|
)
|
||||||
} else if relationship.Kind == "belongs_to" {
|
case "belongs_to":
|
||||||
primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value)
|
primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value)
|
||||||
query = query.Where(
|
query = query.Where(
|
||||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
|
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
|
||||||
@ -367,6 +368,7 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa
|
|||||||
return association
|
return association
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setErr set error when the error is not nil. And return Association.
|
||||||
func (association *Association) setErr(err error) *Association {
|
func (association *Association) setErr(err error) *Association {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
association.Error = err
|
association.Error = err
|
||||||
|
39
callback.go
39
callback.go
@ -1,6 +1,6 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import "log"
|
import "fmt"
|
||||||
|
|
||||||
// DefaultCallback default callbacks defined by gorm
|
// DefaultCallback default callbacks defined by gorm
|
||||||
var DefaultCallback = &Callback{}
|
var DefaultCallback = &Callback{}
|
||||||
@ -13,6 +13,7 @@ var DefaultCallback = &Callback{}
|
|||||||
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
|
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
|
||||||
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
|
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
|
||||||
type Callback struct {
|
type Callback struct {
|
||||||
|
logger logger
|
||||||
creates []*func(scope *Scope)
|
creates []*func(scope *Scope)
|
||||||
updates []*func(scope *Scope)
|
updates []*func(scope *Scope)
|
||||||
deletes []*func(scope *Scope)
|
deletes []*func(scope *Scope)
|
||||||
@ -23,6 +24,7 @@ type Callback struct {
|
|||||||
|
|
||||||
// CallbackProcessor contains callback informations
|
// CallbackProcessor contains callback informations
|
||||||
type CallbackProcessor struct {
|
type CallbackProcessor struct {
|
||||||
|
logger logger
|
||||||
name string // current callback's name
|
name string // current callback's name
|
||||||
before string // register current callback before a callback
|
before string // register current callback before a callback
|
||||||
after string // register current callback after a callback
|
after string // register current callback after a callback
|
||||||
@ -33,8 +35,9 @@ type CallbackProcessor struct {
|
|||||||
parent *Callback
|
parent *Callback
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Callback) clone() *Callback {
|
func (c *Callback) clone(logger logger) *Callback {
|
||||||
return &Callback{
|
return &Callback{
|
||||||
|
logger: logger,
|
||||||
creates: c.creates,
|
creates: c.creates,
|
||||||
updates: c.updates,
|
updates: c.updates,
|
||||||
deletes: c.deletes,
|
deletes: c.deletes,
|
||||||
@ -53,28 +56,28 @@ func (c *Callback) clone() *Callback {
|
|||||||
// scope.Err(errors.New("error"))
|
// scope.Err(errors.New("error"))
|
||||||
// })
|
// })
|
||||||
func (c *Callback) Create() *CallbackProcessor {
|
func (c *Callback) Create() *CallbackProcessor {
|
||||||
return &CallbackProcessor{kind: "create", parent: c}
|
return &CallbackProcessor{logger: c.logger, kind: "create", parent: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update could be used to register callbacks for updating object, refer `Create` for usage
|
// Update could be used to register callbacks for updating object, refer `Create` for usage
|
||||||
func (c *Callback) Update() *CallbackProcessor {
|
func (c *Callback) Update() *CallbackProcessor {
|
||||||
return &CallbackProcessor{kind: "update", parent: c}
|
return &CallbackProcessor{logger: c.logger, kind: "update", parent: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete could be used to register callbacks for deleting object, refer `Create` for usage
|
// Delete could be used to register callbacks for deleting object, refer `Create` for usage
|
||||||
func (c *Callback) Delete() *CallbackProcessor {
|
func (c *Callback) Delete() *CallbackProcessor {
|
||||||
return &CallbackProcessor{kind: "delete", parent: c}
|
return &CallbackProcessor{logger: c.logger, kind: "delete", parent: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
|
// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
|
||||||
// Refer `Create` for usage
|
// Refer `Create` for usage
|
||||||
func (c *Callback) Query() *CallbackProcessor {
|
func (c *Callback) Query() *CallbackProcessor {
|
||||||
return &CallbackProcessor{kind: "query", parent: c}
|
return &CallbackProcessor{logger: c.logger, kind: "query", parent: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
|
// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
|
||||||
func (c *Callback) RowQuery() *CallbackProcessor {
|
func (c *Callback) RowQuery() *CallbackProcessor {
|
||||||
return &CallbackProcessor{kind: "row_query", parent: c}
|
return &CallbackProcessor{logger: c.logger, kind: "row_query", parent: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
// After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
|
// After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
|
||||||
@ -93,7 +96,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
|
|||||||
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
||||||
if cp.kind == "row_query" {
|
if cp.kind == "row_query" {
|
||||||
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
|
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
|
||||||
log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)
|
cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName))
|
||||||
cp.before = "gorm:row_query"
|
cp.before = "gorm:row_query"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -107,7 +110,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
|
|||||||
// Remove a registered callback
|
// Remove a registered callback
|
||||||
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
|
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
|
||||||
func (cp *CallbackProcessor) Remove(callbackName string) {
|
func (cp *CallbackProcessor) Remove(callbackName string) {
|
||||||
log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
cp.logger.Print("info", fmt.Sprintf("[info] removing callback `%v` from %v", callbackName, fileWithLineNum()))
|
||||||
cp.name = callbackName
|
cp.name = callbackName
|
||||||
cp.remove = true
|
cp.remove = true
|
||||||
cp.parent.processors = append(cp.parent.processors, cp)
|
cp.parent.processors = append(cp.parent.processors, cp)
|
||||||
@ -116,11 +119,11 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
|
|||||||
|
|
||||||
// Replace a registered callback with new callback
|
// Replace a registered callback with new callback
|
||||||
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
|
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
|
||||||
// scope.SetColumn("Created", now)
|
// scope.SetColumn("CreatedAt", now)
|
||||||
// scope.SetColumn("Updated", now)
|
// scope.SetColumn("UpdatedAt", now)
|
||||||
// })
|
// })
|
||||||
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
||||||
log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum()))
|
||||||
cp.name = callbackName
|
cp.name = callbackName
|
||||||
cp.processor = &callback
|
cp.processor = &callback
|
||||||
cp.replace = true
|
cp.replace = true
|
||||||
@ -132,11 +135,15 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S
|
|||||||
// db.Callback().Create().Get("gorm:create")
|
// db.Callback().Create().Get("gorm:create")
|
||||||
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
|
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
|
||||||
for _, p := range cp.parent.processors {
|
for _, p := range cp.parent.processors {
|
||||||
if p.name == callbackName && p.kind == cp.kind && !cp.remove {
|
if p.name == callbackName && p.kind == cp.kind {
|
||||||
return *p.processor
|
if p.remove {
|
||||||
|
callback = nil
|
||||||
|
} else {
|
||||||
|
callback = *p.processor
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRIndex get right index from string slice
|
// getRIndex get right index from string slice
|
||||||
@ -159,7 +166,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
|||||||
for _, cp := range cps {
|
for _, cp := range cps {
|
||||||
// show warning message the callback name already exists
|
// show warning message the callback name already exists
|
||||||
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
|
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
|
||||||
log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
|
cp.logger.Print("warning", fmt.Sprintf("[warning] duplicated callback `%v` from %v", cp.name, fileWithLineNum()))
|
||||||
}
|
}
|
||||||
allNames = append(allNames, cp.name)
|
allNames = append(allNames, cp.name)
|
||||||
}
|
}
|
||||||
|
@ -31,7 +31,7 @@ func beforeCreateCallback(scope *Scope) {
|
|||||||
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
|
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
|
||||||
func updateTimeStampForCreateCallback(scope *Scope) {
|
func updateTimeStampForCreateCallback(scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
now := NowFunc()
|
now := scope.db.nowFunc()
|
||||||
|
|
||||||
if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
|
if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
|
||||||
if createdAtField.IsBlank {
|
if createdAtField.IsBlank {
|
||||||
@ -50,7 +50,7 @@ func updateTimeStampForCreateCallback(scope *Scope) {
|
|||||||
// createCallback the callback used to insert data into database
|
// createCallback the callback used to insert data into database
|
||||||
func createCallback(scope *Scope) {
|
func createCallback(scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(scope.db.nowFunc())
|
||||||
|
|
||||||
var (
|
var (
|
||||||
columns, placeholders []string
|
columns, placeholders []string
|
||||||
@ -59,7 +59,7 @@ func createCallback(scope *Scope) {
|
|||||||
|
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if scope.changeableField(field) {
|
if scope.changeableField(field) {
|
||||||
if field.IsNormal {
|
if field.IsNormal && !field.IsIgnored {
|
||||||
if field.IsBlank && field.HasDefaultValue {
|
if field.IsBlank && field.HasDefaultValue {
|
||||||
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
|
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
|
||||||
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
|
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
|
||||||
@ -83,21 +83,30 @@ func createCallback(scope *Scope) {
|
|||||||
quotedTableName = scope.QuotedTableName()
|
quotedTableName = scope.QuotedTableName()
|
||||||
primaryField = scope.PrimaryField()
|
primaryField = scope.PrimaryField()
|
||||||
extraOption string
|
extraOption string
|
||||||
|
insertModifier string
|
||||||
)
|
)
|
||||||
|
|
||||||
if str, ok := scope.Get("gorm:insert_option"); ok {
|
if str, ok := scope.Get("gorm:insert_option"); ok {
|
||||||
extraOption = fmt.Sprint(str)
|
extraOption = fmt.Sprint(str)
|
||||||
}
|
}
|
||||||
|
if str, ok := scope.Get("gorm:insert_modifier"); ok {
|
||||||
|
insertModifier = strings.ToUpper(fmt.Sprint(str))
|
||||||
|
if insertModifier == "INTO" {
|
||||||
|
insertModifier = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if primaryField != nil {
|
if primaryField != nil {
|
||||||
returningColumn = scope.Quote(primaryField.DBName)
|
returningColumn = scope.Quote(primaryField.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
|
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
|
||||||
|
lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns)
|
||||||
|
|
||||||
if len(columns) == 0 {
|
if len(columns) == 0 {
|
||||||
scope.Raw(fmt.Sprintf(
|
scope.Raw(fmt.Sprintf(
|
||||||
"INSERT INTO %v %v%v%v",
|
"INSERT%v INTO %v %v%v%v",
|
||||||
|
addExtraSpaceIfExist(insertModifier),
|
||||||
quotedTableName,
|
quotedTableName,
|
||||||
scope.Dialect().DefaultValueStr(),
|
scope.Dialect().DefaultValueStr(),
|
||||||
addExtraSpaceIfExist(extraOption),
|
addExtraSpaceIfExist(extraOption),
|
||||||
@ -105,17 +114,19 @@ func createCallback(scope *Scope) {
|
|||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
scope.Raw(fmt.Sprintf(
|
scope.Raw(fmt.Sprintf(
|
||||||
"INSERT INTO %v (%v) VALUES (%v)%v%v",
|
"INSERT%v INTO %v (%v)%v VALUES (%v)%v%v",
|
||||||
|
addExtraSpaceIfExist(insertModifier),
|
||||||
scope.QuotedTableName(),
|
scope.QuotedTableName(),
|
||||||
strings.Join(columns, ","),
|
strings.Join(columns, ","),
|
||||||
|
addExtraSpaceIfExist(lastInsertIDOutputInterstitial),
|
||||||
strings.Join(placeholders, ","),
|
strings.Join(placeholders, ","),
|
||||||
addExtraSpaceIfExist(extraOption),
|
addExtraSpaceIfExist(extraOption),
|
||||||
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
// execute create sql
|
// execute create sql: no primaryField
|
||||||
if lastInsertIDReturningSuffix == "" || primaryField == nil {
|
if primaryField == nil {
|
||||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||||
// set rows affected count
|
// set rows affected count
|
||||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||||
@ -127,16 +138,35 @@ func createCallback(scope *Scope) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
return
|
||||||
if primaryField.Field.CanAddr() {
|
|
||||||
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
|
||||||
primaryField.IsBlank = false
|
|
||||||
scope.db.RowsAffected = 1
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
scope.Err(ErrUnaddressable)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// execute create sql: lastInsertID implemention for majority of dialects
|
||||||
|
if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" {
|
||||||
|
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||||
|
// set rows affected count
|
||||||
|
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
|
||||||
|
// set primary value to primary field
|
||||||
|
if primaryField != nil && primaryField.IsBlank {
|
||||||
|
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
|
||||||
|
scope.Err(primaryField.Set(primaryValue))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
|
||||||
|
if primaryField.Field.CanAddr() {
|
||||||
|
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
||||||
|
primaryField.IsBlank = false
|
||||||
|
scope.db.RowsAffected = 1
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
scope.Err(ErrUnaddressable)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ func init() {
|
|||||||
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
|
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
|
||||||
func beforeDeleteCallback(scope *Scope) {
|
func beforeDeleteCallback(scope *Scope) {
|
||||||
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
|
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
|
||||||
scope.Err(errors.New("Missing WHERE clause while deleting"))
|
scope.Err(errors.New("missing WHERE clause while deleting"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
@ -40,7 +40,7 @@ func deleteCallback(scope *Scope) {
|
|||||||
"UPDATE %v SET %v=%v%v%v",
|
"UPDATE %v SET %v=%v%v%v",
|
||||||
scope.QuotedTableName(),
|
scope.QuotedTableName(),
|
||||||
scope.Quote(deletedAtField.DBName),
|
scope.Quote(deletedAtField.DBName),
|
||||||
scope.AddToVars(NowFunc()),
|
scope.AddToVars(scope.db.nowFunc()),
|
||||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||||
addExtraSpaceIfExist(extraOption),
|
addExtraSpaceIfExist(extraOption),
|
||||||
)).Exec()
|
)).Exec()
|
||||||
|
@ -19,7 +19,12 @@ func queryCallback(scope *Scope) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer scope.trace(NowFunc())
|
//we are only preloading relations, dont touch base model
|
||||||
|
if _, skip := scope.InstanceGet("gorm:only_preload"); skip {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer scope.trace(scope.db.nowFunc())
|
||||||
|
|
||||||
var (
|
var (
|
||||||
isSlice, isPtr bool
|
isSlice, isPtr bool
|
||||||
|
@ -100,7 +100,7 @@ func autoPreload(scope *Scope) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if val, ok := field.TagSettings["PRELOAD"]; ok {
|
if val, ok := field.TagSettingsGet("PRELOAD"); ok {
|
||||||
if preload, err := strconv.ParseBool(val); err != nil {
|
if preload, err := strconv.ParseBool(val); err != nil {
|
||||||
scope.Err(errors.New("invalid preload option"))
|
scope.Err(errors.New("invalid preload option"))
|
||||||
return
|
return
|
||||||
@ -161,14 +161,17 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
|
|||||||
)
|
)
|
||||||
|
|
||||||
if indirectScopeValue.Kind() == reflect.Slice {
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
|
foreignValuesToResults := make(map[string]reflect.Value)
|
||||||
|
for i := 0; i < resultsValue.Len(); i++ {
|
||||||
|
result := resultsValue.Index(i)
|
||||||
|
foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames))
|
||||||
|
foreignValuesToResults[foreignValues] = result
|
||||||
|
}
|
||||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
for i := 0; i < resultsValue.Len(); i++ {
|
indirectValue := indirect(indirectScopeValue.Index(j))
|
||||||
result := resultsValue.Index(i)
|
valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames))
|
||||||
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
if result, found := foreignValuesToResults[valueString]; found {
|
||||||
if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
|
indirectValue.FieldByName(field.Name).Set(result)
|
||||||
indirectValue.FieldByName(field.Name).Set(result)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -255,13 +258,21 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
|
|||||||
indirectScopeValue = scope.IndirectValue()
|
indirectScopeValue = scope.IndirectValue()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
foreignFieldToObjects := make(map[string][]*reflect.Value)
|
||||||
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
|
object := indirect(indirectScopeValue.Index(j))
|
||||||
|
valueString := toString(getValueFromFields(object, relation.ForeignFieldNames))
|
||||||
|
foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for i := 0; i < resultsValue.Len(); i++ {
|
for i := 0; i < resultsValue.Len(); i++ {
|
||||||
result := resultsValue.Index(i)
|
result := resultsValue.Index(i)
|
||||||
if indirectScopeValue.Kind() == reflect.Slice {
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
value := getValueFromFields(result, relation.AssociationForeignFieldNames)
|
valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames))
|
||||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
if objects, found := foreignFieldToObjects[valueString]; found {
|
||||||
object := indirect(indirectScopeValue.Index(j))
|
for _, object := range objects {
|
||||||
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
|
|
||||||
object.FieldByName(field.Name).Set(result)
|
object.FieldByName(field.Name).Set(result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -380,14 +391,20 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
|
|||||||
key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
|
key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
|
||||||
fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
|
fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
|
||||||
}
|
}
|
||||||
for source, link := range linkHash {
|
|
||||||
for i, field := range fieldsSourceMap[source] {
|
for source, fields := range fieldsSourceMap {
|
||||||
|
for _, f := range fields {
|
||||||
//If not 0 this means Value is a pointer and we already added preloaded models to it
|
//If not 0 this means Value is a pointer and we already added preloaded models to it
|
||||||
if fieldsSourceMap[source][i].Len() != 0 {
|
if f.Len() != 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
v := reflect.MakeSlice(f.Type(), 0, 0)
|
||||||
|
if len(linkHash[source]) > 0 {
|
||||||
|
v = reflect.Append(f, linkHash[source]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.Set(v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import "database/sql"
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
// Define callbacks for row query
|
// Define callbacks for row query
|
||||||
func init() {
|
func init() {
|
||||||
@ -20,6 +23,9 @@ type RowsQueryResult struct {
|
|||||||
func rowQueryCallback(scope *Scope) {
|
func rowQueryCallback(scope *Scope) {
|
||||||
if result, ok := scope.InstanceGet("row_query_result"); ok {
|
if result, ok := scope.InstanceGet("row_query_result"); ok {
|
||||||
scope.prepareQuerySQL()
|
scope.prepareQuerySQL()
|
||||||
|
if str, ok := scope.Get("gorm:query_option"); ok {
|
||||||
|
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
|
||||||
|
}
|
||||||
|
|
||||||
if rowResult, ok := result.(*RowQueryResult); ok {
|
if rowResult, ok := result.(*RowQueryResult); ok {
|
||||||
rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
|
rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
|
||||||
|
@ -21,9 +21,7 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea
|
|||||||
|
|
||||||
if v, ok := value.(string); ok {
|
if v, ok := value.(string); ok {
|
||||||
v = strings.ToLower(v)
|
v = strings.ToLower(v)
|
||||||
if v == "false" || v != "skip" {
|
return v == "true"
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
@ -36,26 +34,28 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea
|
|||||||
if value, ok := scope.Get("gorm:save_associations"); ok {
|
if value, ok := scope.Get("gorm:save_associations"); ok {
|
||||||
autoUpdate = checkTruth(value)
|
autoUpdate = checkTruth(value)
|
||||||
autoCreate = autoUpdate
|
autoCreate = autoUpdate
|
||||||
} else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok {
|
saveReference = autoUpdate
|
||||||
|
} else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok {
|
||||||
autoUpdate = checkTruth(value)
|
autoUpdate = checkTruth(value)
|
||||||
autoCreate = autoUpdate
|
autoCreate = autoUpdate
|
||||||
|
saveReference = autoUpdate
|
||||||
}
|
}
|
||||||
|
|
||||||
if value, ok := scope.Get("gorm:association_autoupdate"); ok {
|
if value, ok := scope.Get("gorm:association_autoupdate"); ok {
|
||||||
autoUpdate = checkTruth(value)
|
autoUpdate = checkTruth(value)
|
||||||
} else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok {
|
} else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok {
|
||||||
autoUpdate = checkTruth(value)
|
autoUpdate = checkTruth(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
if value, ok := scope.Get("gorm:association_autocreate"); ok {
|
if value, ok := scope.Get("gorm:association_autocreate"); ok {
|
||||||
autoCreate = checkTruth(value)
|
autoCreate = checkTruth(value)
|
||||||
} else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok {
|
} else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok {
|
||||||
autoCreate = checkTruth(value)
|
autoCreate = checkTruth(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
if value, ok := scope.Get("gorm:association_save_reference"); ok {
|
if value, ok := scope.Get("gorm:association_save_reference"); ok {
|
||||||
saveReference = checkTruth(value)
|
saveReference = checkTruth(value)
|
||||||
} else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok {
|
} else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok {
|
||||||
saveReference = checkTruth(value)
|
saveReference = checkTruth(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,7 @@ func afterCreate1(s *Scope) {}
|
|||||||
func afterCreate2(s *Scope) {}
|
func afterCreate2(s *Scope) {}
|
||||||
|
|
||||||
func TestRegisterCallback(t *testing.T) {
|
func TestRegisterCallback(t *testing.T) {
|
||||||
var callback = &Callback{}
|
var callback = &Callback{logger: defaultLogger}
|
||||||
|
|
||||||
callback.Create().Register("before_create1", beforeCreate1)
|
callback.Create().Register("before_create1", beforeCreate1)
|
||||||
callback.Create().Register("before_create2", beforeCreate2)
|
callback.Create().Register("before_create2", beforeCreate2)
|
||||||
@ -37,7 +37,7 @@ func TestRegisterCallback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRegisterCallbackWithOrder(t *testing.T) {
|
func TestRegisterCallbackWithOrder(t *testing.T) {
|
||||||
var callback1 = &Callback{}
|
var callback1 = &Callback{logger: defaultLogger}
|
||||||
callback1.Create().Register("before_create1", beforeCreate1)
|
callback1.Create().Register("before_create1", beforeCreate1)
|
||||||
callback1.Create().Register("create", create)
|
callback1.Create().Register("create", create)
|
||||||
callback1.Create().Register("after_create1", afterCreate1)
|
callback1.Create().Register("after_create1", afterCreate1)
|
||||||
@ -46,7 +46,7 @@ func TestRegisterCallbackWithOrder(t *testing.T) {
|
|||||||
t.Errorf("register callback with order")
|
t.Errorf("register callback with order")
|
||||||
}
|
}
|
||||||
|
|
||||||
var callback2 = &Callback{}
|
var callback2 = &Callback{logger: defaultLogger}
|
||||||
|
|
||||||
callback2.Update().Register("create", create)
|
callback2.Update().Register("create", create)
|
||||||
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
|
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
|
||||||
@ -60,7 +60,7 @@ func TestRegisterCallbackWithOrder(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
||||||
var callback1 = &Callback{}
|
var callback1 = &Callback{logger: defaultLogger}
|
||||||
|
|
||||||
callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
|
callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback1.Query().Register("before_create1", beforeCreate1)
|
callback1.Query().Register("before_create1", beforeCreate1)
|
||||||
@ -70,7 +70,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
|||||||
t.Errorf("register callback with order")
|
t.Errorf("register callback with order")
|
||||||
}
|
}
|
||||||
|
|
||||||
var callback2 = &Callback{}
|
var callback2 = &Callback{logger: defaultLogger}
|
||||||
|
|
||||||
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
|
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
|
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
|
||||||
@ -86,7 +86,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
|||||||
func replaceCreate(s *Scope) {}
|
func replaceCreate(s *Scope) {}
|
||||||
|
|
||||||
func TestReplaceCallback(t *testing.T) {
|
func TestReplaceCallback(t *testing.T) {
|
||||||
var callback = &Callback{}
|
var callback = &Callback{logger: defaultLogger}
|
||||||
|
|
||||||
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback.Create().Register("before_create1", beforeCreate1)
|
callback.Create().Register("before_create1", beforeCreate1)
|
||||||
@ -99,7 +99,7 @@ func TestReplaceCallback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRemoveCallback(t *testing.T) {
|
func TestRemoveCallback(t *testing.T) {
|
||||||
var callback = &Callback{}
|
var callback = &Callback{logger: defaultLogger}
|
||||||
|
|
||||||
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback.Create().Register("before_create1", beforeCreate1)
|
callback.Create().Register("before_create1", beforeCreate1)
|
||||||
|
@ -34,7 +34,7 @@ func assignUpdatingAttributesCallback(scope *Scope) {
|
|||||||
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
|
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
|
||||||
func beforeUpdateCallback(scope *Scope) {
|
func beforeUpdateCallback(scope *Scope) {
|
||||||
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
|
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
|
||||||
scope.Err(errors.New("Missing WHERE clause while updating"))
|
scope.Err(errors.New("missing WHERE clause while updating"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||||
@ -50,7 +50,7 @@ func beforeUpdateCallback(scope *Scope) {
|
|||||||
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
|
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
|
||||||
func updateTimeStampForUpdateCallback(scope *Scope) {
|
func updateTimeStampForUpdateCallback(scope *Scope) {
|
||||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||||
scope.SetColumn("UpdatedAt", NowFunc())
|
scope.SetColumn("UpdatedAt", scope.db.nowFunc())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,8 +75,10 @@ func updateCallback(scope *Scope) {
|
|||||||
} else {
|
} else {
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if scope.changeableField(field) {
|
if scope.changeableField(field) {
|
||||||
if !field.IsPrimaryKey && field.IsNormal {
|
if !field.IsPrimaryKey && field.IsNormal && (field.Name != "CreatedAt" || !field.IsBlank) {
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||||
|
}
|
||||||
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||||
for _, foreignKey := range relationship.ForeignDBNames {
|
for _, foreignKey := range relationship.ForeignDBNames {
|
||||||
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
||||||
|
@ -2,11 +2,10 @@ package gorm_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
|
||||||
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Product) BeforeCreate() (err error) {
|
func (s *Product) BeforeCreate() (err error) {
|
||||||
@ -175,3 +174,46 @@ func TestCallbacksWithErrors(t *testing.T) {
|
|||||||
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
|
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetCallback(t *testing.T) {
|
||||||
|
scope := DB.NewScope(nil)
|
||||||
|
|
||||||
|
if DB.Callback().Create().Get("gorm:test_callback") != nil {
|
||||||
|
t.Errorf("`gorm:test_callback` should be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
|
||||||
|
callback := DB.Callback().Create().Get("gorm:test_callback")
|
||||||
|
if callback == nil {
|
||||||
|
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||||
|
}
|
||||||
|
callback(scope)
|
||||||
|
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
|
||||||
|
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
|
||||||
|
callback = DB.Callback().Create().Get("gorm:test_callback")
|
||||||
|
if callback == nil {
|
||||||
|
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||||
|
}
|
||||||
|
callback(scope)
|
||||||
|
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
|
||||||
|
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Callback().Create().Remove("gorm:test_callback")
|
||||||
|
if DB.Callback().Create().Get("gorm:test_callback") != nil {
|
||||||
|
t.Errorf("`gorm:test_callback` should be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
|
||||||
|
callback = DB.Callback().Create().Get("gorm:test_callback")
|
||||||
|
if callback == nil {
|
||||||
|
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||||
|
}
|
||||||
|
callback(scope)
|
||||||
|
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
|
||||||
|
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -101,6 +101,46 @@ func TestCreateWithExistingTimestamp(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateWithNowFuncOverride(t *testing.T) {
|
||||||
|
user1 := User{Name: "CreateUserTimestampOverride"}
|
||||||
|
|
||||||
|
timeA := now.MustParse("2016-01-01")
|
||||||
|
|
||||||
|
// do DB.New() because we don't want this test to affect other tests
|
||||||
|
db1 := DB.New()
|
||||||
|
// set the override to use static timeA
|
||||||
|
db1.SetNowFuncOverride(func() time.Time {
|
||||||
|
return timeA
|
||||||
|
})
|
||||||
|
// call .New again to check the override is carried over as well during clone
|
||||||
|
db1 = db1.New()
|
||||||
|
|
||||||
|
db1.Save(&user1)
|
||||||
|
|
||||||
|
if user1.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
|
||||||
|
t.Errorf("CreatedAt be using the nowFuncOverride")
|
||||||
|
}
|
||||||
|
if user1.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
|
||||||
|
t.Errorf("UpdatedAt be using the nowFuncOverride")
|
||||||
|
}
|
||||||
|
|
||||||
|
// now create another user with a fresh DB.Now() that doesn't have the nowFuncOverride set
|
||||||
|
// to make sure that setting it only affected the above instance
|
||||||
|
|
||||||
|
user2 := User{Name: "CreateUserTimestampOverrideNoMore"}
|
||||||
|
|
||||||
|
db2 := DB.New()
|
||||||
|
|
||||||
|
db2.Save(&user2)
|
||||||
|
|
||||||
|
if user2.CreatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) {
|
||||||
|
t.Errorf("CreatedAt no longer be using the nowFuncOverride")
|
||||||
|
}
|
||||||
|
if user2.UpdatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) {
|
||||||
|
t.Errorf("UpdatedAt no longer be using the nowFuncOverride")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type AutoIncrementUser struct {
|
type AutoIncrementUser struct {
|
||||||
User
|
User
|
||||||
Sequence uint `gorm:"AUTO_INCREMENT"`
|
Sequence uint `gorm:"AUTO_INCREMENT"`
|
||||||
@ -229,3 +269,20 @@ func TestOmitWithCreate(t *testing.T) {
|
|||||||
t.Errorf("Should not create omitted relationships")
|
t.Errorf("Should not create omitted relationships")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateIgnore(t *testing.T) {
|
||||||
|
float := 35.03554004971999
|
||||||
|
now := time.Now()
|
||||||
|
user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float}
|
||||||
|
|
||||||
|
if !DB.NewRecord(user) || !DB.NewRecord(&user) {
|
||||||
|
t.Error("User should be new record before create")
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Create(&user).RowsAffected; count != 1 {
|
||||||
|
t.Error("There should be one record be affected when create record")
|
||||||
|
}
|
||||||
|
if DB.Dialect().GetName() == "mysql" && DB.Set("gorm:insert_modifier", "IGNORE").Create(&user).Error != nil {
|
||||||
|
t.Error("Should ignore duplicate user insert by insert modifier:IGNORE ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -289,6 +289,9 @@ type SelfReferencingUser struct {
|
|||||||
func TestSelfReferencingMany2ManyColumn(t *testing.T) {
|
func TestSelfReferencingMany2ManyColumn(t *testing.T) {
|
||||||
DB.DropTable(&SelfReferencingUser{}, "UserFriends")
|
DB.DropTable(&SelfReferencingUser{}, "UserFriends")
|
||||||
DB.AutoMigrate(&SelfReferencingUser{})
|
DB.AutoMigrate(&SelfReferencingUser{})
|
||||||
|
if !DB.HasTable("UserFriends") {
|
||||||
|
t.Errorf("auto migrate error, table UserFriends should be created")
|
||||||
|
}
|
||||||
|
|
||||||
friend1 := SelfReferencingUser{Name: "friend1_m2m"}
|
friend1 := SelfReferencingUser{Name: "friend1_m2m"}
|
||||||
if err := DB.Create(&friend1).Error; err != nil {
|
if err := DB.Create(&friend1).Error; err != nil {
|
||||||
@ -313,6 +316,14 @@ func TestSelfReferencingMany2ManyColumn(t *testing.T) {
|
|||||||
t.Errorf("Should find created friends correctly")
|
t.Errorf("Should find created friends correctly")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var count int
|
||||||
|
if err := DB.Table("UserFriends").Count(&count).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
|
}
|
||||||
|
if count == 0 {
|
||||||
|
t.Errorf("table UserFriends should have records")
|
||||||
|
}
|
||||||
|
|
||||||
var newUser = SelfReferencingUser{}
|
var newUser = SelfReferencingUser{}
|
||||||
|
|
||||||
if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil {
|
if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil {
|
||||||
|
19
dialect.go
19
dialect.go
@ -40,6 +40,8 @@ type Dialect interface {
|
|||||||
LimitAndOffsetSQL(limit, offset interface{}) string
|
LimitAndOffsetSQL(limit, offset interface{}) string
|
||||||
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
||||||
SelectFromDummyTable() string
|
SelectFromDummyTable() string
|
||||||
|
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
|
||||||
|
LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
|
||||||
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
||||||
LastInsertIDReturningSuffix(tableName, columnName string) string
|
LastInsertIDReturningSuffix(tableName, columnName string) string
|
||||||
// DefaultValueStr
|
// DefaultValueStr
|
||||||
@ -48,6 +50,9 @@ type Dialect interface {
|
|||||||
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
|
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
|
||||||
BuildKeyName(kind, tableName string, fields ...string) string
|
BuildKeyName(kind, tableName string, fields ...string) string
|
||||||
|
|
||||||
|
// NormalizeIndexAndColumn returns valid index name and column name depending on each dialect
|
||||||
|
NormalizeIndexAndColumn(indexName, columnName string) (string, string)
|
||||||
|
|
||||||
// CurrentDatabase return current database name
|
// CurrentDatabase return current database name
|
||||||
CurrentDatabase() string
|
CurrentDatabase() string
|
||||||
}
|
}
|
||||||
@ -83,7 +88,7 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel
|
|||||||
// Get redirected field type
|
// Get redirected field type
|
||||||
var (
|
var (
|
||||||
reflectType = field.Struct.Type
|
reflectType = field.Struct.Type
|
||||||
dataType = field.TagSettings["TYPE"]
|
dataType, _ = field.TagSettingsGet("TYPE")
|
||||||
)
|
)
|
||||||
|
|
||||||
for reflectType.Kind() == reflect.Ptr {
|
for reflectType.Kind() == reflect.Ptr {
|
||||||
@ -112,18 +117,24 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Default Size
|
// Default Size
|
||||||
if num, ok := field.TagSettings["SIZE"]; ok {
|
if num, ok := field.TagSettingsGet("SIZE"); ok {
|
||||||
size, _ = strconv.Atoi(num)
|
size, _ = strconv.Atoi(num)
|
||||||
} else {
|
} else {
|
||||||
size = 255
|
size = 255
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default type from tag setting
|
// Default type from tag setting
|
||||||
additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
|
notNull, _ := field.TagSettingsGet("NOT NULL")
|
||||||
if value, ok := field.TagSettings["DEFAULT"]; ok {
|
unique, _ := field.TagSettingsGet("UNIQUE")
|
||||||
|
additionalType = notNull + " " + unique
|
||||||
|
if value, ok := field.TagSettingsGet("DEFAULT"); ok {
|
||||||
additionalType = additionalType + " DEFAULT " + value
|
additionalType = additionalType + " DEFAULT " + value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if value, ok := field.TagSettingsGet("COMMENT"); ok {
|
||||||
|
additionalType = additionalType + " COMMENT " + value
|
||||||
|
}
|
||||||
|
|
||||||
return fieldValue, dataType, size, strings.TrimSpace(additionalType)
|
return fieldValue, dataType, size, strings.TrimSpace(additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,6 +9,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var keyNameRegex = regexp.MustCompile("[^a-zA-Z0-9]+")
|
||||||
|
|
||||||
// DefaultForeignKeyNamer contains the default foreign key name generator method
|
// DefaultForeignKeyNamer contains the default foreign key name generator method
|
||||||
type DefaultForeignKeyNamer struct {
|
type DefaultForeignKeyNamer struct {
|
||||||
}
|
}
|
||||||
@ -39,7 +41,7 @@ func (commonDialect) Quote(key string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool {
|
func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool {
|
||||||
if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
|
||||||
return strings.ToLower(value) != "false"
|
return strings.ToLower(value) != "false"
|
||||||
}
|
}
|
||||||
return field.IsPrimaryKey
|
return field.IsPrimaryKey
|
||||||
@ -155,6 +157,10 @@ func (commonDialect) SelectFromDummyTable() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@ -166,10 +172,15 @@ func (commonDialect) DefaultValueStr() string {
|
|||||||
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
|
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
|
||||||
func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
|
func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
|
||||||
keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
|
keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
|
||||||
keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_")
|
keyName = keyNameRegex.ReplaceAllString(keyName, "_")
|
||||||
return keyName
|
return keyName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NormalizeIndexAndColumn returns argument's index name and column name without doing anything
|
||||||
|
func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
|
||||||
|
return indexName, columnName
|
||||||
|
}
|
||||||
|
|
||||||
// IsByteArrayOrSlice returns true of the reflected value is an array or slice
|
// IsByteArrayOrSlice returns true of the reflected value is an array or slice
|
||||||
func IsByteArrayOrSlice(value reflect.Value) bool {
|
func IsByteArrayOrSlice(value reflect.Value) bool {
|
||||||
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
|
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
|
||||||
|
@ -11,6 +11,8 @@ import (
|
|||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var mysqlIndexRegex = regexp.MustCompile(`^(.+)\((\d+)\)$`)
|
||||||
|
|
||||||
type mysql struct {
|
type mysql struct {
|
||||||
commonDialect
|
commonDialect
|
||||||
}
|
}
|
||||||
@ -33,9 +35,9 @@ func (s *mysql) DataTypeOf(field *StructField) string {
|
|||||||
|
|
||||||
// MySQL allows only one auto increment column per table, and it must
|
// MySQL allows only one auto increment column per table, and it must
|
||||||
// be a KEY column.
|
// be a KEY column.
|
||||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
|
||||||
if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey {
|
if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey {
|
||||||
delete(field.TagSettings, "AUTO_INCREMENT")
|
field.TagSettingsDelete("AUTO_INCREMENT")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,42 +47,42 @@ func (s *mysql) DataTypeOf(field *StructField) string {
|
|||||||
sqlType = "boolean"
|
sqlType = "boolean"
|
||||||
case reflect.Int8:
|
case reflect.Int8:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "tinyint AUTO_INCREMENT"
|
sqlType = "tinyint AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "tinyint"
|
sqlType = "tinyint"
|
||||||
}
|
}
|
||||||
case reflect.Int, reflect.Int16, reflect.Int32:
|
case reflect.Int, reflect.Int16, reflect.Int32:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "int AUTO_INCREMENT"
|
sqlType = "int AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "int"
|
sqlType = "int"
|
||||||
}
|
}
|
||||||
case reflect.Uint8:
|
case reflect.Uint8:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "tinyint unsigned AUTO_INCREMENT"
|
sqlType = "tinyint unsigned AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "tinyint unsigned"
|
sqlType = "tinyint unsigned"
|
||||||
}
|
}
|
||||||
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "int unsigned AUTO_INCREMENT"
|
sqlType = "int unsigned AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "int unsigned"
|
sqlType = "int unsigned"
|
||||||
}
|
}
|
||||||
case reflect.Int64:
|
case reflect.Int64:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "bigint AUTO_INCREMENT"
|
sqlType = "bigint AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "bigint"
|
sqlType = "bigint"
|
||||||
}
|
}
|
||||||
case reflect.Uint64:
|
case reflect.Uint64:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "bigint unsigned AUTO_INCREMENT"
|
sqlType = "bigint unsigned AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "bigint unsigned"
|
sqlType = "bigint unsigned"
|
||||||
@ -96,11 +98,11 @@ func (s *mysql) DataTypeOf(field *StructField) string {
|
|||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||||
precision := ""
|
precision := ""
|
||||||
if p, ok := field.TagSettings["PRECISION"]; ok {
|
if p, ok := field.TagSettingsGet("PRECISION"); ok {
|
||||||
precision = fmt.Sprintf("(%s)", p)
|
precision = fmt.Sprintf("(%s)", p)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := field.TagSettings["NOT NULL"]; ok {
|
if _, ok := field.TagSettings["NOT NULL"]; ok || field.IsPrimaryKey {
|
||||||
sqlType = fmt.Sprintf("DATETIME%v", precision)
|
sqlType = fmt.Sprintf("DATETIME%v", precision)
|
||||||
} else {
|
} else {
|
||||||
sqlType = fmt.Sprintf("DATETIME%v NULL", precision)
|
sqlType = fmt.Sprintf("DATETIME%v NULL", precision)
|
||||||
@ -118,7 +120,7 @@ func (s *mysql) DataTypeOf(field *StructField) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if sqlType == "" {
|
if sqlType == "" {
|
||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
|
panic(fmt.Sprintf("invalid sql type %s (%s) in field %s for mysql", dataValue.Type().Name(), dataValue.Kind().String(), field.Name))
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.TrimSpace(additionalType) == "" {
|
if strings.TrimSpace(additionalType) == "" {
|
||||||
@ -178,7 +180,7 @@ func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string {
|
|||||||
bs := h.Sum(nil)
|
bs := h.Sum(nil)
|
||||||
|
|
||||||
// sha1 is 40 characters, keep first 24 characters of destination
|
// sha1 is 40 characters, keep first 24 characters of destination
|
||||||
destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_"))
|
destRunes := []rune(keyNameRegex.ReplaceAllString(fields[0], "_"))
|
||||||
if len(destRunes) > 24 {
|
if len(destRunes) > 24 {
|
||||||
destRunes = destRunes[:24]
|
destRunes = destRunes[:24]
|
||||||
}
|
}
|
||||||
@ -186,6 +188,17 @@ func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string {
|
|||||||
return fmt.Sprintf("%s%x", string(destRunes), bs)
|
return fmt.Sprintf("%s%x", string(destRunes), bs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NormalizeIndexAndColumn returns index name and column name for specify an index prefix length if needed
|
||||||
|
func (mysql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
|
||||||
|
submatch := mysqlIndexRegex.FindStringSubmatch(indexName)
|
||||||
|
if len(submatch) != 3 {
|
||||||
|
return indexName, columnName
|
||||||
|
}
|
||||||
|
indexName = submatch[1]
|
||||||
|
columnName = fmt.Sprintf("%s(%s)", columnName, submatch[2])
|
||||||
|
return indexName, columnName
|
||||||
|
}
|
||||||
|
|
||||||
func (mysql) DefaultValueStr() string {
|
func (mysql) DefaultValueStr() string {
|
||||||
return "VALUES()"
|
return "VALUES()"
|
||||||
}
|
}
|
||||||
|
@ -34,14 +34,14 @@ func (s *postgres) DataTypeOf(field *StructField) string {
|
|||||||
sqlType = "boolean"
|
sqlType = "boolean"
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "serial"
|
sqlType = "serial"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "integer"
|
sqlType = "integer"
|
||||||
}
|
}
|
||||||
case reflect.Int64, reflect.Uint32, reflect.Uint64:
|
case reflect.Int64, reflect.Uint32, reflect.Uint64:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "bigserial"
|
sqlType = "bigserial"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "bigint"
|
sqlType = "bigint"
|
||||||
@ -49,7 +49,7 @@ func (s *postgres) DataTypeOf(field *StructField) string {
|
|||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
sqlType = "numeric"
|
sqlType = "numeric"
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
if _, ok := field.TagSettings["SIZE"]; !ok {
|
if _, ok := field.TagSettingsGet("SIZE"); !ok {
|
||||||
size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
|
size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,6 +120,10 @@ func (s postgres) CurrentDatabase() (name string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
|
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
|
||||||
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
||||||
}
|
}
|
||||||
|
@ -29,14 +29,14 @@ func (s *sqlite3) DataTypeOf(field *StructField) string {
|
|||||||
sqlType = "bool"
|
sqlType = "bool"
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "integer primary key autoincrement"
|
sqlType = "integer primary key autoincrement"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "integer"
|
sqlType = "integer"
|
||||||
}
|
}
|
||||||
case reflect.Int64, reflect.Uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "integer primary key autoincrement"
|
sqlType = "integer primary key autoincrement"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "bigint"
|
sqlType = "bigint"
|
||||||
|
@ -18,7 +18,7 @@ import (
|
|||||||
func setIdentityInsert(scope *gorm.Scope) {
|
func setIdentityInsert(scope *gorm.Scope) {
|
||||||
if scope.Dialect().GetName() == "mssql" {
|
if scope.Dialect().GetName() == "mssql" {
|
||||||
for _, field := range scope.PrimaryFields() {
|
for _, field := range scope.PrimaryFields() {
|
||||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank {
|
if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank {
|
||||||
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
|
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
|
||||||
scope.InstanceSet("mssql:identity_insert_on", true)
|
scope.InstanceSet("mssql:identity_insert_on", true)
|
||||||
}
|
}
|
||||||
@ -70,14 +70,14 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string {
|
|||||||
sqlType = "bit"
|
sqlType = "bit"
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "int IDENTITY(1,1)"
|
sqlType = "int IDENTITY(1,1)"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "int"
|
sqlType = "int"
|
||||||
}
|
}
|
||||||
case reflect.Int64, reflect.Uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
if s.fieldCanAutoIncrement(field) {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
|
||||||
sqlType = "bigint IDENTITY(1,1)"
|
sqlType = "bigint IDENTITY(1,1)"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "bigint"
|
sqlType = "bigint"
|
||||||
@ -116,7 +116,7 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
|
func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
|
||||||
if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
|
||||||
return value != "FALSE"
|
return value != "FALSE"
|
||||||
}
|
}
|
||||||
return field.IsPrimaryKey
|
return field.IsPrimaryKey
|
||||||
@ -190,6 +190,14 @@ func (mssql) SelectFromDummyTable() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
|
||||||
|
if len(columns) == 0 {
|
||||||
|
// No OUTPUT to query
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("OUTPUT Inserted.%v", columnName)
|
||||||
|
}
|
||||||
|
|
||||||
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@ -198,6 +206,11 @@ func (mssql) DefaultValueStr() string {
|
|||||||
return "DEFAULT VALUES"
|
return "DEFAULT VALUES"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NormalizeIndexAndColumn returns argument's index name and column name without doing anything
|
||||||
|
func (mssql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
|
||||||
|
return indexName, columnName
|
||||||
|
}
|
||||||
|
|
||||||
func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
|
func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
|
||||||
if strings.Contains(tableName, ".") {
|
if strings.Contains(tableName, ".") {
|
||||||
splitStrings := strings.SplitN(tableName, ".", 2)
|
splitStrings := strings.SplitN(tableName, ".", 2)
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
"github.com/lib/pq/hstore"
|
"github.com/lib/pq/hstore"
|
||||||
)
|
)
|
||||||
|
14
errors.go
14
errors.go
@ -6,11 +6,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// ErrRecordNotFound record not found error, happens when only haven't find any matched data when looking up with a struct, finding a slice won't return this error
|
// ErrRecordNotFound returns a "record not found error". Occurs only when attempting to query the database with a struct; querying with a slice won't return this error
|
||||||
ErrRecordNotFound = errors.New("record not found")
|
ErrRecordNotFound = errors.New("record not found")
|
||||||
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
|
// ErrInvalidSQL occurs when you attempt a query with invalid SQL
|
||||||
ErrInvalidSQL = errors.New("invalid SQL")
|
ErrInvalidSQL = errors.New("invalid SQL")
|
||||||
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
|
// ErrInvalidTransaction occurs when you are trying to `Commit` or `Rollback`
|
||||||
ErrInvalidTransaction = errors.New("no valid transaction")
|
ErrInvalidTransaction = errors.New("no valid transaction")
|
||||||
// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
|
// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
|
||||||
ErrCantStartTransaction = errors.New("can't start transaction")
|
ErrCantStartTransaction = errors.New("can't start transaction")
|
||||||
@ -21,7 +21,7 @@ var (
|
|||||||
// Errors contains all happened errors
|
// Errors contains all happened errors
|
||||||
type Errors []error
|
type Errors []error
|
||||||
|
|
||||||
// IsRecordNotFoundError returns current error has record not found error or not
|
// IsRecordNotFoundError returns true if error contains a RecordNotFound error
|
||||||
func IsRecordNotFoundError(err error) bool {
|
func IsRecordNotFoundError(err error) bool {
|
||||||
if errs, ok := err.(Errors); ok {
|
if errs, ok := err.(Errors); ok {
|
||||||
for _, err := range errs {
|
for _, err := range errs {
|
||||||
@ -33,12 +33,12 @@ func IsRecordNotFoundError(err error) bool {
|
|||||||
return err == ErrRecordNotFound
|
return err == ErrRecordNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetErrors gets all happened errors
|
// GetErrors gets all errors that have occurred and returns a slice of errors (Error type)
|
||||||
func (errs Errors) GetErrors() []error {
|
func (errs Errors) GetErrors() []error {
|
||||||
return errs
|
return errs
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add adds an error
|
// Add adds an error to a given slice of errors
|
||||||
func (errs Errors) Add(newErrors ...error) Errors {
|
func (errs Errors) Add(newErrors ...error) Errors {
|
||||||
for _, err := range newErrors {
|
for _, err := range newErrors {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -62,7 +62,7 @@ func (errs Errors) Add(newErrors ...error) Errors {
|
|||||||
return errs
|
return errs
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error format happened errors
|
// Error takes a slice of all errors that have occurred and returns it as a formatted string
|
||||||
func (errs Errors) Error() string {
|
func (errs Errors) Error() string {
|
||||||
var errors = []string{}
|
var errors = []string{}
|
||||||
for _, e := range errs {
|
for _, e := range errs {
|
||||||
|
10
field.go
10
field.go
@ -2,6 +2,7 @@ package gorm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -44,7 +45,14 @@ func (field *Field) Set(value interface{}) (err error) {
|
|||||||
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
|
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
|
||||||
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
|
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
|
||||||
} else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
|
} else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
|
||||||
err = scanner.Scan(reflectValue.Interface())
|
v := reflectValue.Interface()
|
||||||
|
if valuer, ok := v.(driver.Valuer); ok {
|
||||||
|
if v, err = valuer.Value(); err == nil {
|
||||||
|
err = scanner.Scan(v)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err = scanner.Scan(v)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
|
err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
@ -43,7 +46,85 @@ func TestCalculateField(t *testing.T) {
|
|||||||
|
|
||||||
if field, ok := scope.FieldByName("embedded_name"); !ok {
|
if field, ok := scope.FieldByName("embedded_name"); !ok {
|
||||||
t.Errorf("should find embedded field")
|
t.Errorf("should find embedded field")
|
||||||
} else if _, ok := field.TagSettings["NOT NULL"]; !ok {
|
} else if _, ok := field.TagSettingsGet("NOT NULL"); !ok {
|
||||||
t.Errorf("should find embedded field's tag settings")
|
t.Errorf("should find embedded field's tag settings")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UUID [16]byte
|
||||||
|
|
||||||
|
type NullUUID struct {
|
||||||
|
UUID
|
||||||
|
Valid bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func FromString(input string) (u UUID) {
|
||||||
|
src := []byte(input)
|
||||||
|
return FromBytes(src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func FromBytes(src []byte) (u UUID) {
|
||||||
|
dst := u[:]
|
||||||
|
hex.Decode(dst[0:4], src[0:8])
|
||||||
|
hex.Decode(dst[4:6], src[9:13])
|
||||||
|
hex.Decode(dst[6:8], src[14:18])
|
||||||
|
hex.Decode(dst[8:10], src[19:23])
|
||||||
|
hex.Decode(dst[10:], src[24:])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u UUID) String() string {
|
||||||
|
buf := make([]byte, 36)
|
||||||
|
src := u[:]
|
||||||
|
hex.Encode(buf[0:8], src[0:4])
|
||||||
|
buf[8] = '-'
|
||||||
|
hex.Encode(buf[9:13], src[4:6])
|
||||||
|
buf[13] = '-'
|
||||||
|
hex.Encode(buf[14:18], src[6:8])
|
||||||
|
buf[18] = '-'
|
||||||
|
hex.Encode(buf[19:23], src[8:10])
|
||||||
|
buf[23] = '-'
|
||||||
|
hex.Encode(buf[24:], src[10:])
|
||||||
|
return string(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u UUID) Value() (driver.Value, error) {
|
||||||
|
return u.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *UUID) Scan(src interface{}) error {
|
||||||
|
switch src := src.(type) {
|
||||||
|
case UUID: // support gorm convert from UUID to NullUUID
|
||||||
|
*u = src
|
||||||
|
return nil
|
||||||
|
case []byte:
|
||||||
|
*u = FromBytes(src)
|
||||||
|
return nil
|
||||||
|
case string:
|
||||||
|
*u = FromString(src)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("uuid: cannot convert %T to UUID", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *NullUUID) Scan(src interface{}) error {
|
||||||
|
u.Valid = true
|
||||||
|
return u.UUID.Scan(src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFieldSet(t *testing.T) {
|
||||||
|
type TestFieldSetNullUUID struct {
|
||||||
|
NullUUID NullUUID
|
||||||
|
}
|
||||||
|
scope := DB.NewScope(&TestFieldSetNullUUID{})
|
||||||
|
field := scope.Fields()[0]
|
||||||
|
err := field.Set(FromString("3034d44a-da03-11e8-b366-4a00070b9f00"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if id, ok := field.Field.Addr().Interface().(*NullUUID); !ok {
|
||||||
|
t.Fatal()
|
||||||
|
} else if !id.Valid || id.UUID.String() != "3034d44a-da03-11e8-b366-4a00070b9f00" {
|
||||||
|
t.Fatal(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
13
go.mod
Normal file
13
go.mod
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
module github.com/jinzhu/gorm
|
||||||
|
|
||||||
|
go 1.12
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/denisenkom/go-mssqldb v0.0.0-20190515213511-eb9f6a1743f3
|
||||||
|
github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5
|
||||||
|
github.com/go-sql-driver/mysql v1.4.1
|
||||||
|
github.com/jinzhu/inflection v1.0.0
|
||||||
|
github.com/jinzhu/now v1.0.1
|
||||||
|
github.com/lib/pq v1.1.1
|
||||||
|
github.com/mattn/go-sqlite3 v1.11.0
|
||||||
|
)
|
131
go.sum
Normal file
131
go.sum
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||||
|
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||||
|
cloud.google.com/go v0.37.4 h1:glPeL3BQJsbF6aIIYfZizMwc5LTYz250bDMjttbBGAU=
|
||||||
|
cloud.google.com/go v0.37.4/go.mod h1:NHPJ89PdicEuT9hdPXMROBD91xc5uRDxsMtSB16k7hw=
|
||||||
|
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||||
|
github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo=
|
||||||
|
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
|
||||||
|
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
|
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
||||||
|
github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
|
||||||
|
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
|
||||||
|
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/denisenkom/go-mssqldb v0.0.0-20190515213511-eb9f6a1743f3 h1:tkum0XDgfR0jcVVXuTsYv/erY2NnEDqwRojbxR1rBYA=
|
||||||
|
github.com/denisenkom/go-mssqldb v0.0.0-20190515213511-eb9f6a1743f3/go.mod h1:zAg7JM8CkOJ43xKXIj7eRO9kmWm/TW578qo+oDO6tuM=
|
||||||
|
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
|
||||||
|
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
|
||||||
|
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
|
||||||
|
github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DPaBjB8zlTR87/ElzFsnQfuHnVUVqpZZIcV5Y=
|
||||||
|
github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0=
|
||||||
|
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||||
|
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||||
|
github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
|
||||||
|
github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA=
|
||||||
|
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
|
||||||
|
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||||
|
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||||
|
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||||
|
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||||
|
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||||
|
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||||
|
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
|
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||||
|
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||||
|
github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ=
|
||||||
|
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||||
|
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
||||||
|
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
||||||
|
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
||||||
|
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
|
||||||
|
github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
|
||||||
|
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||||
|
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||||
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
|
github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M=
|
||||||
|
github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
|
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
|
||||||
|
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
||||||
|
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||||
|
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||||
|
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
|
||||||
|
github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4=
|
||||||
|
github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||||
|
github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q=
|
||||||
|
github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||||
|
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||||
|
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||||
|
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||||
|
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||||
|
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||||
|
github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw=
|
||||||
|
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
|
||||||
|
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
||||||
|
github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829/go.mod h1:p2iRAGwDERtqlqzRXnrOVns+ignqQo//hLXqYxZYVNs=
|
||||||
|
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
||||||
|
github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
||||||
|
github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
|
||||||
|
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
||||||
|
github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
||||||
|
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||||
|
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||||
|
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
|
go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk=
|
||||||
|
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||||
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
|
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI=
|
||||||
|
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
|
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||||
|
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
|
||||||
|
golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||||
|
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
|
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
|
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
|
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
||||||
|
golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||||
|
google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk=
|
||||||
|
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||||
|
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
|
||||||
|
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||||
|
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
|
||||||
|
google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
|
||||||
|
google.golang.org/genproto v0.0.0-20190404172233-64821d5d2107/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
|
||||||
|
google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
|
||||||
|
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||||
|
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
|
||||||
|
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||||
|
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||||
|
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||||
|
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
@ -1,6 +1,9 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import "database/sql"
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
|
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
|
||||||
type SQLCommon interface {
|
type SQLCommon interface {
|
||||||
@ -12,6 +15,7 @@ type SQLCommon interface {
|
|||||||
|
|
||||||
type sqlDb interface {
|
type sqlDb interface {
|
||||||
Begin() (*sql.Tx, error)
|
Begin() (*sql.Tx, error)
|
||||||
|
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type sqlTx interface {
|
type sqlTx interface {
|
||||||
|
13
logger.go
13
logger.go
@ -49,7 +49,11 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) {
|
|||||||
if indirectValue.IsValid() {
|
if indirectValue.IsValid() {
|
||||||
value = indirectValue.Interface()
|
value = indirectValue.Interface()
|
||||||
if t, ok := value.(time.Time); ok {
|
if t, ok := value.(time.Time); ok {
|
||||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05")))
|
if t.IsZero() {
|
||||||
|
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", "0000-00-00 00:00:00"))
|
||||||
|
} else {
|
||||||
|
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05")))
|
||||||
|
}
|
||||||
} else if b, ok := value.([]byte); ok {
|
} else if b, ok := value.([]byte); ok {
|
||||||
if str := string(b); isPrintable(str) {
|
if str := string(b); isPrintable(str) {
|
||||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
|
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
|
||||||
@ -63,7 +67,12 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) {
|
|||||||
formattedValues = append(formattedValues, "NULL")
|
formattedValues = append(formattedValues, "NULL")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
|
switch value.(type) {
|
||||||
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||||
|
formattedValues = append(formattedValues, fmt.Sprintf("%v", value))
|
||||||
|
default:
|
||||||
|
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
formattedValues = append(formattedValues, "NULL")
|
formattedValues = append(formattedValues, "NULL")
|
||||||
|
112
main.go
112
main.go
@ -1,6 +1,7 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -12,6 +13,7 @@ import (
|
|||||||
|
|
||||||
// DB contains information for current db connection
|
// DB contains information for current db connection
|
||||||
type DB struct {
|
type DB struct {
|
||||||
|
sync.RWMutex
|
||||||
Value interface{}
|
Value interface{}
|
||||||
Error error
|
Error error
|
||||||
RowsAffected int64
|
RowsAffected int64
|
||||||
@ -19,18 +21,29 @@ type DB struct {
|
|||||||
// single db
|
// single db
|
||||||
db SQLCommon
|
db SQLCommon
|
||||||
blockGlobalUpdate bool
|
blockGlobalUpdate bool
|
||||||
logMode int
|
logMode logModeValue
|
||||||
logger logger
|
logger logger
|
||||||
search *search
|
search *search
|
||||||
values map[string]interface{}
|
values sync.Map
|
||||||
|
|
||||||
// global db
|
// global db
|
||||||
parent *DB
|
parent *DB
|
||||||
callbacks *Callback
|
callbacks *Callback
|
||||||
dialect Dialect
|
dialect Dialect
|
||||||
singularTable bool
|
singularTable bool
|
||||||
|
|
||||||
|
// function to be used to override the creating of a new timestamp
|
||||||
|
nowFuncOverride func() time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type logModeValue int
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultLogMode logModeValue = iota
|
||||||
|
noLogMode
|
||||||
|
detailedLogMode
|
||||||
|
)
|
||||||
|
|
||||||
// Open initialize a new db connection, need to import driver first, e.g:
|
// Open initialize a new db connection, need to import driver first, e.g:
|
||||||
//
|
//
|
||||||
// import _ "github.com/go-sql-driver/mysql"
|
// import _ "github.com/go-sql-driver/mysql"
|
||||||
@ -72,7 +85,6 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
|||||||
db = &DB{
|
db = &DB{
|
||||||
db: dbSQL,
|
db: dbSQL,
|
||||||
logger: defaultLogger,
|
logger: defaultLogger,
|
||||||
values: map[string]interface{}{},
|
|
||||||
callbacks: DefaultCallback,
|
callbacks: DefaultCallback,
|
||||||
dialect: newDialect(dialect, dbSQL),
|
dialect: newDialect(dialect, dbSQL),
|
||||||
}
|
}
|
||||||
@ -130,7 +142,7 @@ func (s *DB) Dialect() Dialect {
|
|||||||
// db.Callback().Create().Register("update_created_at", updateCreated)
|
// db.Callback().Create().Register("update_created_at", updateCreated)
|
||||||
// Refer https://jinzhu.github.io/gorm/development.html#callbacks
|
// Refer https://jinzhu.github.io/gorm/development.html#callbacks
|
||||||
func (s *DB) Callback() *Callback {
|
func (s *DB) Callback() *Callback {
|
||||||
s.parent.callbacks = s.parent.callbacks.clone()
|
s.parent.callbacks = s.parent.callbacks.clone(s.logger)
|
||||||
return s.parent.callbacks
|
return s.parent.callbacks
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,13 +154,29 @@ func (s *DB) SetLogger(log logger) {
|
|||||||
// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs
|
// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs
|
||||||
func (s *DB) LogMode(enable bool) *DB {
|
func (s *DB) LogMode(enable bool) *DB {
|
||||||
if enable {
|
if enable {
|
||||||
s.logMode = 2
|
s.logMode = detailedLogMode
|
||||||
} else {
|
} else {
|
||||||
s.logMode = 1
|
s.logMode = noLogMode
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetNowFuncOverride set the function to be used when creating a new timestamp
|
||||||
|
func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB {
|
||||||
|
s.nowFuncOverride = nowFuncOverride
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get a new timestamp, using the provided nowFuncOverride on the DB instance if set,
|
||||||
|
// otherwise defaults to the global NowFunc()
|
||||||
|
func (s *DB) nowFunc() time.Time {
|
||||||
|
if s.nowFuncOverride != nil {
|
||||||
|
return s.nowFuncOverride()
|
||||||
|
}
|
||||||
|
|
||||||
|
return NowFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// BlockGlobalUpdate if true, generates an error on update/delete without where clause.
|
// BlockGlobalUpdate if true, generates an error on update/delete without where clause.
|
||||||
// This is to prevent eventual error with empty objects updates/deletions
|
// This is to prevent eventual error with empty objects updates/deletions
|
||||||
func (s *DB) BlockGlobalUpdate(enable bool) *DB {
|
func (s *DB) BlockGlobalUpdate(enable bool) *DB {
|
||||||
@ -163,7 +191,8 @@ func (s *DB) HasBlockGlobalUpdate() bool {
|
|||||||
|
|
||||||
// SingularTable use singular table by default
|
// SingularTable use singular table by default
|
||||||
func (s *DB) SingularTable(enable bool) {
|
func (s *DB) SingularTable(enable bool) {
|
||||||
modelStructsMap = sync.Map{}
|
s.parent.Lock()
|
||||||
|
defer s.parent.Unlock()
|
||||||
s.parent.singularTable = enable
|
s.parent.singularTable = enable
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -171,7 +200,13 @@ func (s *DB) SingularTable(enable bool) {
|
|||||||
func (s *DB) NewScope(value interface{}) *Scope {
|
func (s *DB) NewScope(value interface{}) *Scope {
|
||||||
dbClone := s.clone()
|
dbClone := s.clone()
|
||||||
dbClone.Value = value
|
dbClone.Value = value
|
||||||
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
|
scope := &Scope{db: dbClone, Value: value}
|
||||||
|
if s.search != nil {
|
||||||
|
scope.Search = s.search.clone()
|
||||||
|
} else {
|
||||||
|
scope.Search = &search{}
|
||||||
|
}
|
||||||
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryExpr returns the query as expr object
|
// QueryExpr returns the query as expr object
|
||||||
@ -291,6 +326,7 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
|
|||||||
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
||||||
newScope := s.NewScope(out)
|
newScope := s.NewScope(out)
|
||||||
newScope.Search.Limit(1)
|
newScope.Search.Limit(1)
|
||||||
|
|
||||||
return newScope.Set("gorm:order_by_primary_key", "ASC").
|
return newScope.Set("gorm:order_by_primary_key", "ASC").
|
||||||
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
@ -315,6 +351,11 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
|||||||
return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//Preloads preloads relations, don`t touch out
|
||||||
|
func (s *DB) Preloads(out interface{}) *DB {
|
||||||
|
return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db
|
||||||
|
}
|
||||||
|
|
||||||
// Scan scan value to a struct
|
// Scan scan value to a struct
|
||||||
func (s *DB) Scan(dest interface{}) *DB {
|
func (s *DB) Scan(dest interface{}) *DB {
|
||||||
return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
|
return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
|
||||||
@ -425,7 +466,7 @@ func (s *DB) Save(value interface{}) *DB {
|
|||||||
if !scope.PrimaryKeyZero() {
|
if !scope.PrimaryKeyZero() {
|
||||||
newDB := scope.callCallbacks(s.parent.callbacks.updates).db
|
newDB := scope.callCallbacks(s.parent.callbacks.updates).db
|
||||||
if newDB.Error == nil && newDB.RowsAffected == 0 {
|
if newDB.Error == nil && newDB.RowsAffected == 0 {
|
||||||
return s.New().FirstOrCreate(value)
|
return s.New().Table(scope.TableName()).FirstOrCreate(value)
|
||||||
}
|
}
|
||||||
return newDB
|
return newDB
|
||||||
}
|
}
|
||||||
@ -482,11 +523,16 @@ func (s *DB) Debug() *DB {
|
|||||||
return s.clone().LogMode(true)
|
return s.clone().LogMode(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Begin begin a transaction
|
// Begin begins a transaction
|
||||||
func (s *DB) Begin() *DB {
|
func (s *DB) Begin() *DB {
|
||||||
|
return s.BeginTx(context.Background(), &sql.TxOptions{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeginTx begins a transaction with options
|
||||||
|
func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
|
||||||
c := s.clone()
|
c := s.clone()
|
||||||
if db, ok := c.db.(sqlDb); ok && db != nil {
|
if db, ok := c.db.(sqlDb); ok && db != nil {
|
||||||
tx, err := db.Begin()
|
tx, err := db.BeginTx(ctx, opts)
|
||||||
c.db = interface{}(tx).(SQLCommon)
|
c.db = interface{}(tx).(SQLCommon)
|
||||||
|
|
||||||
c.dialect.SetDB(c.db)
|
c.dialect.SetDB(c.db)
|
||||||
@ -512,7 +558,26 @@ func (s *DB) Commit() *DB {
|
|||||||
func (s *DB) Rollback() *DB {
|
func (s *DB) Rollback() *DB {
|
||||||
var emptySQLTx *sql.Tx
|
var emptySQLTx *sql.Tx
|
||||||
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
||||||
s.AddError(db.Rollback())
|
if err := db.Rollback(); err != nil && err != sql.ErrTxDone {
|
||||||
|
s.AddError(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
s.AddError(ErrInvalidTransaction)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// RollbackUnlessCommitted rollback a transaction if it has not yet been
|
||||||
|
// committed.
|
||||||
|
func (s *DB) RollbackUnlessCommitted() *DB {
|
||||||
|
var emptySQLTx *sql.Tx
|
||||||
|
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
|
||||||
|
err := db.Rollback()
|
||||||
|
// Ignore the error indicating that the transaction has already
|
||||||
|
// been committed.
|
||||||
|
if err != sql.ErrTxDone {
|
||||||
|
s.AddError(err)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
s.AddError(ErrInvalidTransaction)
|
s.AddError(ErrInvalidTransaction)
|
||||||
}
|
}
|
||||||
@ -680,13 +745,13 @@ func (s *DB) Set(name string, value interface{}) *DB {
|
|||||||
|
|
||||||
// InstantSet instant set setting, will affect current db
|
// InstantSet instant set setting, will affect current db
|
||||||
func (s *DB) InstantSet(name string, value interface{}) *DB {
|
func (s *DB) InstantSet(name string, value interface{}) *DB {
|
||||||
s.values[name] = value
|
s.values.Store(name, value)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get get setting by name
|
// Get get setting by name
|
||||||
func (s *DB) Get(name string) (value interface{}, ok bool) {
|
func (s *DB) Get(name string) (value interface{}, ok bool) {
|
||||||
value, ok = s.values[name]
|
value, ok = s.values.Load(name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -695,7 +760,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
|
|||||||
scope := s.NewScope(source)
|
scope := s.NewScope(source)
|
||||||
for _, field := range scope.GetModelStruct().StructFields {
|
for _, field := range scope.GetModelStruct().StructFields {
|
||||||
if field.Name == column || field.DBName == column {
|
if field.Name == column || field.DBName == column {
|
||||||
if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" {
|
||||||
source := (&Scope{Value: source}).GetModelStruct().ModelType
|
source := (&Scope{Value: source}).GetModelStruct().ModelType
|
||||||
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
|
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
|
||||||
handler.Setup(field.Relationship, many2many, source, destination)
|
handler.Setup(field.Relationship, many2many, source, destination)
|
||||||
@ -712,8 +777,8 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
|
|||||||
func (s *DB) AddError(err error) error {
|
func (s *DB) AddError(err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != ErrRecordNotFound {
|
if err != ErrRecordNotFound {
|
||||||
if s.logMode == 0 {
|
if s.logMode == defaultLogMode {
|
||||||
go s.print(fileWithLineNum(), err)
|
go s.print("error", fileWithLineNum(), err)
|
||||||
} else {
|
} else {
|
||||||
s.log(err)
|
s.log(err)
|
||||||
}
|
}
|
||||||
@ -750,16 +815,17 @@ func (s *DB) clone() *DB {
|
|||||||
parent: s.parent,
|
parent: s.parent,
|
||||||
logger: s.logger,
|
logger: s.logger,
|
||||||
logMode: s.logMode,
|
logMode: s.logMode,
|
||||||
values: map[string]interface{}{},
|
|
||||||
Value: s.Value,
|
Value: s.Value,
|
||||||
Error: s.Error,
|
Error: s.Error,
|
||||||
blockGlobalUpdate: s.blockGlobalUpdate,
|
blockGlobalUpdate: s.blockGlobalUpdate,
|
||||||
dialect: newDialect(s.dialect.GetName(), s.db),
|
dialect: newDialect(s.dialect.GetName(), s.db),
|
||||||
|
nowFuncOverride: s.nowFuncOverride,
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, value := range s.values {
|
s.values.Range(func(k, v interface{}) bool {
|
||||||
db.values[key] = value
|
db.values.Store(k, v)
|
||||||
}
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
if s.search == nil {
|
if s.search == nil {
|
||||||
db.search = &search{limit: -1, offset: -1}
|
db.search = &search{limit: -1, offset: -1}
|
||||||
@ -776,13 +842,13 @@ func (s *DB) print(v ...interface{}) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) log(v ...interface{}) {
|
func (s *DB) log(v ...interface{}) {
|
||||||
if s != nil && s.logMode == 2 {
|
if s != nil && s.logMode == detailedLogMode {
|
||||||
s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...)
|
s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
|
func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
|
||||||
if s.logMode == 2 {
|
if s.logMode == detailedLogMode {
|
||||||
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected)
|
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
383
main_test.go
383
main_test.go
@ -1,6 +1,11 @@
|
|||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
|
// Run tests
|
||||||
|
// $ docker-compose up
|
||||||
|
// $ ./test_all.sh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -9,6 +14,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -176,6 +182,15 @@ func TestSetTable(t *testing.T) {
|
|||||||
t.Errorf("Query from specified table")
|
t.Errorf("Query from specified table")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var user User
|
||||||
|
DB.Table("deleted_users").First(&user, "name = ?", "DeletedUser")
|
||||||
|
|
||||||
|
user.Age = 20
|
||||||
|
DB.Table("deleted_users").Save(&user)
|
||||||
|
if DB.Table("deleted_users").First(&user, "name = ? AND age = ?", "DeletedUser", 20).RecordNotFound() {
|
||||||
|
t.Errorf("Failed to found updated user")
|
||||||
|
}
|
||||||
|
|
||||||
DB.Save(getPreparedUser("normal_user", "reset_table"))
|
DB.Save(getPreparedUser("normal_user", "reset_table"))
|
||||||
DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table"))
|
DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table"))
|
||||||
var user1, user2, user3 User
|
var user1, user2, user3 User
|
||||||
@ -277,6 +292,30 @@ func TestTableName(t *testing.T) {
|
|||||||
DB.SingularTable(false)
|
DB.SingularTable(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTableNameConcurrently(t *testing.T) {
|
||||||
|
DB := DB.Model("")
|
||||||
|
if DB.NewScope(Order{}).TableName() != "orders" {
|
||||||
|
t.Errorf("Order's table name should be orders")
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(10)
|
||||||
|
|
||||||
|
for i := 1; i <= 10; i++ {
|
||||||
|
go func(db *gorm.DB) {
|
||||||
|
DB.SingularTable(true)
|
||||||
|
wg.Done()
|
||||||
|
}(DB)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if DB.NewScope(Order{}).TableName() != "order" {
|
||||||
|
t.Errorf("Order's singular table name should be order")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.SingularTable(false)
|
||||||
|
}
|
||||||
|
|
||||||
func TestNullValues(t *testing.T) {
|
func TestNullValues(t *testing.T) {
|
||||||
DB.DropTable(&NullValue{})
|
DB.DropTable(&NullValue{})
|
||||||
DB.AutoMigrate(&NullValue{})
|
DB.AutoMigrate(&NullValue{})
|
||||||
@ -394,6 +433,90 @@ func TestTransaction(t *testing.T) {
|
|||||||
if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
||||||
t.Errorf("Should be able to find committed record")
|
t.Errorf("Should be able to find committed record")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tx3 := DB.Begin()
|
||||||
|
u3 := User{Name: "transcation-3"}
|
||||||
|
if err := tx3.Save(&u3).Error; err != nil {
|
||||||
|
t.Errorf("No error should raise")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx3.First(&User{}, "name = ?", "transcation-3").Error; err != nil {
|
||||||
|
t.Errorf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
tx3.RollbackUnlessCommitted()
|
||||||
|
|
||||||
|
if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil {
|
||||||
|
t.Errorf("Should not find record after rollback")
|
||||||
|
}
|
||||||
|
|
||||||
|
tx4 := DB.Begin()
|
||||||
|
u4 := User{Name: "transcation-4"}
|
||||||
|
if err := tx4.Save(&u4).Error; err != nil {
|
||||||
|
t.Errorf("No error should raise")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx4.First(&User{}, "name = ?", "transcation-4").Error; err != nil {
|
||||||
|
t.Errorf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
tx4.Commit()
|
||||||
|
|
||||||
|
tx4.RollbackUnlessCommitted()
|
||||||
|
|
||||||
|
if err := DB.First(&User{}, "name = ?", "transcation-4").Error; err != nil {
|
||||||
|
t.Errorf("Should be able to find committed record")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) {
|
||||||
|
tx := DB.Begin()
|
||||||
|
u := User{Name: "transcation"}
|
||||||
|
if err := tx.Save(&u).Error; err != nil {
|
||||||
|
t.Errorf("No error should raise")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit().Error; err != nil {
|
||||||
|
t.Errorf("Commit should not raise error")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Rollback().Error; err != nil {
|
||||||
|
t.Errorf("Rollback should not raise error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionReadonly(t *testing.T) {
|
||||||
|
dialect := os.Getenv("GORM_DIALECT")
|
||||||
|
if dialect == "" {
|
||||||
|
dialect = "sqlite"
|
||||||
|
}
|
||||||
|
switch dialect {
|
||||||
|
case "mssql", "sqlite":
|
||||||
|
t.Skipf("%s does not support readonly transactions\n", dialect)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := DB.Begin()
|
||||||
|
u := User{Name: "transcation"}
|
||||||
|
if err := tx.Save(&u).Error; err != nil {
|
||||||
|
t.Errorf("No error should raise")
|
||||||
|
}
|
||||||
|
tx.Commit()
|
||||||
|
|
||||||
|
tx = DB.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
|
||||||
|
if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
|
||||||
|
t.Errorf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil {
|
||||||
|
t.Errorf("Should return the underlying sql.Tx")
|
||||||
|
}
|
||||||
|
|
||||||
|
u = User{Name: "transcation-2"}
|
||||||
|
if err := tx.Save(&u).Error; err == nil {
|
||||||
|
t.Errorf("Error should have been raised in a readonly transaction")
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.Rollback()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRow(t *testing.T) {
|
func TestRow(t *testing.T) {
|
||||||
@ -581,6 +704,60 @@ func TestJoins(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type JoinedIds struct {
|
||||||
|
UserID int64 `gorm:"column:id"`
|
||||||
|
BillingAddressID int64 `gorm:"column:id"`
|
||||||
|
EmailID int64 `gorm:"column:id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScanIdenticalColumnNames(t *testing.T) {
|
||||||
|
var user = User{
|
||||||
|
Name: "joinsIds",
|
||||||
|
Email: "joinIds@example.com",
|
||||||
|
BillingAddress: Address{
|
||||||
|
Address1: "One Park Place",
|
||||||
|
},
|
||||||
|
Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
|
||||||
|
}
|
||||||
|
DB.Save(&user)
|
||||||
|
|
||||||
|
var users []JoinedIds
|
||||||
|
DB.Select("users.id, addresses.id, emails.id").Table("users").
|
||||||
|
Joins("left join addresses on users.billing_address_id = addresses.id").
|
||||||
|
Joins("left join emails on emails.user_id = users.id").
|
||||||
|
Where("name = ?", "joinsIds").Scan(&users)
|
||||||
|
|
||||||
|
if len(users) != 2 {
|
||||||
|
t.Fatal("should find two rows using left join")
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Id != users[0].UserID {
|
||||||
|
t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[0].UserID)
|
||||||
|
}
|
||||||
|
if user.Id != users[1].UserID {
|
||||||
|
t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[1].UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.BillingAddressID.Int64 != users[0].BillingAddressID {
|
||||||
|
t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID)
|
||||||
|
}
|
||||||
|
if user.BillingAddressID.Int64 != users[1].BillingAddressID {
|
||||||
|
t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if users[0].EmailID == users[1].EmailID {
|
||||||
|
t.Errorf("Email ids should be unique. Got %d and %d", users[0].EmailID, users[1].EmailID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if int64(user.Emails[0].Id) != users[0].EmailID && int64(user.Emails[1].Id) != users[0].EmailID {
|
||||||
|
t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[0].EmailID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if int64(user.Emails[0].Id) != users[1].EmailID && int64(user.Emails[1].Id) != users[1].EmailID {
|
||||||
|
t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[1].EmailID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestJoinsWithSelect(t *testing.T) {
|
func TestJoinsWithSelect(t *testing.T) {
|
||||||
type result struct {
|
type result struct {
|
||||||
Name string
|
Name string
|
||||||
@ -879,6 +1056,94 @@ func TestOpenWithOneParameter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSaveAssociations(t *testing.T) {
|
||||||
|
db := DB.New()
|
||||||
|
deltaAddressCount := 0
|
||||||
|
if err := db.Model(&Address{}).Count(&deltaAddressCount).Error; err != nil {
|
||||||
|
t.Errorf("failed to fetch address count")
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
placeAddress := &Address{
|
||||||
|
Address1: "somewhere on earth",
|
||||||
|
}
|
||||||
|
ownerAddress1 := &Address{
|
||||||
|
Address1: "near place address",
|
||||||
|
}
|
||||||
|
ownerAddress2 := &Address{
|
||||||
|
Address1: "address2",
|
||||||
|
}
|
||||||
|
db.Create(placeAddress)
|
||||||
|
|
||||||
|
addressCountShouldBe := func(t *testing.T, expectedCount int) {
|
||||||
|
countFromDB := 0
|
||||||
|
t.Helper()
|
||||||
|
err := db.Model(&Address{}).Count(&countFromDB).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Error("failed to fetch address count")
|
||||||
|
}
|
||||||
|
if countFromDB != expectedCount {
|
||||||
|
t.Errorf("address count mismatch: %d", countFromDB)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
addressCountShouldBe(t, deltaAddressCount+1)
|
||||||
|
|
||||||
|
// owner address should be created, place address should be reused
|
||||||
|
place1 := &Place{
|
||||||
|
PlaceAddressID: placeAddress.ID,
|
||||||
|
PlaceAddress: placeAddress,
|
||||||
|
OwnerAddress: ownerAddress1,
|
||||||
|
}
|
||||||
|
err := db.Create(place1).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to store place: %s", err.Error())
|
||||||
|
}
|
||||||
|
addressCountShouldBe(t, deltaAddressCount+2)
|
||||||
|
|
||||||
|
// owner address should be created again, place address should be reused
|
||||||
|
place2 := &Place{
|
||||||
|
PlaceAddressID: placeAddress.ID,
|
||||||
|
PlaceAddress: &Address{
|
||||||
|
ID: 777,
|
||||||
|
Address1: "address1",
|
||||||
|
},
|
||||||
|
OwnerAddress: ownerAddress2,
|
||||||
|
OwnerAddressID: 778,
|
||||||
|
}
|
||||||
|
err = db.Create(place2).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to store place: %s", err.Error())
|
||||||
|
}
|
||||||
|
addressCountShouldBe(t, deltaAddressCount+3)
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
db.Model(&Place{}).Where(&Place{
|
||||||
|
PlaceAddressID: placeAddress.ID,
|
||||||
|
OwnerAddressID: ownerAddress1.ID,
|
||||||
|
}).Count(&count)
|
||||||
|
if count != 1 {
|
||||||
|
t.Errorf("only one instance of (%d, %d) should be available, found: %d",
|
||||||
|
placeAddress.ID, ownerAddress1.ID, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Model(&Place{}).Where(&Place{
|
||||||
|
PlaceAddressID: placeAddress.ID,
|
||||||
|
OwnerAddressID: ownerAddress2.ID,
|
||||||
|
}).Count(&count)
|
||||||
|
if count != 1 {
|
||||||
|
t.Errorf("only one instance of (%d, %d) should be available, found: %d",
|
||||||
|
placeAddress.ID, ownerAddress2.ID, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Model(&Place{}).Where(&Place{
|
||||||
|
PlaceAddressID: placeAddress.ID,
|
||||||
|
}).Count(&count)
|
||||||
|
if count != 2 {
|
||||||
|
t.Errorf("two instances of (%d) should be available, found: %d",
|
||||||
|
placeAddress.ID, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBlockGlobalUpdate(t *testing.T) {
|
func TestBlockGlobalUpdate(t *testing.T) {
|
||||||
db := DB.New()
|
db := DB.New()
|
||||||
db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})
|
db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})
|
||||||
@ -917,6 +1182,124 @@ func TestBlockGlobalUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCountWithHaving(t *testing.T) {
|
||||||
|
db := DB.New()
|
||||||
|
db.Delete(User{})
|
||||||
|
defer db.Delete(User{})
|
||||||
|
|
||||||
|
DB.Create(getPreparedUser("user1", "pluck_user"))
|
||||||
|
DB.Create(getPreparedUser("user2", "pluck_user"))
|
||||||
|
user3 := getPreparedUser("user3", "pluck_user")
|
||||||
|
user3.Languages = []Language{}
|
||||||
|
DB.Create(user3)
|
||||||
|
|
||||||
|
var count int
|
||||||
|
err := db.Model(User{}).Select("users.id").
|
||||||
|
Joins("LEFT JOIN user_languages ON user_languages.user_id = users.id").
|
||||||
|
Joins("LEFT JOIN languages ON user_languages.language_id = languages.id").
|
||||||
|
Group("users.id").Having("COUNT(languages.id) > 1").Count(&count).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Unexpected error on query count with having")
|
||||||
|
}
|
||||||
|
|
||||||
|
if count != 2 {
|
||||||
|
t.Error("Unexpected result on query count with having")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPluck(t *testing.T) {
|
||||||
|
db := DB.New()
|
||||||
|
db.Delete(User{})
|
||||||
|
defer db.Delete(User{})
|
||||||
|
|
||||||
|
DB.Create(&User{Id: 1, Name: "user1"})
|
||||||
|
DB.Create(&User{Id: 2, Name: "user2"})
|
||||||
|
DB.Create(&User{Id: 3, Name: "user3"})
|
||||||
|
|
||||||
|
var ids []int64
|
||||||
|
err := db.Model(User{}).Order("id").Pluck("id", &ids).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Unexpected error on pluck")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ids) != 3 || ids[0] != 1 || ids[1] != 2 || ids[2] != 3 {
|
||||||
|
t.Error("Unexpected result on pluck")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.Model(User{}).Order("id").Pluck("id", &ids).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Unexpected error on pluck again")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ids) != 3 || ids[0] != 1 || ids[1] != 2 || ids[2] != 3 {
|
||||||
|
t.Error("Unexpected result on pluck again")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCountWithQueryOption(t *testing.T) {
|
||||||
|
db := DB.New()
|
||||||
|
db.Delete(User{})
|
||||||
|
defer db.Delete(User{})
|
||||||
|
|
||||||
|
DB.Create(&User{Name: "user1"})
|
||||||
|
DB.Create(&User{Name: "user2"})
|
||||||
|
DB.Create(&User{Name: "user3"})
|
||||||
|
|
||||||
|
var count int
|
||||||
|
err := db.Model(User{}).Select("users.id").
|
||||||
|
Set("gorm:query_option", "WHERE users.name='user2'").
|
||||||
|
Count(&count).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Unexpected error on query count with query_option")
|
||||||
|
}
|
||||||
|
|
||||||
|
if count != 1 {
|
||||||
|
t.Error("Unexpected result on query count with query_option")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFloatColumnPrecision(t *testing.T) {
|
||||||
|
if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" && dialect != "sqlite" {
|
||||||
|
t.Skip()
|
||||||
|
}
|
||||||
|
|
||||||
|
type FloatTest struct {
|
||||||
|
ID string `gorm:"primary_key"`
|
||||||
|
FloatValue float64 `gorm:"column:float_value" sql:"type:float(255,5);"`
|
||||||
|
}
|
||||||
|
DB.DropTable(&FloatTest{})
|
||||||
|
DB.AutoMigrate(&FloatTest{})
|
||||||
|
|
||||||
|
data := FloatTest{ID: "uuid", FloatValue: 112.57315}
|
||||||
|
if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.FloatValue != 112.57315 {
|
||||||
|
t.Errorf("Float value should not lose precision")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWhereUpdates(t *testing.T) {
|
||||||
|
type OwnerEntity struct {
|
||||||
|
gorm.Model
|
||||||
|
OwnerID uint
|
||||||
|
OwnerType string
|
||||||
|
}
|
||||||
|
|
||||||
|
type SomeEntity struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
OwnerEntity OwnerEntity `gorm:"polymorphic:Owner"`
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.DropTable(&SomeEntity{})
|
||||||
|
DB.AutoMigrate(&SomeEntity{})
|
||||||
|
|
||||||
|
a := SomeEntity{Name: "test"}
|
||||||
|
DB.Model(&a).Where(a).Updates(SomeEntity{Name: "test2"})
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkGorm(b *testing.B) {
|
func BenchmarkGorm(b *testing.B) {
|
||||||
b.N = 2000
|
b.N = 2000
|
||||||
for x := 0; x < b.N; x++ {
|
for x := 0; x < b.N; x++ {
|
||||||
|
@ -118,6 +118,14 @@ type Company struct {
|
|||||||
Owner *User `sql:"-"`
|
Owner *User `sql:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Place struct {
|
||||||
|
Id int64
|
||||||
|
PlaceAddressID int
|
||||||
|
PlaceAddress *Address `gorm:"save_associations:false"`
|
||||||
|
OwnerAddressID int
|
||||||
|
OwnerAddress *Address `gorm:"save_associations:true"`
|
||||||
|
}
|
||||||
|
|
||||||
type EncryptedData []byte
|
type EncryptedData []byte
|
||||||
|
|
||||||
func (data *EncryptedData) Scan(value interface{}) error {
|
func (data *EncryptedData) Scan(value interface{}) error {
|
||||||
@ -284,7 +292,7 @@ func runMigration() {
|
|||||||
DB.Exec(fmt.Sprintf("drop table %v;", table))
|
DB.Exec(fmt.Sprintf("drop table %v;", table))
|
||||||
}
|
}
|
||||||
|
|
||||||
values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}}
|
values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}}
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
DB.DropTable(value)
|
DB.DropTable(value)
|
||||||
}
|
}
|
||||||
@ -530,3 +538,42 @@ func TestModifyColumnType(t *testing.T) {
|
|||||||
t.Errorf("No error should happen when ModifyColumn, but got %v", err)
|
t.Errorf("No error should happen when ModifyColumn, but got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIndexWithPrefixLength(t *testing.T) {
|
||||||
|
if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" {
|
||||||
|
t.Skip("Skipping this because only mysql support setting an index prefix length")
|
||||||
|
}
|
||||||
|
|
||||||
|
type IndexWithPrefix struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
|
||||||
|
}
|
||||||
|
type IndexesWithPrefix struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Description1 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
|
||||||
|
Description2 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
|
||||||
|
}
|
||||||
|
type IndexesWithPrefixAndWithoutPrefix struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string `gorm:"index:idx_index_with_prefixes_length"`
|
||||||
|
Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
|
||||||
|
}
|
||||||
|
tables := []interface{}{&IndexWithPrefix{}, &IndexesWithPrefix{}, &IndexesWithPrefixAndWithoutPrefix{}}
|
||||||
|
for _, table := range tables {
|
||||||
|
scope := DB.NewScope(table)
|
||||||
|
tableName := scope.TableName()
|
||||||
|
t.Run(fmt.Sprintf("Create index with prefix length: %s", tableName), func(t *testing.T) {
|
||||||
|
if err := DB.DropTableIfExists(table).Error; err != nil {
|
||||||
|
t.Errorf("Failed to drop %s table: %v", tableName, err)
|
||||||
|
}
|
||||||
|
if err := DB.CreateTable(table).Error; err != nil {
|
||||||
|
t.Errorf("Failed to create %s table: %v", tableName, err)
|
||||||
|
}
|
||||||
|
if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") {
|
||||||
|
t.Errorf("Failed to create %s table index:", tableName)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
146
model_struct.go
146
model_struct.go
@ -21,23 +21,30 @@ var modelStructsMap sync.Map
|
|||||||
|
|
||||||
// ModelStruct model definition
|
// ModelStruct model definition
|
||||||
type ModelStruct struct {
|
type ModelStruct struct {
|
||||||
PrimaryFields []*StructField
|
PrimaryFields []*StructField
|
||||||
StructFields []*StructField
|
StructFields []*StructField
|
||||||
ModelType reflect.Type
|
ModelType reflect.Type
|
||||||
|
|
||||||
defaultTableName string
|
defaultTableName string
|
||||||
|
l sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// TableName returns model's table name
|
// TableName returns model's table name
|
||||||
func (s *ModelStruct) TableName(db *DB) string {
|
func (s *ModelStruct) TableName(db *DB) string {
|
||||||
|
s.l.Lock()
|
||||||
|
defer s.l.Unlock()
|
||||||
|
|
||||||
if s.defaultTableName == "" && db != nil && s.ModelType != nil {
|
if s.defaultTableName == "" && db != nil && s.ModelType != nil {
|
||||||
// Set default table name
|
// Set default table name
|
||||||
if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok {
|
if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok {
|
||||||
s.defaultTableName = tabler.TableName()
|
s.defaultTableName = tabler.TableName()
|
||||||
} else {
|
} else {
|
||||||
tableName := ToDBName(s.ModelType.Name())
|
tableName := ToTableName(s.ModelType.Name())
|
||||||
if db == nil || !db.parent.singularTable {
|
db.parent.RLock()
|
||||||
|
if db == nil || (db.parent != nil && !db.parent.singularTable) {
|
||||||
tableName = inflection.Plural(tableName)
|
tableName = inflection.Plural(tableName)
|
||||||
}
|
}
|
||||||
|
db.parent.RUnlock()
|
||||||
s.defaultTableName = tableName
|
s.defaultTableName = tableName
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -60,30 +67,57 @@ type StructField struct {
|
|||||||
Struct reflect.StructField
|
Struct reflect.StructField
|
||||||
IsForeignKey bool
|
IsForeignKey bool
|
||||||
Relationship *Relationship
|
Relationship *Relationship
|
||||||
|
|
||||||
|
tagSettingsLock sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (structField *StructField) clone() *StructField {
|
// TagSettingsSet Sets a tag in the tag settings map
|
||||||
|
func (sf *StructField) TagSettingsSet(key, val string) {
|
||||||
|
sf.tagSettingsLock.Lock()
|
||||||
|
defer sf.tagSettingsLock.Unlock()
|
||||||
|
sf.TagSettings[key] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
// TagSettingsGet returns a tag from the tag settings
|
||||||
|
func (sf *StructField) TagSettingsGet(key string) (string, bool) {
|
||||||
|
sf.tagSettingsLock.RLock()
|
||||||
|
defer sf.tagSettingsLock.RUnlock()
|
||||||
|
val, ok := sf.TagSettings[key]
|
||||||
|
return val, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// TagSettingsDelete deletes a tag
|
||||||
|
func (sf *StructField) TagSettingsDelete(key string) {
|
||||||
|
sf.tagSettingsLock.Lock()
|
||||||
|
defer sf.tagSettingsLock.Unlock()
|
||||||
|
delete(sf.TagSettings, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sf *StructField) clone() *StructField {
|
||||||
clone := &StructField{
|
clone := &StructField{
|
||||||
DBName: structField.DBName,
|
DBName: sf.DBName,
|
||||||
Name: structField.Name,
|
Name: sf.Name,
|
||||||
Names: structField.Names,
|
Names: sf.Names,
|
||||||
IsPrimaryKey: structField.IsPrimaryKey,
|
IsPrimaryKey: sf.IsPrimaryKey,
|
||||||
IsNormal: structField.IsNormal,
|
IsNormal: sf.IsNormal,
|
||||||
IsIgnored: structField.IsIgnored,
|
IsIgnored: sf.IsIgnored,
|
||||||
IsScanner: structField.IsScanner,
|
IsScanner: sf.IsScanner,
|
||||||
HasDefaultValue: structField.HasDefaultValue,
|
HasDefaultValue: sf.HasDefaultValue,
|
||||||
Tag: structField.Tag,
|
Tag: sf.Tag,
|
||||||
TagSettings: map[string]string{},
|
TagSettings: map[string]string{},
|
||||||
Struct: structField.Struct,
|
Struct: sf.Struct,
|
||||||
IsForeignKey: structField.IsForeignKey,
|
IsForeignKey: sf.IsForeignKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
if structField.Relationship != nil {
|
if sf.Relationship != nil {
|
||||||
relationship := *structField.Relationship
|
relationship := *sf.Relationship
|
||||||
clone.Relationship = &relationship
|
clone.Relationship = &relationship
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, value := range structField.TagSettings {
|
// copy the struct field tagSettings, they should be read-locked while they are copied
|
||||||
|
sf.tagSettingsLock.Lock()
|
||||||
|
defer sf.tagSettingsLock.Unlock()
|
||||||
|
for key, value := range sf.TagSettings {
|
||||||
clone.TagSettings[key] = value
|
clone.TagSettings[key] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,7 +139,7 @@ type Relationship struct {
|
|||||||
|
|
||||||
func getForeignField(column string, fields []*StructField) *StructField {
|
func getForeignField(column string, fields []*StructField) *StructField {
|
||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) {
|
if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) {
|
||||||
return field
|
return field
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -131,7 +165,18 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get Cached model struct
|
// Get Cached model struct
|
||||||
if value, ok := modelStructsMap.Load(reflectType); ok && value != nil {
|
isSingularTable := false
|
||||||
|
if scope.db != nil && scope.db.parent != nil {
|
||||||
|
scope.db.parent.RLock()
|
||||||
|
isSingularTable = scope.db.parent.singularTable
|
||||||
|
scope.db.parent.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
hashKey := struct {
|
||||||
|
singularTable bool
|
||||||
|
reflectType reflect.Type
|
||||||
|
}{isSingularTable, reflectType}
|
||||||
|
if value, ok := modelStructsMap.Load(hashKey); ok && value != nil {
|
||||||
return value.(*ModelStruct)
|
return value.(*ModelStruct)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -149,19 +194,19 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// is ignored field
|
// is ignored field
|
||||||
if _, ok := field.TagSettings["-"]; ok {
|
if _, ok := field.TagSettingsGet("-"); ok {
|
||||||
field.IsIgnored = true
|
field.IsIgnored = true
|
||||||
} else {
|
} else {
|
||||||
if _, ok := field.TagSettings["PRIMARY_KEY"]; ok {
|
if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok {
|
||||||
field.IsPrimaryKey = true
|
field.IsPrimaryKey = true
|
||||||
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := field.TagSettings["DEFAULT"]; ok {
|
if _, ok := field.TagSettingsGet("DEFAULT"); ok && !field.IsPrimaryKey {
|
||||||
field.HasDefaultValue = true
|
field.HasDefaultValue = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsPrimaryKey {
|
if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey {
|
||||||
field.HasDefaultValue = true
|
field.HasDefaultValue = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -177,8 +222,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
if indirectType.Kind() == reflect.Struct {
|
if indirectType.Kind() == reflect.Struct {
|
||||||
for i := 0; i < indirectType.NumField(); i++ {
|
for i := 0; i < indirectType.NumField(); i++ {
|
||||||
for key, value := range parseTagSetting(indirectType.Field(i).Tag) {
|
for key, value := range parseTagSetting(indirectType.Field(i).Tag) {
|
||||||
if _, ok := field.TagSettings[key]; !ok {
|
if _, ok := field.TagSettingsGet(key); !ok {
|
||||||
field.TagSettings[key] = value
|
field.TagSettingsSet(key, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -186,17 +231,17 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
} else if _, isTime := fieldValue.(*time.Time); isTime {
|
} else if _, isTime := fieldValue.(*time.Time); isTime {
|
||||||
// is time
|
// is time
|
||||||
field.IsNormal = true
|
field.IsNormal = true
|
||||||
} else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
|
} else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous {
|
||||||
// is embedded struct
|
// is embedded struct
|
||||||
for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields {
|
for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields {
|
||||||
subField = subField.clone()
|
subField = subField.clone()
|
||||||
subField.Names = append([]string{fieldStruct.Name}, subField.Names...)
|
subField.Names = append([]string{fieldStruct.Name}, subField.Names...)
|
||||||
if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok {
|
if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok {
|
||||||
subField.DBName = prefix + subField.DBName
|
subField.DBName = prefix + subField.DBName
|
||||||
}
|
}
|
||||||
|
|
||||||
if subField.IsPrimaryKey {
|
if subField.IsPrimaryKey {
|
||||||
if _, ok := subField.TagSettings["PRIMARY_KEY"]; ok {
|
if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok {
|
||||||
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField)
|
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField)
|
||||||
} else {
|
} else {
|
||||||
subField.IsPrimaryKey = false
|
subField.IsPrimaryKey = false
|
||||||
@ -227,13 +272,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
elemType = field.Struct.Type
|
elemType = field.Struct.Type
|
||||||
)
|
)
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
|
if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" {
|
||||||
foreignKeys = strings.Split(foreignKey, ",")
|
foreignKeys = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
|
if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" {
|
||||||
associationForeignKeys = strings.Split(foreignKey, ",")
|
associationForeignKeys = strings.Split(foreignKey, ",")
|
||||||
} else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
} else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" {
|
||||||
associationForeignKeys = strings.Split(foreignKey, ",")
|
associationForeignKeys = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -242,13 +287,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if elemType.Kind() == reflect.Struct {
|
if elemType.Kind() == reflect.Struct {
|
||||||
if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" {
|
||||||
relationship.Kind = "many_to_many"
|
relationship.Kind = "many_to_many"
|
||||||
|
|
||||||
{ // Foreign Keys for Source
|
{ // Foreign Keys for Source
|
||||||
joinTableDBNames := []string{}
|
joinTableDBNames := []string{}
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
|
if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" {
|
||||||
joinTableDBNames = strings.Split(foreignKey, ",")
|
joinTableDBNames = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -269,7 +314,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
// if defined join table's foreign key
|
// if defined join table's foreign key
|
||||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
|
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
|
||||||
} else {
|
} else {
|
||||||
defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName
|
defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName
|
||||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
|
relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -279,7 +324,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
{ // Foreign Keys for Association (Destination)
|
{ // Foreign Keys for Association (Destination)
|
||||||
associationJoinTableDBNames := []string{}
|
associationJoinTableDBNames := []string{}
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
|
if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" {
|
||||||
associationJoinTableDBNames = strings.Split(foreignKey, ",")
|
associationJoinTableDBNames = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -300,7 +345,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
|
||||||
} else {
|
} else {
|
||||||
// join table foreign keys for association
|
// join table foreign keys for association
|
||||||
joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
|
joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName
|
||||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -317,7 +362,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
var toFields = toScope.GetStructFields()
|
var toFields = toScope.GetStructFields()
|
||||||
relationship.Kind = "has_many"
|
relationship.Kind = "has_many"
|
||||||
|
|
||||||
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" {
|
||||||
// Dog has many toys, tag polymorphic is Owner, then associationType is Owner
|
// Dog has many toys, tag polymorphic is Owner, then associationType is Owner
|
||||||
// Toy use OwnerID, OwnerType ('dogs') as foreign key
|
// Toy use OwnerID, OwnerType ('dogs') as foreign key
|
||||||
if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
|
if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
|
||||||
@ -325,7 +370,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
relationship.PolymorphicType = polymorphicType.Name
|
relationship.PolymorphicType = polymorphicType.Name
|
||||||
relationship.PolymorphicDBName = polymorphicType.DBName
|
relationship.PolymorphicDBName = polymorphicType.DBName
|
||||||
// if Dog has multiple set of toys set name of the set (instead of default 'dogs')
|
// if Dog has multiple set of toys set name of the set (instead of default 'dogs')
|
||||||
if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
|
if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok {
|
||||||
relationship.PolymorphicValue = value
|
relationship.PolymorphicValue = value
|
||||||
} else {
|
} else {
|
||||||
relationship.PolymorphicValue = scope.TableName()
|
relationship.PolymorphicValue = scope.TableName()
|
||||||
@ -407,17 +452,17 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
tagAssociationForeignKeys []string
|
tagAssociationForeignKeys []string
|
||||||
)
|
)
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
|
if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" {
|
||||||
tagForeignKeys = strings.Split(foreignKey, ",")
|
tagForeignKeys = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
|
if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" {
|
||||||
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
|
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
|
||||||
} else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
} else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" {
|
||||||
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
|
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" {
|
||||||
// Cat has one toy, tag polymorphic is Owner, then associationType is Owner
|
// Cat has one toy, tag polymorphic is Owner, then associationType is Owner
|
||||||
// Toy use OwnerID, OwnerType ('cats') as foreign key
|
// Toy use OwnerID, OwnerType ('cats') as foreign key
|
||||||
if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
|
if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
|
||||||
@ -425,7 +470,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
relationship.PolymorphicType = polymorphicType.Name
|
relationship.PolymorphicType = polymorphicType.Name
|
||||||
relationship.PolymorphicDBName = polymorphicType.DBName
|
relationship.PolymorphicDBName = polymorphicType.DBName
|
||||||
// if Cat has several different types of toys set name for each (instead of default 'cats')
|
// if Cat has several different types of toys set name for each (instead of default 'cats')
|
||||||
if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
|
if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok {
|
||||||
relationship.PolymorphicValue = value
|
relationship.PolymorphicValue = value
|
||||||
} else {
|
} else {
|
||||||
relationship.PolymorphicValue = scope.TableName()
|
relationship.PolymorphicValue = scope.TableName()
|
||||||
@ -563,10 +608,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Even it is ignored, also possible to decode db value into the field
|
// Even it is ignored, also possible to decode db value into the field
|
||||||
if value, ok := field.TagSettings["COLUMN"]; ok {
|
if value, ok := field.TagSettingsGet("COLUMN"); ok {
|
||||||
field.DBName = value
|
field.DBName = value
|
||||||
} else {
|
} else {
|
||||||
field.DBName = ToDBName(fieldStruct.Name)
|
field.DBName = ToColumnName(fieldStruct.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
||||||
@ -580,7 +625,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
modelStructsMap.Store(reflectType, &modelStruct)
|
modelStructsMap.Store(hashKey, &modelStruct)
|
||||||
|
|
||||||
return &modelStruct
|
return &modelStruct
|
||||||
}
|
}
|
||||||
@ -593,6 +638,9 @@ func (scope *Scope) GetStructFields() (fields []*StructField) {
|
|||||||
func parseTagSetting(tags reflect.StructTag) map[string]string {
|
func parseTagSetting(tags reflect.StructTag) map[string]string {
|
||||||
setting := map[string]string{}
|
setting := map[string]string{}
|
||||||
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
|
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
|
||||||
|
if str == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
tags := strings.Split(str, ";")
|
tags := strings.Split(str, ";")
|
||||||
for _, value := range tags {
|
for _, value := range tags {
|
||||||
v := strings.Split(value, ":")
|
v := strings.Split(value, ":")
|
||||||
|
124
naming.go
Normal file
124
naming.go
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Namer is a function type which is given a string and return a string
|
||||||
|
type Namer func(string) string
|
||||||
|
|
||||||
|
// NamingStrategy represents naming strategies
|
||||||
|
type NamingStrategy struct {
|
||||||
|
DB Namer
|
||||||
|
Table Namer
|
||||||
|
Column Namer
|
||||||
|
}
|
||||||
|
|
||||||
|
// TheNamingStrategy is being initialized with defaultNamingStrategy
|
||||||
|
var TheNamingStrategy = &NamingStrategy{
|
||||||
|
DB: defaultNamer,
|
||||||
|
Table: defaultNamer,
|
||||||
|
Column: defaultNamer,
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddNamingStrategy sets the naming strategy
|
||||||
|
func AddNamingStrategy(ns *NamingStrategy) {
|
||||||
|
if ns.DB == nil {
|
||||||
|
ns.DB = defaultNamer
|
||||||
|
}
|
||||||
|
if ns.Table == nil {
|
||||||
|
ns.Table = defaultNamer
|
||||||
|
}
|
||||||
|
if ns.Column == nil {
|
||||||
|
ns.Column = defaultNamer
|
||||||
|
}
|
||||||
|
TheNamingStrategy = ns
|
||||||
|
}
|
||||||
|
|
||||||
|
// DBName alters the given name by DB
|
||||||
|
func (ns *NamingStrategy) DBName(name string) string {
|
||||||
|
return ns.DB(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName alters the given name by Table
|
||||||
|
func (ns *NamingStrategy) TableName(name string) string {
|
||||||
|
return ns.Table(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ColumnName alters the given name by Column
|
||||||
|
func (ns *NamingStrategy) ColumnName(name string) string {
|
||||||
|
return ns.Column(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToDBName convert string to db name
|
||||||
|
func ToDBName(name string) string {
|
||||||
|
return TheNamingStrategy.DBName(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToTableName convert string to table name
|
||||||
|
func ToTableName(name string) string {
|
||||||
|
return TheNamingStrategy.TableName(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToColumnName convert string to db name
|
||||||
|
func ToColumnName(name string) string {
|
||||||
|
return TheNamingStrategy.ColumnName(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
var smap = newSafeMap()
|
||||||
|
|
||||||
|
func defaultNamer(name string) string {
|
||||||
|
const (
|
||||||
|
lower = false
|
||||||
|
upper = true
|
||||||
|
)
|
||||||
|
|
||||||
|
if v := smap.Get(name); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
if name == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
value = commonInitialismsReplacer.Replace(name)
|
||||||
|
buf = bytes.NewBufferString("")
|
||||||
|
lastCase, currCase, nextCase, nextNumber bool
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, v := range value[:len(value)-1] {
|
||||||
|
nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z')
|
||||||
|
nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9')
|
||||||
|
|
||||||
|
if i > 0 {
|
||||||
|
if currCase == upper {
|
||||||
|
if lastCase == upper && (nextCase == upper || nextNumber == upper) {
|
||||||
|
buf.WriteRune(v)
|
||||||
|
} else {
|
||||||
|
if value[i-1] != '_' && value[i+1] != '_' {
|
||||||
|
buf.WriteRune('_')
|
||||||
|
}
|
||||||
|
buf.WriteRune(v)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
buf.WriteRune(v)
|
||||||
|
if i == len(value)-2 && (nextCase == upper && nextNumber == lower) {
|
||||||
|
buf.WriteRune('_')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
currCase = upper
|
||||||
|
buf.WriteRune(v)
|
||||||
|
}
|
||||||
|
lastCase = currCase
|
||||||
|
currCase = nextCase
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteByte(value[len(value)-1])
|
||||||
|
|
||||||
|
s := strings.ToLower(buf.String())
|
||||||
|
smap.Set(name, s)
|
||||||
|
return s
|
||||||
|
}
|
69
naming_test.go
Normal file
69
naming_test.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
package gorm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTheNamingStrategy(t *testing.T) {
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
namer gorm.Namer
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{name: "auth", expected: "auth", namer: gorm.TheNamingStrategy.DB},
|
||||||
|
{name: "userRestrictions", expected: "user_restrictions", namer: gorm.TheNamingStrategy.Table},
|
||||||
|
{name: "clientID", expected: "client_id", namer: gorm.TheNamingStrategy.Column},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
result := c.namer(c.name)
|
||||||
|
if result != c.expected {
|
||||||
|
t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamingStrategy(t *testing.T) {
|
||||||
|
|
||||||
|
dbNameNS := func(name string) string {
|
||||||
|
return "db_" + name
|
||||||
|
}
|
||||||
|
tableNameNS := func(name string) string {
|
||||||
|
return "tbl_" + name
|
||||||
|
}
|
||||||
|
columnNameNS := func(name string) string {
|
||||||
|
return "col_" + name
|
||||||
|
}
|
||||||
|
|
||||||
|
ns := &gorm.NamingStrategy{
|
||||||
|
DB: dbNameNS,
|
||||||
|
Table: tableNameNS,
|
||||||
|
Column: columnNameNS,
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
namer gorm.Namer
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{name: "auth", expected: "db_auth", namer: ns.DB},
|
||||||
|
{name: "user", expected: "tbl_user", namer: ns.Table},
|
||||||
|
{name: "password", expected: "col_password", namer: ns.Column},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
result := c.namer(c.name)
|
||||||
|
if result != c.expected {
|
||||||
|
t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -771,6 +771,7 @@ func TestNestedPreload11(t *testing.T) {
|
|||||||
levelB3 := &LevelB3{
|
levelB3 := &LevelB3{
|
||||||
Value: "bar",
|
Value: "bar",
|
||||||
LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)},
|
LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)},
|
||||||
|
LevelB2s: []*LevelB2{},
|
||||||
}
|
}
|
||||||
if err := DB.Create(levelB3).Error; err != nil {
|
if err := DB.Create(levelB3).Error; err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@ -1676,7 +1677,7 @@ func TestPreloadManyToManyCallbacks(t *testing.T) {
|
|||||||
lvl := Level1{
|
lvl := Level1{
|
||||||
Name: "l1",
|
Name: "l1",
|
||||||
Level2s: []Level2{
|
Level2s: []Level2{
|
||||||
Level2{Name: "l2-1"}, Level2{Name: "l2-2"},
|
{Name: "l2-1"}, {Name: "l2-2"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
DB.Save(&lvl)
|
DB.Save(&lvl)
|
||||||
|
92
scope.go
92
scope.go
@ -68,7 +68,7 @@ func (scope *Scope) Dialect() Dialect {
|
|||||||
|
|
||||||
// Quote used to quote string to escape them for database
|
// Quote used to quote string to escape them for database
|
||||||
func (scope *Scope) Quote(str string) string {
|
func (scope *Scope) Quote(str string) string {
|
||||||
if strings.Index(str, ".") != -1 {
|
if strings.Contains(str, ".") {
|
||||||
newStrs := []string{}
|
newStrs := []string{}
|
||||||
for _, str := range strings.Split(str, ".") {
|
for _, str := range strings.Split(str, ".") {
|
||||||
newStrs = append(newStrs, scope.Dialect().Quote(str))
|
newStrs = append(newStrs, scope.Dialect().Quote(str))
|
||||||
@ -134,7 +134,7 @@ func (scope *Scope) Fields() []*Field {
|
|||||||
// FieldByName find `gorm.Field` with field name or db name
|
// FieldByName find `gorm.Field` with field name or db name
|
||||||
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
||||||
var (
|
var (
|
||||||
dbName = ToDBName(name)
|
dbName = ToColumnName(name)
|
||||||
mostMatchedField *Field
|
mostMatchedField *Field
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -330,7 +330,7 @@ func (scope *Scope) TableName() string {
|
|||||||
// QuotedTableName return quoted table name
|
// QuotedTableName return quoted table name
|
||||||
func (scope *Scope) QuotedTableName() (name string) {
|
func (scope *Scope) QuotedTableName() (name string) {
|
||||||
if scope.Search != nil && len(scope.Search.tableName) > 0 {
|
if scope.Search != nil && len(scope.Search.tableName) > 0 {
|
||||||
if strings.Index(scope.Search.tableName, " ") != -1 {
|
if strings.Contains(scope.Search.tableName, " ") {
|
||||||
return scope.Search.tableName
|
return scope.Search.tableName
|
||||||
}
|
}
|
||||||
return scope.Quote(scope.Search.tableName)
|
return scope.Quote(scope.Search.tableName)
|
||||||
@ -358,7 +358,7 @@ func (scope *Scope) Raw(sql string) *Scope {
|
|||||||
|
|
||||||
// Exec perform generated SQL
|
// Exec perform generated SQL
|
||||||
func (scope *Scope) Exec() *Scope {
|
func (scope *Scope) Exec() *Scope {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(scope.db.nowFunc())
|
||||||
|
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||||
@ -402,7 +402,7 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
|
|||||||
// Begin start a transaction
|
// Begin start a transaction
|
||||||
func (scope *Scope) Begin() *Scope {
|
func (scope *Scope) Begin() *Scope {
|
||||||
if db, ok := scope.SQLDB().(sqlDb); ok {
|
if db, ok := scope.SQLDB().(sqlDb); ok {
|
||||||
if tx, err := db.Begin(); err == nil {
|
if tx, err := db.Begin(); scope.Err(err) == nil {
|
||||||
scope.db.db = interface{}(tx).(SQLCommon)
|
scope.db.db = interface{}(tx).(SQLCommon)
|
||||||
scope.InstanceSet("gorm:started_transaction", true)
|
scope.InstanceSet("gorm:started_transaction", true)
|
||||||
}
|
}
|
||||||
@ -486,8 +486,10 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
|
|||||||
values[index] = &ignored
|
values[index] = &ignored
|
||||||
|
|
||||||
selectFields = fields
|
selectFields = fields
|
||||||
|
offset := 0
|
||||||
if idx, ok := selectedColumnsMap[column]; ok {
|
if idx, ok := selectedColumnsMap[column]; ok {
|
||||||
selectFields = selectFields[idx+1:]
|
offset = idx + 1
|
||||||
|
selectFields = selectFields[offset:]
|
||||||
}
|
}
|
||||||
|
|
||||||
for fieldIndex, field := range selectFields {
|
for fieldIndex, field := range selectFields {
|
||||||
@ -501,7 +503,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
|
|||||||
resetFields[index] = field
|
resetFields[index] = field
|
||||||
}
|
}
|
||||||
|
|
||||||
selectedColumnsMap[column] = fieldIndex
|
selectedColumnsMap[column] = offset + fieldIndex
|
||||||
|
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
break
|
break
|
||||||
@ -853,6 +855,14 @@ func (scope *Scope) inlineCondition(values ...interface{}) *Scope {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
if db, ok := scope.db.db.(sqlTx); ok {
|
||||||
|
db.Rollback()
|
||||||
|
}
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
for _, f := range funcs {
|
for _, f := range funcs {
|
||||||
(*f)(scope)
|
(*f)(scope)
|
||||||
if scope.skipLeft {
|
if scope.skipLeft {
|
||||||
@ -862,7 +872,7 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
|||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string]interface{} {
|
func convertInterfaceToMap(values interface{}, withIgnoredField bool, db *DB) map[string]interface{} {
|
||||||
var attrs = map[string]interface{}{}
|
var attrs = map[string]interface{}{}
|
||||||
|
|
||||||
switch value := values.(type) {
|
switch value := values.(type) {
|
||||||
@ -870,7 +880,7 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string
|
|||||||
return value
|
return value
|
||||||
case []interface{}:
|
case []interface{}:
|
||||||
for _, v := range value {
|
for _, v := range value {
|
||||||
for key, value := range convertInterfaceToMap(v, withIgnoredField) {
|
for key, value := range convertInterfaceToMap(v, withIgnoredField, db) {
|
||||||
attrs[key] = value
|
attrs[key] = value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -880,10 +890,10 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string
|
|||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Map:
|
case reflect.Map:
|
||||||
for _, key := range reflectValue.MapKeys() {
|
for _, key := range reflectValue.MapKeys() {
|
||||||
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
|
attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
for _, field := range (&Scope{Value: values}).Fields() {
|
for _, field := range (&Scope{Value: values, db: db}).Fields() {
|
||||||
if !field.IsBlank && (withIgnoredField || !field.IsIgnored) {
|
if !field.IsBlank && (withIgnoredField || !field.IsIgnored) {
|
||||||
attrs[field.DBName] = field.Field.Interface()
|
attrs[field.DBName] = field.Field.Interface()
|
||||||
}
|
}
|
||||||
@ -895,19 +905,19 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string
|
|||||||
|
|
||||||
func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) {
|
func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) {
|
||||||
if scope.IndirectValue().Kind() != reflect.Struct {
|
if scope.IndirectValue().Kind() != reflect.Struct {
|
||||||
return convertInterfaceToMap(value, false), true
|
return convertInterfaceToMap(value, false, scope.db), true
|
||||||
}
|
}
|
||||||
|
|
||||||
results = map[string]interface{}{}
|
results = map[string]interface{}{}
|
||||||
|
|
||||||
for key, value := range convertInterfaceToMap(value, true) {
|
for key, value := range convertInterfaceToMap(value, true, scope.db) {
|
||||||
if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
|
if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
|
||||||
if _, ok := value.(*expr); ok {
|
if _, ok := value.(*expr); ok {
|
||||||
hasUpdate = true
|
hasUpdate = true
|
||||||
results[field.DBName] = value
|
results[field.DBName] = value
|
||||||
} else {
|
} else {
|
||||||
err := field.Set(value)
|
err := field.Set(value)
|
||||||
if field.IsNormal {
|
if field.IsNormal && !field.IsIgnored {
|
||||||
hasUpdate = true
|
hasUpdate = true
|
||||||
if err == ErrUnaddressable {
|
if err == ErrUnaddressable {
|
||||||
results[field.DBName] = value
|
results[field.DBName] = value
|
||||||
@ -922,7 +932,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) row() *sql.Row {
|
func (scope *Scope) row() *sql.Row {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(scope.db.nowFunc())
|
||||||
|
|
||||||
result := &RowQueryResult{}
|
result := &RowQueryResult{}
|
||||||
scope.InstanceSet("row_query_result", result)
|
scope.InstanceSet("row_query_result", result)
|
||||||
@ -932,7 +942,7 @@ func (scope *Scope) row() *sql.Row {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) rows() (*sql.Rows, error) {
|
func (scope *Scope) rows() (*sql.Rows, error) {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(scope.db.nowFunc())
|
||||||
|
|
||||||
result := &RowsQueryResult{}
|
result := &RowsQueryResult{}
|
||||||
scope.InstanceSet("row_query_result", result)
|
scope.InstanceSet("row_query_result", result)
|
||||||
@ -974,6 +984,10 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope {
|
|||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if dest.Len() > 0 {
|
||||||
|
dest.Set(reflect.Zero(dest.Type()))
|
||||||
|
}
|
||||||
|
|
||||||
if query, ok := scope.Search.selects["query"]; !ok || !scope.isQueryForColumn(query, column) {
|
if query, ok := scope.Search.selects["query"]; !ok || !scope.isQueryForColumn(query, column) {
|
||||||
scope.Search.Select(column)
|
scope.Search.Select(column)
|
||||||
}
|
}
|
||||||
@ -997,8 +1011,15 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope {
|
|||||||
func (scope *Scope) count(value interface{}) *Scope {
|
func (scope *Scope) count(value interface{}) *Scope {
|
||||||
if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) {
|
if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) {
|
||||||
if len(scope.Search.group) != 0 {
|
if len(scope.Search.group) != 0 {
|
||||||
scope.Search.Select("count(*) FROM ( SELECT count(*) as name ")
|
if len(scope.Search.havingConditions) != 0 {
|
||||||
scope.Search.group += " ) AS count_table"
|
scope.prepareQuerySQL()
|
||||||
|
scope.Search = &search{}
|
||||||
|
scope.Search.Select("count(*)")
|
||||||
|
scope.Search.Table(fmt.Sprintf("( %s ) AS count_table", scope.SQL))
|
||||||
|
} else {
|
||||||
|
scope.Search.Select("count(*) FROM ( SELECT count(*) as name ")
|
||||||
|
scope.Search.group += " ) AS count_table"
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
scope.Search.Select("count(*)")
|
scope.Search.Select("count(*)")
|
||||||
}
|
}
|
||||||
@ -1113,8 +1134,8 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
|||||||
if field, ok := scope.FieldByName(fieldName); ok {
|
if field, ok := scope.FieldByName(fieldName); ok {
|
||||||
foreignKeyStruct := field.clone()
|
foreignKeyStruct := field.clone()
|
||||||
foreignKeyStruct.IsPrimaryKey = false
|
foreignKeyStruct.IsPrimaryKey = false
|
||||||
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true")
|
||||||
delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT")
|
foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT")
|
||||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
||||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
|
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
|
||||||
}
|
}
|
||||||
@ -1124,8 +1145,8 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
|||||||
if field, ok := toScope.FieldByName(fieldName); ok {
|
if field, ok := toScope.FieldByName(fieldName); ok {
|
||||||
foreignKeyStruct := field.clone()
|
foreignKeyStruct := field.clone()
|
||||||
foreignKeyStruct.IsPrimaryKey = false
|
foreignKeyStruct.IsPrimaryKey = false
|
||||||
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true")
|
||||||
delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT")
|
foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT")
|
||||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
||||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
|
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
|
||||||
}
|
}
|
||||||
@ -1173,7 +1194,7 @@ func (scope *Scope) createTable() *Scope {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) dropTable() *Scope {
|
func (scope *Scope) dropTable() *Scope {
|
||||||
scope.Raw(fmt.Sprintf("DROP TABLE %v%s", scope.QuotedTableName(), scope.getTableOptions())).Exec()
|
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec()
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1260,25 +1281,27 @@ func (scope *Scope) autoIndex() *Scope {
|
|||||||
var uniqueIndexes = map[string][]string{}
|
var uniqueIndexes = map[string][]string{}
|
||||||
|
|
||||||
for _, field := range scope.GetStructFields() {
|
for _, field := range scope.GetStructFields() {
|
||||||
if name, ok := field.TagSettings["INDEX"]; ok {
|
if name, ok := field.TagSettingsGet("INDEX"); ok {
|
||||||
names := strings.Split(name, ",")
|
names := strings.Split(name, ",")
|
||||||
|
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
if name == "INDEX" || name == "" {
|
if name == "INDEX" || name == "" {
|
||||||
name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName)
|
name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName)
|
||||||
}
|
}
|
||||||
indexes[name] = append(indexes[name], field.DBName)
|
name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName)
|
||||||
|
indexes[name] = append(indexes[name], column)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok {
|
if name, ok := field.TagSettingsGet("UNIQUE_INDEX"); ok {
|
||||||
names := strings.Split(name, ",")
|
names := strings.Split(name, ",")
|
||||||
|
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
if name == "UNIQUE_INDEX" || name == "" {
|
if name == "UNIQUE_INDEX" || name == "" {
|
||||||
name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName)
|
name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName)
|
||||||
}
|
}
|
||||||
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
|
name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName)
|
||||||
|
uniqueIndexes[name] = append(uniqueIndexes[name], column)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1299,6 +1322,7 @@ func (scope *Scope) autoIndex() *Scope {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
|
func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
|
||||||
|
resultMap := make(map[string][]interface{})
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
indirectValue := indirect(reflect.ValueOf(value))
|
indirectValue := indirect(reflect.ValueOf(value))
|
||||||
|
|
||||||
@ -1317,7 +1341,10 @@ func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (r
|
|||||||
}
|
}
|
||||||
|
|
||||||
if hasValue {
|
if hasValue {
|
||||||
results = append(results, result)
|
h := fmt.Sprint(result...)
|
||||||
|
if _, exist := resultMap[h]; !exist {
|
||||||
|
resultMap[h] = result
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
@ -1332,11 +1359,16 @@ func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (r
|
|||||||
}
|
}
|
||||||
|
|
||||||
if hasValue {
|
if hasValue {
|
||||||
results = append(results, result)
|
h := fmt.Sprint(result...)
|
||||||
|
if _, exist := resultMap[h]; !exist {
|
||||||
|
resultMap[h] = result
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for _, v := range resultMap {
|
||||||
|
results = append(results, v)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,3 +78,16 @@ func TestFailedValuer(t *testing.T) {
|
|||||||
t.Errorf("The error should be returned from Valuer, but get %v", err)
|
t.Errorf("The error should be returned from Valuer, but get %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDropTableWithTableOptions(t *testing.T) {
|
||||||
|
type UserWithOptions struct {
|
||||||
|
gorm.Model
|
||||||
|
}
|
||||||
|
DB.AutoMigrate(&UserWithOptions{})
|
||||||
|
|
||||||
|
DB = DB.Set("gorm:table_options", "CHARSET=utf8")
|
||||||
|
err := DB.DropTable(&UserWithOptions{}).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Table must be dropped, got error %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
67
utils.go
67
utils.go
@ -1,7 +1,6 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -26,8 +25,8 @@ var NowFunc = func() time.Time {
|
|||||||
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
||||||
var commonInitialismsReplacer *strings.Replacer
|
var commonInitialismsReplacer *strings.Replacer
|
||||||
|
|
||||||
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`)
|
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`)
|
||||||
var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`)
|
var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
var commonInitialismsForReplacer []string
|
var commonInitialismsForReplacer []string
|
||||||
@ -58,66 +57,6 @@ func newSafeMap() *safeMap {
|
|||||||
return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
|
return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
|
||||||
}
|
}
|
||||||
|
|
||||||
var smap = newSafeMap()
|
|
||||||
|
|
||||||
type strCase bool
|
|
||||||
|
|
||||||
const (
|
|
||||||
lower strCase = false
|
|
||||||
upper strCase = true
|
|
||||||
)
|
|
||||||
|
|
||||||
// ToDBName convert string to db name
|
|
||||||
func ToDBName(name string) string {
|
|
||||||
if v := smap.Get(name); v != "" {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
if name == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
value = commonInitialismsReplacer.Replace(name)
|
|
||||||
buf = bytes.NewBufferString("")
|
|
||||||
lastCase, currCase, nextCase, nextNumber strCase
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, v := range value[:len(value)-1] {
|
|
||||||
nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z')
|
|
||||||
nextNumber = strCase(value[i+1] >= '0' && value[i+1] <= '9')
|
|
||||||
|
|
||||||
if i > 0 {
|
|
||||||
if currCase == upper {
|
|
||||||
if lastCase == upper && (nextCase == upper || nextNumber == upper) {
|
|
||||||
buf.WriteRune(v)
|
|
||||||
} else {
|
|
||||||
if value[i-1] != '_' && value[i+1] != '_' {
|
|
||||||
buf.WriteRune('_')
|
|
||||||
}
|
|
||||||
buf.WriteRune(v)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
buf.WriteRune(v)
|
|
||||||
if i == len(value)-2 && (nextCase == upper && nextNumber == lower) {
|
|
||||||
buf.WriteRune('_')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
currCase = upper
|
|
||||||
buf.WriteRune(v)
|
|
||||||
}
|
|
||||||
lastCase = currCase
|
|
||||||
currCase = nextCase
|
|
||||||
}
|
|
||||||
|
|
||||||
buf.WriteByte(value[len(value)-1])
|
|
||||||
|
|
||||||
s := strings.ToLower(buf.String())
|
|
||||||
smap.Set(name, s)
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SQL expression
|
// SQL expression
|
||||||
type expr struct {
|
type expr struct {
|
||||||
expr string
|
expr string
|
||||||
@ -267,7 +206,7 @@ func getValueFromFields(value reflect.Value, fieldNames []string) (results []int
|
|||||||
// as FieldByName could panic
|
// as FieldByName could panic
|
||||||
if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
|
if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
|
||||||
for _, fieldName := range fieldNames {
|
for _, fieldName := range fieldNames {
|
||||||
if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() {
|
if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() {
|
||||||
result := fieldValue.Interface()
|
result := fieldValue.Interface()
|
||||||
if r, ok := result.(driver.Valuer); ok {
|
if r, ok := result.(driver.Valuer); ok {
|
||||||
result, _ = r.Value()
|
result, _ = r.Value()
|
||||||
|
@ -1,35 +0,0 @@
|
|||||||
package gorm_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestToDBNameGenerateFriendlyName(t *testing.T) {
|
|
||||||
var maps = map[string]string{
|
|
||||||
"": "",
|
|
||||||
"X": "x",
|
|
||||||
"ThisIsATest": "this_is_a_test",
|
|
||||||
"PFAndESI": "pf_and_esi",
|
|
||||||
"AbcAndJkl": "abc_and_jkl",
|
|
||||||
"EmployeeID": "employee_id",
|
|
||||||
"SKU_ID": "sku_id",
|
|
||||||
"UTF8": "utf8",
|
|
||||||
"Level1": "level1",
|
|
||||||
"SHA256Hash": "sha256_hash",
|
|
||||||
"FieldX": "field_x",
|
|
||||||
"HTTPAndSMTP": "http_and_smtp",
|
|
||||||
"HTTPServerHandlerForURLID": "http_server_handler_for_url_id",
|
|
||||||
"UUID": "uuid",
|
|
||||||
"HTTPURL": "http_url",
|
|
||||||
"HTTP_URL": "http_url",
|
|
||||||
"ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id",
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, value := range maps {
|
|
||||||
if gorm.ToDBName(key) != value {
|
|
||||||
t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
52
wercker.yml
52
wercker.yml
@ -9,6 +9,13 @@ services:
|
|||||||
MYSQL_USER: gorm
|
MYSQL_USER: gorm
|
||||||
MYSQL_PASSWORD: gorm
|
MYSQL_PASSWORD: gorm
|
||||||
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||||
|
- name: mysql
|
||||||
|
id: mysql:latest
|
||||||
|
env:
|
||||||
|
MYSQL_DATABASE: gorm
|
||||||
|
MYSQL_USER: gorm
|
||||||
|
MYSQL_PASSWORD: gorm
|
||||||
|
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||||
- name: mysql57
|
- name: mysql57
|
||||||
id: mysql:5.7
|
id: mysql:5.7
|
||||||
env:
|
env:
|
||||||
@ -23,13 +30,6 @@ services:
|
|||||||
MYSQL_USER: gorm
|
MYSQL_USER: gorm
|
||||||
MYSQL_PASSWORD: gorm
|
MYSQL_PASSWORD: gorm
|
||||||
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||||
- name: mysql55
|
|
||||||
id: mysql:5.5
|
|
||||||
env:
|
|
||||||
MYSQL_DATABASE: gorm
|
|
||||||
MYSQL_USER: gorm
|
|
||||||
MYSQL_PASSWORD: gorm
|
|
||||||
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
|
||||||
- name: postgres
|
- name: postgres
|
||||||
id: postgres:latest
|
id: postgres:latest
|
||||||
env:
|
env:
|
||||||
@ -83,7 +83,7 @@ build:
|
|||||||
code: |
|
code: |
|
||||||
cd $WERCKER_SOURCE_DIR
|
cd $WERCKER_SOURCE_DIR
|
||||||
go version
|
go version
|
||||||
go get -t ./...
|
go get -t -v ./...
|
||||||
|
|
||||||
# Build the project
|
# Build the project
|
||||||
- script:
|
- script:
|
||||||
@ -95,54 +95,60 @@ build:
|
|||||||
- script:
|
- script:
|
||||||
name: test sqlite
|
name: test sqlite
|
||||||
code: |
|
code: |
|
||||||
go test ./...
|
go test -race -v ./...
|
||||||
|
|
||||||
- script:
|
- script:
|
||||||
name: test mariadb
|
name: test mariadb
|
||||||
code: |
|
code: |
|
||||||
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test ./...
|
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
|
||||||
|
|
||||||
|
- script:
|
||||||
|
name: test mysql
|
||||||
|
code: |
|
||||||
|
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
|
||||||
|
|
||||||
- script:
|
- script:
|
||||||
name: test mysql5.7
|
name: test mysql5.7
|
||||||
code: |
|
code: |
|
||||||
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test ./...
|
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
|
||||||
|
|
||||||
- script:
|
- script:
|
||||||
name: test mysql5.6
|
name: test mysql5.6
|
||||||
code: |
|
code: |
|
||||||
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test ./...
|
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
|
||||||
|
|
||||||
- script:
|
|
||||||
name: test mysql5.5
|
|
||||||
code: |
|
|
||||||
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test ./...
|
|
||||||
|
|
||||||
- script:
|
- script:
|
||||||
name: test postgres
|
name: test postgres
|
||||||
code: |
|
code: |
|
||||||
GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
|
GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
|
||||||
|
|
||||||
- script:
|
- script:
|
||||||
name: test postgres96
|
name: test postgres96
|
||||||
code: |
|
code: |
|
||||||
GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
|
GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
|
||||||
|
|
||||||
- script:
|
- script:
|
||||||
name: test postgres95
|
name: test postgres95
|
||||||
code: |
|
code: |
|
||||||
GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
|
GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
|
||||||
|
|
||||||
- script:
|
- script:
|
||||||
name: test postgres94
|
name: test postgres94
|
||||||
code: |
|
code: |
|
||||||
GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
|
GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
|
||||||
|
|
||||||
- script:
|
- script:
|
||||||
name: test postgres93
|
name: test postgres93
|
||||||
code: |
|
code: |
|
||||||
GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
|
GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
|
||||||
|
|
||||||
- script:
|
- script:
|
||||||
name: test mssql
|
name: test mssql
|
||||||
code: |
|
code: |
|
||||||
GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test ./...
|
GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test -race ./...
|
||||||
|
|
||||||
|
- script:
|
||||||
|
name: codecov
|
||||||
|
code: |
|
||||||
|
go test -race -coverprofile=coverage.txt -covermode=atomic ./...
|
||||||
|
bash <(curl -s https://codecov.io/bash)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user