Merge branch 'master' into tag-settings-lock

This commit is contained in:
Jinzhu 2018-09-10 07:10:41 +08:00 committed by GitHub
commit 253b5a4875
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 593 additions and 185 deletions

View File

@ -59,7 +59,7 @@ func createCallback(scope *Scope) {
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
if scope.changeableField(field) { if scope.changeableField(field) {
if field.IsNormal { if field.IsNormal && !field.IsIgnored {
if field.IsBlank && field.HasDefaultValue { if field.IsBlank && field.HasDefaultValue {
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)

View File

@ -18,6 +18,11 @@ func queryCallback(scope *Scope) {
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
return return
} }
//we are only preloading relations, dont touch base model
if _, skip := scope.InstanceGet("gorm:only_preload"); skip {
return
}
defer scope.trace(NowFunc()) defer scope.trace(NowFunc())

View File

@ -14,8 +14,14 @@ func preloadCallback(scope *Scope) {
return return
} }
if _, ok := scope.Get("gorm:auto_preload"); ok { if ap, ok := scope.Get("gorm:auto_preload"); ok {
autoPreload(scope) // If gorm:auto_preload IS NOT a bool then auto preload.
// Else if it IS a bool, use the value
if apb, ok := ap.(bool); !ok {
autoPreload(scope)
} else if apb {
autoPreload(scope)
}
} }
if scope.Search.preload == nil || scope.HasError() { if scope.Search.preload == nil || scope.HasError() {
@ -155,14 +161,17 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
) )
if indirectScopeValue.Kind() == reflect.Slice { if indirectScopeValue.Kind() == reflect.Slice {
foreignValuesToResults := make(map[string]reflect.Value)
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames))
foreignValuesToResults[foreignValues] = result
}
for j := 0; j < indirectScopeValue.Len(); j++ { for j := 0; j < indirectScopeValue.Len(); j++ {
for i := 0; i < resultsValue.Len(); i++ { indirectValue := indirect(indirectScopeValue.Index(j))
result := resultsValue.Index(i) valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames))
foreignValues := getValueFromFields(result, relation.ForeignFieldNames) if result, found := foreignValuesToResults[valueString]; found {
if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { indirectValue.FieldByName(field.Name).Set(result)
indirectValue.FieldByName(field.Name).Set(result)
break
}
} }
} }
} else { } else {
@ -249,13 +258,21 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
indirectScopeValue = scope.IndirectValue() indirectScopeValue = scope.IndirectValue()
) )
foreignFieldToObjects := make(map[string][]*reflect.Value)
if indirectScopeValue.Kind() == reflect.Slice {
for j := 0; j < indirectScopeValue.Len(); j++ {
object := indirect(indirectScopeValue.Index(j))
valueString := toString(getValueFromFields(object, relation.ForeignFieldNames))
foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object)
}
}
for i := 0; i < resultsValue.Len(); i++ { for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i) result := resultsValue.Index(i)
if indirectScopeValue.Kind() == reflect.Slice { if indirectScopeValue.Kind() == reflect.Slice {
value := getValueFromFields(result, relation.AssociationForeignFieldNames) valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames))
for j := 0; j < indirectScopeValue.Len(); j++ { if objects, found := foreignFieldToObjects[valueString]; found {
object := indirect(indirectScopeValue.Index(j)) for _, object := range objects {
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
object.FieldByName(field.Name).Set(result) object.FieldByName(field.Name).Set(result)
} }
} }

View File

@ -21,9 +21,7 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea
if v, ok := value.(string); ok { if v, ok := value.(string); ok {
v = strings.ToLower(v) v = strings.ToLower(v)
if v == "false" || v != "skip" { return v == "true"
return false
}
} }
return true return true
@ -36,9 +34,11 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea
if value, ok := scope.Get("gorm:save_associations"); ok { if value, ok := scope.Get("gorm:save_associations"); ok {
autoUpdate = checkTruth(value) autoUpdate = checkTruth(value)
autoCreate = autoUpdate autoCreate = autoUpdate
saveReference = autoUpdate
} else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok { } else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok {
autoUpdate = checkTruth(value) autoUpdate = checkTruth(value)
autoCreate = autoUpdate autoCreate = autoUpdate
saveReference = autoUpdate
} }
if value, ok := scope.Get("gorm:association_autoupdate"); ok { if value, ok := scope.Get("gorm:association_autoupdate"); ok {

View File

@ -76,7 +76,9 @@ func updateCallback(scope *Scope) {
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
if scope.changeableField(field) { if scope.changeableField(field) {
if !field.IsPrimaryKey && field.IsNormal { if !field.IsPrimaryKey && field.IsNormal {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
for _, foreignKey := range relationship.ForeignDBNames { for _, foreignKey := range relationship.ForeignDBNames {
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {

View File

@ -72,6 +72,12 @@ func RegisterDialect(name string, dialect Dialect) {
dialectsMap[name] = dialect dialectsMap[name] = dialect
} }
// GetDialect gets the dialect for the specified dialect name
func GetDialect(name string) (dialect Dialect, ok bool) {
dialect, ok = dialectsMap[name]
return
}
// ParseFieldStructForDialect get field's sql data type // ParseFieldStructForDialect get field's sql data type
var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
// Get redirected field type // Get redirected field type

View File

@ -1,12 +1,16 @@
package mssql package mssql
import ( import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"time" "time"
// Importing mssql driver package only in dialect file, otherwide not needed
_ "github.com/denisenkom/go-mssqldb" _ "github.com/denisenkom/go-mssqldb"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
) )
@ -130,7 +134,14 @@ func (s mssql) RemoveIndex(tableName string, indexName string) error {
} }
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
return false var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow(`SELECT count(*)
FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id
inner join information_schema.tables as I on I.TABLE_NAME = T.name
WHERE F.name = ?
AND T.Name = ? AND I.TABLE_CATALOG = ?;`, foreignKeyName, tableName, currentDatabase).Scan(&count)
return count > 0
} }
func (s mssql) HasTable(tableName string) bool { func (s mssql) HasTable(tableName string) bool {
@ -194,3 +205,27 @@ func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, st
} }
return dialect.CurrentDatabase(), tableName return dialect.CurrentDatabase(), tableName
} }
// JSON type to support easy handling of JSON data in character table fields
// using golang json.RawMessage for deferred decoding/encoding
type JSON struct {
json.RawMessage
}
// Value get value of JSON
func (j JSON) Value() (driver.Value, error) {
if len(j.RawMessage) == 0 {
return nil, nil
}
return j.MarshalJSON()
}
// Scan scan value into JSON
func (j *JSON) Scan(value interface{}) error {
str, ok := value.(string)
if !ok {
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value))
}
bytes := []byte(str)
return json.Unmarshal(bytes, j)
}

View File

@ -4,11 +4,11 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
_ "github.com/lib/pq"
"github.com/lib/pq/hstore"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
_ "github.com/lib/pq"
"github.com/lib/pq/hstore"
) )
type Hstore map[string]*string type Hstore map[string]*string

View File

@ -6,7 +6,7 @@ import (
) )
var ( var (
// ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct // ErrRecordNotFound record not found error, happens when only haven't find any matched data when looking up with a struct, finding a slice won't return this error
ErrRecordNotFound = errors.New("record not found") ErrRecordNotFound = errors.New("record not found")
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
ErrInvalidSQL = errors.New("invalid SQL") ErrInvalidSQL = errors.New("invalid SQL")

41
main.go
View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"sync"
"time" "time"
) )
@ -21,7 +22,7 @@ type DB struct {
logMode int logMode int
logger logger logger logger
search *search search *search
values map[string]interface{} values sync.Map
// global db // global db
parent *DB parent *DB
@ -48,6 +49,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
} }
var source string var source string
var dbSQL SQLCommon var dbSQL SQLCommon
var ownDbSQL bool
switch value := args[0].(type) { switch value := args[0].(type) {
case string: case string:
@ -59,14 +61,17 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
source = args[1].(string) source = args[1].(string)
} }
dbSQL, err = sql.Open(driver, source) dbSQL, err = sql.Open(driver, source)
ownDbSQL = true
case SQLCommon: case SQLCommon:
dbSQL = value dbSQL = value
ownDbSQL = false
default:
return nil, fmt.Errorf("invalid database source: %v is not a valid type", value)
} }
db = &DB{ db = &DB{
db: dbSQL, db: dbSQL,
logger: defaultLogger, logger: defaultLogger,
values: map[string]interface{}{},
callbacks: DefaultCallback, callbacks: DefaultCallback,
dialect: newDialect(dialect, dbSQL), dialect: newDialect(dialect, dbSQL),
} }
@ -76,7 +81,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
} }
// Send a ping to make sure the database connection is alive. // Send a ping to make sure the database connection is alive.
if d, ok := dbSQL.(*sql.DB); ok { if d, ok := dbSQL.(*sql.DB); ok {
if err = d.Ping(); err != nil { if err = d.Ping(); err != nil && ownDbSQL {
d.Close() d.Close()
} }
} }
@ -117,7 +122,7 @@ func (s *DB) CommonDB() SQLCommon {
// Dialect get dialect // Dialect get dialect
func (s *DB) Dialect() Dialect { func (s *DB) Dialect() Dialect {
return s.parent.dialect return s.dialect
} }
// Callback return `Callbacks` container, you could add/change/delete callbacks with it // Callback return `Callbacks` container, you could add/change/delete callbacks with it
@ -157,7 +162,7 @@ func (s *DB) HasBlockGlobalUpdate() bool {
// SingularTable use singular table by default // SingularTable use singular table by default
func (s *DB) SingularTable(enable bool) { func (s *DB) SingularTable(enable bool) {
modelStructsMap = newModelStructsMap() modelStructsMap = sync.Map{}
s.parent.singularTable = enable s.parent.singularTable = enable
} }
@ -309,6 +314,11 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB {
return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
} }
//Preloads preloads relations, don`t touch out
func (s *DB) Preloads(out interface{}) *DB {
return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db
}
// Scan scan value to a struct // Scan scan value to a struct
func (s *DB) Scan(dest interface{}) *DB { func (s *DB) Scan(dest interface{}) *DB {
return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
@ -482,6 +492,8 @@ func (s *DB) Begin() *DB {
if db, ok := c.db.(sqlDb); ok && db != nil { if db, ok := c.db.(sqlDb); ok && db != nil {
tx, err := db.Begin() tx, err := db.Begin()
c.db = interface{}(tx).(SQLCommon) c.db = interface{}(tx).(SQLCommon)
c.dialect.SetDB(c.db)
c.AddError(err) c.AddError(err)
} else { } else {
c.AddError(ErrCantStartTransaction) c.AddError(ErrCantStartTransaction)
@ -491,7 +503,8 @@ func (s *DB) Begin() *DB {
// Commit commit a transaction // Commit commit a transaction
func (s *DB) Commit() *DB { func (s *DB) Commit() *DB {
if db, ok := s.db.(sqlTx); ok && db != nil { var emptySQLTx *sql.Tx
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
s.AddError(db.Commit()) s.AddError(db.Commit())
} else { } else {
s.AddError(ErrInvalidTransaction) s.AddError(ErrInvalidTransaction)
@ -501,7 +514,8 @@ func (s *DB) Commit() *DB {
// Rollback rollback a transaction // Rollback rollback a transaction
func (s *DB) Rollback() *DB { func (s *DB) Rollback() *DB {
if db, ok := s.db.(sqlTx); ok && db != nil { var emptySQLTx *sql.Tx
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
s.AddError(db.Rollback()) s.AddError(db.Rollback())
} else { } else {
s.AddError(ErrInvalidTransaction) s.AddError(ErrInvalidTransaction)
@ -670,13 +684,13 @@ func (s *DB) Set(name string, value interface{}) *DB {
// InstantSet instant set setting, will affect current db // InstantSet instant set setting, will affect current db
func (s *DB) InstantSet(name string, value interface{}) *DB { func (s *DB) InstantSet(name string, value interface{}) *DB {
s.values[name] = value s.values.Store(name, value)
return s return s
} }
// Get get setting by name // Get get setting by name
func (s *DB) Get(name string) (value interface{}, ok bool) { func (s *DB) Get(name string) (value interface{}, ok bool) {
value, ok = s.values[name] value, ok = s.values.Load(name)
return return
} }
@ -740,15 +754,16 @@ func (s *DB) clone() *DB {
parent: s.parent, parent: s.parent,
logger: s.logger, logger: s.logger,
logMode: s.logMode, logMode: s.logMode,
values: map[string]interface{}{},
Value: s.Value, Value: s.Value,
Error: s.Error, Error: s.Error,
blockGlobalUpdate: s.blockGlobalUpdate, blockGlobalUpdate: s.blockGlobalUpdate,
dialect: newDialect(s.dialect.GetName(), s.db),
} }
for key, value := range s.values { s.values.Range(func(k, v interface{}) bool {
db.values[key] = value db.values.Store(k, v)
} return true
})
if s.search == nil { if s.search == nil {
db.search = &search{limit: -1, offset: -1} db.search = &search{limit: -1, offset: -1}

View File

@ -8,6 +8,7 @@ import (
"path/filepath" "path/filepath"
"reflect" "reflect"
"strconv" "strconv"
"strings"
"testing" "testing"
"time" "time"
@ -79,6 +80,22 @@ func OpenTestConnection() (db *gorm.DB, err error) {
return return
} }
func TestOpen_ReturnsError_WithBadArgs(t *testing.T) {
stringRef := "foo"
testCases := []interface{}{42, time.Now(), &stringRef}
for _, tc := range testCases {
t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) {
_, err := gorm.Open("postgresql", tc)
if err == nil {
t.Error("Should got error with invalid database source")
}
if !strings.HasPrefix(err.Error(), "invalid database source:") {
t.Errorf("Should got error starting with \"invalid database source:\", but got %q", err.Error())
}
})
}
}
func TestStringPrimaryKey(t *testing.T) { func TestStringPrimaryKey(t *testing.T) {
type UUIDStruct struct { type UUIDStruct struct {
ID string `gorm:"primary_key"` ID string `gorm:"primary_key"`
@ -564,6 +581,60 @@ func TestJoins(t *testing.T) {
} }
} }
type JoinedIds struct {
UserID int64 `gorm:"column:id"`
BillingAddressID int64 `gorm:"column:id"`
EmailID int64 `gorm:"column:id"`
}
func TestScanIdenticalColumnNames(t *testing.T) {
var user = User{
Name: "joinsIds",
Email: "joinIds@example.com",
BillingAddress: Address{
Address1: "One Park Place",
},
Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
}
DB.Save(&user)
var users []JoinedIds
DB.Select("users.id, addresses.id, emails.id").Table("users").
Joins("left join addresses on users.billing_address_id = addresses.id").
Joins("left join emails on emails.user_id = users.id").
Where("name = ?", "joinsIds").Scan(&users)
if len(users) != 2 {
t.Fatal("should find two rows using left join")
}
if user.Id != users[0].UserID {
t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[0].UserID)
}
if user.Id != users[1].UserID {
t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[1].UserID)
}
if user.BillingAddressID.Int64 != users[0].BillingAddressID {
t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID)
}
if user.BillingAddressID.Int64 != users[1].BillingAddressID {
t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID)
}
if users[0].EmailID == users[1].EmailID {
t.Errorf("Email ids should be unique. Got %d and %d", users[0].EmailID, users[1].EmailID)
}
if int64(user.Emails[0].Id) != users[0].EmailID && int64(user.Emails[1].Id) != users[0].EmailID {
t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[0].EmailID)
}
if int64(user.Emails[0].Id) != users[1].EmailID && int64(user.Emails[1].Id) != users[1].EmailID {
t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[1].EmailID)
}
}
func TestJoinsWithSelect(t *testing.T) { func TestJoinsWithSelect(t *testing.T) {
type result struct { type result struct {
Name string Name string
@ -862,6 +933,94 @@ func TestOpenWithOneParameter(t *testing.T) {
} }
} }
func TestSaveAssociations(t *testing.T) {
db := DB.New()
deltaAddressCount := 0
if err := db.Model(&Address{}).Count(&deltaAddressCount).Error; err != nil {
t.Errorf("failed to fetch address count")
t.FailNow()
}
placeAddress := &Address{
Address1: "somewhere on earth",
}
ownerAddress1 := &Address{
Address1: "near place address",
}
ownerAddress2 := &Address{
Address1: "address2",
}
db.Create(placeAddress)
addressCountShouldBe := func(t *testing.T, expectedCount int) {
countFromDB := 0
t.Helper()
err := db.Model(&Address{}).Count(&countFromDB).Error
if err != nil {
t.Error("failed to fetch address count")
}
if countFromDB != expectedCount {
t.Errorf("address count mismatch: %d", countFromDB)
}
}
addressCountShouldBe(t, deltaAddressCount+1)
// owner address should be created, place address should be reused
place1 := &Place{
PlaceAddressID: placeAddress.ID,
PlaceAddress: placeAddress,
OwnerAddress: ownerAddress1,
}
err := db.Create(place1).Error
if err != nil {
t.Errorf("failed to store place: %s", err.Error())
}
addressCountShouldBe(t, deltaAddressCount+2)
// owner address should be created again, place address should be reused
place2 := &Place{
PlaceAddressID: placeAddress.ID,
PlaceAddress: &Address{
ID: 777,
Address1: "address1",
},
OwnerAddress: ownerAddress2,
OwnerAddressID: 778,
}
err = db.Create(place2).Error
if err != nil {
t.Errorf("failed to store place: %s", err.Error())
}
addressCountShouldBe(t, deltaAddressCount+3)
count := 0
db.Model(&Place{}).Where(&Place{
PlaceAddressID: placeAddress.ID,
OwnerAddressID: ownerAddress1.ID,
}).Count(&count)
if count != 1 {
t.Errorf("only one instance of (%d, %d) should be available, found: %d",
placeAddress.ID, ownerAddress1.ID, count)
}
db.Model(&Place{}).Where(&Place{
PlaceAddressID: placeAddress.ID,
OwnerAddressID: ownerAddress2.ID,
}).Count(&count)
if count != 1 {
t.Errorf("only one instance of (%d, %d) should be available, found: %d",
placeAddress.ID, ownerAddress2.ID, count)
}
db.Model(&Place{}).Where(&Place{
PlaceAddressID: placeAddress.ID,
}).Count(&count)
if count != 2 {
t.Errorf("two instances of (%d) should be available, found: %d",
placeAddress.ID, count)
}
}
func TestBlockGlobalUpdate(t *testing.T) { func TestBlockGlobalUpdate(t *testing.T) {
db := DB.New() db := DB.New()
db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"}) db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})

View File

@ -118,6 +118,14 @@ type Company struct {
Owner *User `sql:"-"` Owner *User `sql:"-"`
} }
type Place struct {
Id int64
PlaceAddressID int
PlaceAddress *Address `gorm:"save_associations:false"`
OwnerAddressID int
OwnerAddress *Address `gorm:"save_associations:true"`
}
type EncryptedData []byte type EncryptedData []byte
func (data *EncryptedData) Scan(value interface{}) error { func (data *EncryptedData) Scan(value interface{}) error {
@ -284,7 +292,7 @@ func runMigration() {
DB.Exec(fmt.Sprintf("drop table %v;", table)) DB.Exec(fmt.Sprintf("drop table %v;", table))
} }
values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}} values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}}
for _, value := range values { for _, value := range values {
DB.DropTable(value) DB.DropTable(value)
} }
@ -398,6 +406,53 @@ func TestAutoMigration(t *testing.T) {
} }
} }
func TestCreateAndAutomigrateTransaction(t *testing.T) {
tx := DB.Begin()
func() {
type Bar struct {
ID uint
}
DB.DropTableIfExists(&Bar{})
if ok := DB.HasTable("bars"); ok {
t.Errorf("Table should not exist, but does")
}
if ok := tx.HasTable("bars"); ok {
t.Errorf("Table should not exist, but does")
}
}()
func() {
type Bar struct {
Name string
}
err := tx.CreateTable(&Bar{}).Error
if err != nil {
t.Errorf("Should have been able to create the table, but couldn't: %s", err)
}
if ok := tx.HasTable(&Bar{}); !ok {
t.Errorf("The transaction should be able to see the table")
}
}()
func() {
type Bar struct {
Stuff string
}
err := tx.AutoMigrate(&Bar{}).Error
if err != nil {
t.Errorf("Should have been able to alter the table, but couldn't")
}
}()
tx.Rollback()
}
type MultipleIndexes struct { type MultipleIndexes struct {
ID int64 ID int64
UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`

View File

@ -17,28 +17,7 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
return defaultTableName return defaultTableName
} }
type safeModelStructsMap struct { var modelStructsMap sync.Map
m map[reflect.Type]*ModelStruct
l *sync.RWMutex
}
func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) {
s.l.Lock()
defer s.l.Unlock()
s.m[key] = value
}
func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct {
s.l.RLock()
defer s.l.RUnlock()
return s.m[key]
}
func newModelStructsMap() *safeModelStructsMap {
return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)}
}
var modelStructsMap = newModelStructsMap()
// ModelStruct model definition // ModelStruct model definition
type ModelStruct struct { type ModelStruct struct {
@ -48,14 +27,14 @@ type ModelStruct struct {
defaultTableName string defaultTableName string
} }
// TableName get model's table name // TableName returns model's table name
func (s *ModelStruct) TableName(db *DB) string { func (s *ModelStruct) TableName(db *DB) string {
if s.defaultTableName == "" && db != nil && s.ModelType != nil { if s.defaultTableName == "" && db != nil && s.ModelType != nil {
// Set default table name // Set default table name
if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok {
s.defaultTableName = tabler.TableName() s.defaultTableName = tabler.TableName()
} else { } else {
tableName := ToDBName(s.ModelType.Name()) tableName := ToTableName(s.ModelType.Name())
if db == nil || !db.parent.singularTable { if db == nil || !db.parent.singularTable {
tableName = inflection.Plural(tableName) tableName = inflection.Plural(tableName)
} }
@ -153,7 +132,7 @@ type Relationship struct {
func getForeignField(column string, fields []*StructField) *StructField { func getForeignField(column string, fields []*StructField) *StructField {
for _, field := range fields { for _, field := range fields {
if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) { if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) {
return field return field
} }
} }
@ -179,8 +158,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
// Get Cached model struct // Get Cached model struct
if value := modelStructsMap.Get(reflectType); value != nil { if value, ok := modelStructsMap.Load(reflectType); ok && value != nil {
return value return value.(*ModelStruct)
} }
modelStruct.ModelType = reflectType modelStruct.ModelType = reflectType
@ -317,7 +296,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
// if defined join table's foreign key // if defined join table's foreign key
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
} else { } else {
defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName
relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
} }
} }
@ -348,7 +327,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
} else { } else {
// join table foreign keys for association // join table foreign keys for association
joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
} }
} }
@ -356,7 +335,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
joinTableHandler := JoinTableHandler{} joinTableHandler := JoinTableHandler{}
joinTableHandler.Setup(relationship, many2many, reflectType, elemType) joinTableHandler.Setup(relationship, ToTableName(many2many), reflectType, elemType)
relationship.JoinTableHandler = &joinTableHandler relationship.JoinTableHandler = &joinTableHandler
field.Relationship = relationship field.Relationship = relationship
} else { } else {
@ -614,7 +593,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
if value, ok := field.TagSettingsGet("COLUMN"); ok { if value, ok := field.TagSettingsGet("COLUMN"); ok {
field.DBName = value field.DBName = value
} else { } else {
field.DBName = ToDBName(fieldStruct.Name) field.DBName = ToColumnName(fieldStruct.Name)
} }
modelStruct.StructFields = append(modelStruct.StructFields, field) modelStruct.StructFields = append(modelStruct.StructFields, field)
@ -628,7 +607,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
} }
modelStructsMap.Set(reflectType, &modelStruct) modelStructsMap.Store(reflectType, &modelStruct)
return &modelStruct return &modelStruct
} }

124
naming.go Normal file
View File

@ -0,0 +1,124 @@
package gorm
import (
"bytes"
"strings"
)
// Namer is a function type which is given a string and return a string
type Namer func(string) string
// NamingStrategy represents naming strategies
type NamingStrategy struct {
DB Namer
Table Namer
Column Namer
}
// TheNamingStrategy is being initialized with defaultNamingStrategy
var TheNamingStrategy = &NamingStrategy{
DB: defaultNamer,
Table: defaultNamer,
Column: defaultNamer,
}
// AddNamingStrategy sets the naming strategy
func AddNamingStrategy(ns *NamingStrategy) {
if ns.DB == nil {
ns.DB = defaultNamer
}
if ns.Table == nil {
ns.Table = defaultNamer
}
if ns.Column == nil {
ns.Column = defaultNamer
}
TheNamingStrategy = ns
}
// DBName alters the given name by DB
func (ns *NamingStrategy) DBName(name string) string {
return ns.DB(name)
}
// TableName alters the given name by Table
func (ns *NamingStrategy) TableName(name string) string {
return ns.Table(name)
}
// ColumnName alters the given name by Column
func (ns *NamingStrategy) ColumnName(name string) string {
return ns.Column(name)
}
// ToDBName convert string to db name
func ToDBName(name string) string {
return TheNamingStrategy.DBName(name)
}
// ToTableName convert string to table name
func ToTableName(name string) string {
return TheNamingStrategy.TableName(name)
}
// ToColumnName convert string to db name
func ToColumnName(name string) string {
return TheNamingStrategy.ColumnName(name)
}
var smap = newSafeMap()
func defaultNamer(name string) string {
const (
lower = false
upper = true
)
if v := smap.Get(name); v != "" {
return v
}
if name == "" {
return ""
}
var (
value = commonInitialismsReplacer.Replace(name)
buf = bytes.NewBufferString("")
lastCase, currCase, nextCase, nextNumber bool
)
for i, v := range value[:len(value)-1] {
nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z')
nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9')
if i > 0 {
if currCase == upper {
if lastCase == upper && (nextCase == upper || nextNumber == upper) {
buf.WriteRune(v)
} else {
if value[i-1] != '_' && value[i+1] != '_' {
buf.WriteRune('_')
}
buf.WriteRune(v)
}
} else {
buf.WriteRune(v)
if i == len(value)-2 && (nextCase == upper && nextNumber == lower) {
buf.WriteRune('_')
}
}
} else {
currCase = upper
buf.WriteRune(v)
}
lastCase = currCase
currCase = nextCase
}
buf.WriteByte(value[len(value)-1])
s := strings.ToLower(buf.String())
smap.Set(name, s)
return s
}

69
naming_test.go Normal file
View File

@ -0,0 +1,69 @@
package gorm_test
import (
"testing"
"github.com/jinzhu/gorm"
)
func TestTheNamingStrategy(t *testing.T) {
cases := []struct {
name string
namer gorm.Namer
expected string
}{
{name: "auth", expected: "auth", namer: gorm.TheNamingStrategy.DB},
{name: "userRestrictions", expected: "user_restrictions", namer: gorm.TheNamingStrategy.Table},
{name: "clientID", expected: "client_id", namer: gorm.TheNamingStrategy.Column},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
result := c.namer(c.name)
if result != c.expected {
t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result)
}
})
}
}
func TestNamingStrategy(t *testing.T) {
dbNameNS := func(name string) string {
return "db_" + name
}
tableNameNS := func(name string) string {
return "tbl_" + name
}
columnNameNS := func(name string) string {
return "col_" + name
}
ns := &gorm.NamingStrategy{
DB: dbNameNS,
Table: tableNameNS,
Column: columnNameNS,
}
cases := []struct {
name string
namer gorm.Namer
expected string
}{
{name: "auth", expected: "db_auth", namer: ns.DB},
{name: "user", expected: "tbl_user", namer: ns.Table},
{name: "password", expected: "col_password", namer: ns.Column},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
result := c.namer(c.name)
if result != c.expected {
t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result)
}
})
}
}

View File

@ -123,6 +123,31 @@ func TestAutoPreload(t *testing.T) {
} }
} }
func TestAutoPreloadFalseDoesntPreload(t *testing.T) {
user1 := getPreloadUser("auto_user1")
DB.Save(user1)
preloadDB := DB.Set("gorm:auto_preload", false).Where("role = ?", "Preload")
var user User
preloadDB.Find(&user)
if user.BillingAddress.Address1 != "" {
t.Error("AutoPreload was set to fasle, but still fetched data")
}
user2 := getPreloadUser("auto_user2")
DB.Save(user2)
var users []User
preloadDB.Find(&users)
for _, user := range users {
if user.BillingAddress.Address1 != "" {
t.Error("AutoPreload was set to fasle, but still fetched data")
}
}
}
func TestNestedPreload1(t *testing.T) { func TestNestedPreload1(t *testing.T) {
type ( type (
Level1 struct { Level1 struct {

View File

@ -181,17 +181,17 @@ func TestSearchWithPlainSQL(t *testing.T) {
scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users) scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users)
if len(users) != 2 { if len(users) != 2 {
t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users)) t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users))
} }
scopedb.Where("birthday > ?", "2002-10-10").Find(&users) scopedb.Where("birthday > ?", "2002-10-10").Find(&users)
if len(users) != 2 { if len(users) != 2 {
t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users)) t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users))
} }
scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users)
if len(users) != 1 { if len(users) != 1 {
t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
} }
DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users)
@ -532,28 +532,28 @@ func TestNot(t *testing.T) {
DB.Table("users").Where("name = ?", "user3").Count(&name3Count) DB.Table("users").Where("name = ?", "user3").Count(&name3Count)
DB.Not("name", "user3").Find(&users4) DB.Not("name", "user3").Find(&users4)
if len(users1)-len(users4) != int(name3Count) { if len(users1)-len(users4) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3") t.Errorf("Should find all users' name not equal 3")
} }
DB.Not("name = ?", "user3").Find(&users4) DB.Not("name = ?", "user3").Find(&users4)
if len(users1)-len(users4) != int(name3Count) { if len(users1)-len(users4) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3") t.Errorf("Should find all users' name not equal 3")
} }
DB.Not("name <> ?", "user3").Find(&users4) DB.Not("name <> ?", "user3").Find(&users4)
if len(users4) != int(name3Count) { if len(users4) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3") t.Errorf("Should find all users' name not equal 3")
} }
DB.Not(User{Name: "user3"}).Find(&users5) DB.Not(User{Name: "user3"}).Find(&users5)
if len(users1)-len(users5) != int(name3Count) { if len(users1)-len(users5) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3") t.Errorf("Should find all users' name not equal 3")
} }
DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6) DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6)
if len(users1)-len(users6) != int(name3Count) { if len(users1)-len(users6) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3") t.Errorf("Should find all users' name not equal 3")
} }
DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7)
@ -563,14 +563,14 @@ func TestNot(t *testing.T) {
DB.Not("name", []string{"user3"}).Find(&users8) DB.Not("name", []string{"user3"}).Find(&users8)
if len(users1)-len(users8) != int(name3Count) { if len(users1)-len(users8) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3") t.Errorf("Should find all users' name not equal 3")
} }
var name2Count int64 var name2Count int64
DB.Table("users").Where("name = ?", "user2").Count(&name2Count) DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
DB.Not("name", []string{"user3", "user2"}).Find(&users9) DB.Not("name", []string{"user3", "user2"}).Find(&users9)
if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) {
t.Errorf("Should find all users's name not equal 3") t.Errorf("Should find all users' name not equal 3")
} }
} }

View File

@ -63,7 +63,7 @@ func (scope *Scope) SQLDB() SQLCommon {
// Dialect get dialect // Dialect get dialect
func (scope *Scope) Dialect() Dialect { func (scope *Scope) Dialect() Dialect {
return scope.db.parent.dialect return scope.db.dialect
} }
// Quote used to quote string to escape them for database // Quote used to quote string to escape them for database
@ -134,7 +134,7 @@ func (scope *Scope) Fields() []*Field {
// FieldByName find `gorm.Field` with field name or db name // FieldByName find `gorm.Field` with field name or db name
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
var ( var (
dbName = ToDBName(name) dbName = ToColumnName(name)
mostMatchedField *Field mostMatchedField *Field
) )
@ -486,8 +486,10 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
values[index] = &ignored values[index] = &ignored
selectFields = fields selectFields = fields
offset := 0
if idx, ok := selectedColumnsMap[column]; ok { if idx, ok := selectedColumnsMap[column]; ok {
selectFields = selectFields[idx+1:] offset = idx + 1
selectFields = selectFields[offset:]
} }
for fieldIndex, field := range selectFields { for fieldIndex, field := range selectFields {
@ -501,7 +503,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
resetFields[index] = field resetFields[index] = field
} }
selectedColumnsMap[column] = fieldIndex selectedColumnsMap[column] = offset + fieldIndex
if field.IsNormal { if field.IsNormal {
break break
@ -586,10 +588,10 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
scope.Err(fmt.Errorf("invalid query condition: %v", value)) scope.Err(fmt.Errorf("invalid query condition: %v", value))
return return
} }
scopeQuotedTableName := newScope.QuotedTableName()
for _, field := range newScope.Fields() { for _, field := range newScope.Fields() {
if !field.IsIgnored && !field.IsBlank { if !field.IsIgnored && !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
} }
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
@ -692,12 +694,12 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
buff := bytes.NewBuffer([]byte{}) buff := bytes.NewBuffer([]byte{})
i := 0 i := 0
for pos := range str { for pos, char := range str {
if str[pos] == '?' { if str[pos] == '?' {
buff.WriteString(replacements[i]) buff.WriteString(replacements[i])
i++ i++
} else { } else {
buff.WriteByte(str[pos]) buff.WriteRune(char)
} }
} }
@ -880,7 +882,7 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Map: case reflect.Map:
for _, key := range reflectValue.MapKeys() { for _, key := range reflectValue.MapKeys() {
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
} }
default: default:
for _, field := range (&Scope{Value: values}).Fields() { for _, field := range (&Scope{Value: values}).Fields() {
@ -907,7 +909,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin
results[field.DBName] = value results[field.DBName] = value
} else { } else {
err := field.Set(value) err := field.Set(value)
if field.IsNormal { if field.IsNormal && !field.IsIgnored {
hasUpdate = true hasUpdate = true
if err == ErrUnaddressable { if err == ErrUnaddressable {
results[field.DBName] = value results[field.DBName] = value
@ -1215,12 +1217,18 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
} }
func (scope *Scope) removeForeignKey(field string, dest string) { func (scope *Scope) removeForeignKey(field string, dest string) {
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest) keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
return return
} }
var query = `ALTER TABLE %s DROP CONSTRAINT %s;` var mysql mysql
var query string
if scope.Dialect().GetName() == mysql.GetName() {
query = `ALTER TABLE %s DROP FOREIGN KEY %s;`
} else {
query = `ALTER TABLE %s DROP CONSTRAINT %s;`
}
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec()
} }

View File

@ -1,7 +1,6 @@
package gorm package gorm
import ( import (
"bytes"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
@ -26,8 +25,8 @@ var NowFunc = func() time.Time {
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
var commonInitialismsReplacer *strings.Replacer var commonInitialismsReplacer *strings.Replacer
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`)
var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`) var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`)
func init() { func init() {
var commonInitialismsForReplacer []string var commonInitialismsForReplacer []string
@ -58,64 +57,6 @@ func newSafeMap() *safeMap {
return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
} }
var smap = newSafeMap()
type strCase bool
const (
lower strCase = false
upper strCase = true
)
// ToDBName convert string to db name
func ToDBName(name string) string {
if v := smap.Get(name); v != "" {
return v
}
if name == "" {
return ""
}
var (
value = commonInitialismsReplacer.Replace(name)
buf = bytes.NewBufferString("")
lastCase, currCase, nextCase strCase
)
for i, v := range value[:len(value)-1] {
nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z')
if i > 0 {
if currCase == upper {
if lastCase == upper && nextCase == upper {
buf.WriteRune(v)
} else {
if value[i-1] != '_' && value[i+1] != '_' {
buf.WriteRune('_')
}
buf.WriteRune(v)
}
} else {
buf.WriteRune(v)
if i == len(value)-2 && nextCase == upper {
buf.WriteRune('_')
}
}
} else {
currCase = upper
buf.WriteRune(v)
}
lastCase = currCase
currCase = nextCase
}
buf.WriteByte(value[len(value)-1])
s := strings.ToLower(buf.String())
smap.Set(name, s)
return s
}
// SQL expression // SQL expression
type expr struct { type expr struct {
expr string expr string
@ -265,7 +206,7 @@ func getValueFromFields(value reflect.Value, fieldNames []string) (results []int
// as FieldByName could panic // as FieldByName could panic
if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
for _, fieldName := range fieldNames { for _, fieldName := range fieldNames {
if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() {
result := fieldValue.Interface() result := fieldValue.Interface()
if r, ok := result.(driver.Valuer); ok { if r, ok := result.(driver.Valuer); ok {
result, _ = r.Value() result, _ = r.Value()

View File

@ -1,32 +0,0 @@
package gorm_test
import (
"testing"
"github.com/jinzhu/gorm"
)
func TestToDBNameGenerateFriendlyName(t *testing.T) {
var maps = map[string]string{
"": "",
"X": "x",
"ThisIsATest": "this_is_a_test",
"PFAndESI": "pf_and_esi",
"AbcAndJkl": "abc_and_jkl",
"EmployeeID": "employee_id",
"SKU_ID": "sku_id",
"FieldX": "field_x",
"HTTPAndSMTP": "http_and_smtp",
"HTTPServerHandlerForURLID": "http_server_handler_for_url_id",
"UUID": "uuid",
"HTTPURL": "http_url",
"HTTP_URL": "http_url",
"ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id",
}
for key, value := range maps {
if gorm.ToDBName(key) != value {
t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key))
}
}
}