diff --git a/callback_create.go b/callback_create.go index e7fe6f86..2ab05d3b 100644 --- a/callback_create.go +++ b/callback_create.go @@ -59,7 +59,7 @@ func createCallback(scope *Scope) { for _, field := range scope.Fields() { if scope.changeableField(field) { - if field.IsNormal { + if field.IsNormal && !field.IsIgnored { if field.IsBlank && field.HasDefaultValue { blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) diff --git a/callback_query.go b/callback_query.go index ba10cc7d..593e5d30 100644 --- a/callback_query.go +++ b/callback_query.go @@ -18,6 +18,11 @@ func queryCallback(scope *Scope) { if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { return } + + //we are only preloading relations, dont touch base model + if _, skip := scope.InstanceGet("gorm:only_preload"); skip { + return + } defer scope.trace(NowFunc()) diff --git a/callback_query_preload.go b/callback_query_preload.go index b1572753..d7c8a133 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -14,8 +14,14 @@ func preloadCallback(scope *Scope) { return } - if _, ok := scope.Get("gorm:auto_preload"); ok { - autoPreload(scope) + if ap, ok := scope.Get("gorm:auto_preload"); ok { + // If gorm:auto_preload IS NOT a bool then auto preload. + // Else if it IS a bool, use the value + if apb, ok := ap.(bool); !ok { + autoPreload(scope) + } else if apb { + autoPreload(scope) + } } if scope.Search.preload == nil || scope.HasError() { @@ -155,14 +161,17 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) ) if indirectScopeValue.Kind() == reflect.Slice { + foreignValuesToResults := make(map[string]reflect.Value) + for i := 0; i < resultsValue.Len(); i++ { + result := resultsValue.Index(i) + foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames)) + foreignValuesToResults[foreignValues] = result + } for j := 0; j < indirectScopeValue.Len(); j++ { - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { - indirectValue.FieldByName(field.Name).Set(result) - break - } + indirectValue := indirect(indirectScopeValue.Index(j)) + valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames)) + if result, found := foreignValuesToResults[valueString]; found { + indirectValue.FieldByName(field.Name).Set(result) } } } else { @@ -249,13 +258,21 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ indirectScopeValue = scope.IndirectValue() ) + foreignFieldToObjects := make(map[string][]*reflect.Value) + if indirectScopeValue.Kind() == reflect.Slice { + for j := 0; j < indirectScopeValue.Len(); j++ { + object := indirect(indirectScopeValue.Index(j)) + valueString := toString(getValueFromFields(object, relation.ForeignFieldNames)) + foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object) + } + } + for i := 0; i < resultsValue.Len(); i++ { result := resultsValue.Index(i) if indirectScopeValue.Kind() == reflect.Slice { - value := getValueFromFields(result, relation.AssociationForeignFieldNames) - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { + valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames)) + if objects, found := foreignFieldToObjects[valueString]; found { + for _, object := range objects { object.FieldByName(field.Name).Set(result) } } diff --git a/callback_save.go b/callback_save.go index 0d83200a..3b4e0589 100644 --- a/callback_save.go +++ b/callback_save.go @@ -21,9 +21,7 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea if v, ok := value.(string); ok { v = strings.ToLower(v) - if v == "false" || v != "skip" { - return false - } + return v == "true" } return true @@ -36,9 +34,11 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea if value, ok := scope.Get("gorm:save_associations"); ok { autoUpdate = checkTruth(value) autoCreate = autoUpdate + saveReference = autoUpdate } else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok { autoUpdate = checkTruth(value) autoCreate = autoUpdate + saveReference = autoUpdate } if value, ok := scope.Get("gorm:association_autoupdate"); ok { diff --git a/callback_update.go b/callback_update.go index 373bd726..f6ba0ffd 100644 --- a/callback_update.go +++ b/callback_update.go @@ -76,7 +76,9 @@ func updateCallback(scope *Scope) { for _, field := range scope.Fields() { if scope.changeableField(field) { if !field.IsPrimaryKey && field.IsNormal { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { for _, foreignKey := range relationship.ForeignDBNames { if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { diff --git a/dialect.go b/dialect.go index 9ee4dc6d..27b308af 100644 --- a/dialect.go +++ b/dialect.go @@ -72,6 +72,12 @@ func RegisterDialect(name string, dialect Dialect) { dialectsMap[name] = dialect } +// GetDialect gets the dialect for the specified dialect name +func GetDialect(name string) (dialect Dialect, ok bool) { + dialect, ok = dialectsMap[name] + return +} + // ParseFieldStructForDialect get field's sql data type var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { // Get redirected field type diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index b303ff9c..6c424bc1 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -1,12 +1,16 @@ package mssql import ( + "database/sql/driver" + "encoding/json" + "errors" "fmt" "reflect" "strconv" "strings" "time" + // Importing mssql driver package only in dialect file, otherwide not needed _ "github.com/denisenkom/go-mssqldb" "github.com/jinzhu/gorm" ) @@ -130,7 +134,14 @@ func (s mssql) RemoveIndex(tableName string, indexName string) error { } func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { - return false + var count int + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow(`SELECT count(*) + FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id + inner join information_schema.tables as I on I.TABLE_NAME = T.name + WHERE F.name = ? + AND T.Name = ? AND I.TABLE_CATALOG = ?;`, foreignKeyName, tableName, currentDatabase).Scan(&count) + return count > 0 } func (s mssql) HasTable(tableName string) bool { @@ -194,3 +205,27 @@ func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, st } return dialect.CurrentDatabase(), tableName } + +// JSON type to support easy handling of JSON data in character table fields +// using golang json.RawMessage for deferred decoding/encoding +type JSON struct { + json.RawMessage +} + +// Value get value of JSON +func (j JSON) Value() (driver.Value, error) { + if len(j.RawMessage) == 0 { + return nil, nil + } + return j.MarshalJSON() +} + +// Scan scan value into JSON +func (j *JSON) Scan(value interface{}) error { + str, ok := value.(string) + if !ok { + return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value)) + } + bytes := []byte(str) + return json.Unmarshal(bytes, j) +} diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 1d0dcb60..424e8bdc 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -4,11 +4,11 @@ import ( "database/sql" "database/sql/driver" - _ "github.com/lib/pq" - "github.com/lib/pq/hstore" "encoding/json" "errors" "fmt" + _ "github.com/lib/pq" + "github.com/lib/pq/hstore" ) type Hstore map[string]*string diff --git a/errors.go b/errors.go index da2cf13c..27c9a92d 100644 --- a/errors.go +++ b/errors.go @@ -6,7 +6,7 @@ import ( ) var ( - // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct + // ErrRecordNotFound record not found error, happens when only haven't find any matched data when looking up with a struct, finding a slice won't return this error ErrRecordNotFound = errors.New("record not found") // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL ErrInvalidSQL = errors.New("invalid SQL") diff --git a/main.go b/main.go index 69e676e8..17c75ed3 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "sync" "time" ) @@ -21,7 +22,7 @@ type DB struct { logMode int logger logger search *search - values map[string]interface{} + values sync.Map // global db parent *DB @@ -48,6 +49,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { } var source string var dbSQL SQLCommon + var ownDbSQL bool switch value := args[0].(type) { case string: @@ -59,14 +61,17 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { source = args[1].(string) } dbSQL, err = sql.Open(driver, source) + ownDbSQL = true case SQLCommon: dbSQL = value + ownDbSQL = false + default: + return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) } db = &DB{ db: dbSQL, logger: defaultLogger, - values: map[string]interface{}{}, callbacks: DefaultCallback, dialect: newDialect(dialect, dbSQL), } @@ -76,7 +81,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { } // Send a ping to make sure the database connection is alive. if d, ok := dbSQL.(*sql.DB); ok { - if err = d.Ping(); err != nil { + if err = d.Ping(); err != nil && ownDbSQL { d.Close() } } @@ -117,7 +122,7 @@ func (s *DB) CommonDB() SQLCommon { // Dialect get dialect func (s *DB) Dialect() Dialect { - return s.parent.dialect + return s.dialect } // Callback return `Callbacks` container, you could add/change/delete callbacks with it @@ -157,7 +162,7 @@ func (s *DB) HasBlockGlobalUpdate() bool { // SingularTable use singular table by default func (s *DB) SingularTable(enable bool) { - modelStructsMap = newModelStructsMap() + modelStructsMap = sync.Map{} s.parent.singularTable = enable } @@ -309,6 +314,11 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB { return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } +//Preloads preloads relations, don`t touch out +func (s *DB) Preloads(out interface{}) *DB { + return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db +} + // Scan scan value to a struct func (s *DB) Scan(dest interface{}) *DB { return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db @@ -482,6 +492,8 @@ func (s *DB) Begin() *DB { if db, ok := c.db.(sqlDb); ok && db != nil { tx, err := db.Begin() c.db = interface{}(tx).(SQLCommon) + + c.dialect.SetDB(c.db) c.AddError(err) } else { c.AddError(ErrCantStartTransaction) @@ -491,7 +503,8 @@ func (s *DB) Begin() *DB { // Commit commit a transaction func (s *DB) Commit() *DB { - if db, ok := s.db.(sqlTx); ok && db != nil { + var emptySQLTx *sql.Tx + if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { s.AddError(db.Commit()) } else { s.AddError(ErrInvalidTransaction) @@ -501,7 +514,8 @@ func (s *DB) Commit() *DB { // Rollback rollback a transaction func (s *DB) Rollback() *DB { - if db, ok := s.db.(sqlTx); ok && db != nil { + var emptySQLTx *sql.Tx + if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { s.AddError(db.Rollback()) } else { s.AddError(ErrInvalidTransaction) @@ -670,13 +684,13 @@ func (s *DB) Set(name string, value interface{}) *DB { // InstantSet instant set setting, will affect current db func (s *DB) InstantSet(name string, value interface{}) *DB { - s.values[name] = value + s.values.Store(name, value) return s } // Get get setting by name func (s *DB) Get(name string) (value interface{}, ok bool) { - value, ok = s.values[name] + value, ok = s.values.Load(name) return } @@ -740,15 +754,16 @@ func (s *DB) clone() *DB { parent: s.parent, logger: s.logger, logMode: s.logMode, - values: map[string]interface{}{}, Value: s.Value, Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate, + dialect: newDialect(s.dialect.GetName(), s.db), } - for key, value := range s.values { - db.values[key] = value - } + s.values.Range(func(k, v interface{}) bool { + db.values.Store(k, v) + return true + }) if s.search == nil { db.search = &search{limit: -1, offset: -1} diff --git a/main_test.go b/main_test.go index 66c46af0..94d2fa39 100644 --- a/main_test.go +++ b/main_test.go @@ -8,6 +8,7 @@ import ( "path/filepath" "reflect" "strconv" + "strings" "testing" "time" @@ -79,6 +80,22 @@ func OpenTestConnection() (db *gorm.DB, err error) { return } +func TestOpen_ReturnsError_WithBadArgs(t *testing.T) { + stringRef := "foo" + testCases := []interface{}{42, time.Now(), &stringRef} + for _, tc := range testCases { + t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) { + _, err := gorm.Open("postgresql", tc) + if err == nil { + t.Error("Should got error with invalid database source") + } + if !strings.HasPrefix(err.Error(), "invalid database source:") { + t.Errorf("Should got error starting with \"invalid database source:\", but got %q", err.Error()) + } + }) + } +} + func TestStringPrimaryKey(t *testing.T) { type UUIDStruct struct { ID string `gorm:"primary_key"` @@ -564,6 +581,60 @@ func TestJoins(t *testing.T) { } } +type JoinedIds struct { + UserID int64 `gorm:"column:id"` + BillingAddressID int64 `gorm:"column:id"` + EmailID int64 `gorm:"column:id"` +} + +func TestScanIdenticalColumnNames(t *testing.T) { + var user = User{ + Name: "joinsIds", + Email: "joinIds@example.com", + BillingAddress: Address{ + Address1: "One Park Place", + }, + Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, + } + DB.Save(&user) + + var users []JoinedIds + DB.Select("users.id, addresses.id, emails.id").Table("users"). + Joins("left join addresses on users.billing_address_id = addresses.id"). + Joins("left join emails on emails.user_id = users.id"). + Where("name = ?", "joinsIds").Scan(&users) + + if len(users) != 2 { + t.Fatal("should find two rows using left join") + } + + if user.Id != users[0].UserID { + t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[0].UserID) + } + if user.Id != users[1].UserID { + t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[1].UserID) + } + + if user.BillingAddressID.Int64 != users[0].BillingAddressID { + t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID) + } + if user.BillingAddressID.Int64 != users[1].BillingAddressID { + t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID) + } + + if users[0].EmailID == users[1].EmailID { + t.Errorf("Email ids should be unique. Got %d and %d", users[0].EmailID, users[1].EmailID) + } + + if int64(user.Emails[0].Id) != users[0].EmailID && int64(user.Emails[1].Id) != users[0].EmailID { + t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[0].EmailID) + } + + if int64(user.Emails[0].Id) != users[1].EmailID && int64(user.Emails[1].Id) != users[1].EmailID { + t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[1].EmailID) + } +} + func TestJoinsWithSelect(t *testing.T) { type result struct { Name string @@ -862,6 +933,94 @@ func TestOpenWithOneParameter(t *testing.T) { } } +func TestSaveAssociations(t *testing.T) { + db := DB.New() + deltaAddressCount := 0 + if err := db.Model(&Address{}).Count(&deltaAddressCount).Error; err != nil { + t.Errorf("failed to fetch address count") + t.FailNow() + } + + placeAddress := &Address{ + Address1: "somewhere on earth", + } + ownerAddress1 := &Address{ + Address1: "near place address", + } + ownerAddress2 := &Address{ + Address1: "address2", + } + db.Create(placeAddress) + + addressCountShouldBe := func(t *testing.T, expectedCount int) { + countFromDB := 0 + t.Helper() + err := db.Model(&Address{}).Count(&countFromDB).Error + if err != nil { + t.Error("failed to fetch address count") + } + if countFromDB != expectedCount { + t.Errorf("address count mismatch: %d", countFromDB) + } + } + addressCountShouldBe(t, deltaAddressCount+1) + + // owner address should be created, place address should be reused + place1 := &Place{ + PlaceAddressID: placeAddress.ID, + PlaceAddress: placeAddress, + OwnerAddress: ownerAddress1, + } + err := db.Create(place1).Error + if err != nil { + t.Errorf("failed to store place: %s", err.Error()) + } + addressCountShouldBe(t, deltaAddressCount+2) + + // owner address should be created again, place address should be reused + place2 := &Place{ + PlaceAddressID: placeAddress.ID, + PlaceAddress: &Address{ + ID: 777, + Address1: "address1", + }, + OwnerAddress: ownerAddress2, + OwnerAddressID: 778, + } + err = db.Create(place2).Error + if err != nil { + t.Errorf("failed to store place: %s", err.Error()) + } + addressCountShouldBe(t, deltaAddressCount+3) + + count := 0 + db.Model(&Place{}).Where(&Place{ + PlaceAddressID: placeAddress.ID, + OwnerAddressID: ownerAddress1.ID, + }).Count(&count) + if count != 1 { + t.Errorf("only one instance of (%d, %d) should be available, found: %d", + placeAddress.ID, ownerAddress1.ID, count) + } + + db.Model(&Place{}).Where(&Place{ + PlaceAddressID: placeAddress.ID, + OwnerAddressID: ownerAddress2.ID, + }).Count(&count) + if count != 1 { + t.Errorf("only one instance of (%d, %d) should be available, found: %d", + placeAddress.ID, ownerAddress2.ID, count) + } + + db.Model(&Place{}).Where(&Place{ + PlaceAddressID: placeAddress.ID, + }).Count(&count) + if count != 2 { + t.Errorf("two instances of (%d) should be available, found: %d", + placeAddress.ID, count) + } +} + func TestBlockGlobalUpdate(t *testing.T) { db := DB.New() db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"}) diff --git a/migration_test.go b/migration_test.go index 7c694485..3fb06648 100644 --- a/migration_test.go +++ b/migration_test.go @@ -118,6 +118,14 @@ type Company struct { Owner *User `sql:"-"` } +type Place struct { + Id int64 + PlaceAddressID int + PlaceAddress *Address `gorm:"save_associations:false"` + OwnerAddressID int + OwnerAddress *Address `gorm:"save_associations:true"` +} + type EncryptedData []byte func (data *EncryptedData) Scan(value interface{}) error { @@ -284,7 +292,7 @@ func runMigration() { DB.Exec(fmt.Sprintf("drop table %v;", table)) } - values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}} + values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}} for _, value := range values { DB.DropTable(value) } @@ -398,6 +406,53 @@ func TestAutoMigration(t *testing.T) { } } +func TestCreateAndAutomigrateTransaction(t *testing.T) { + tx := DB.Begin() + + func() { + type Bar struct { + ID uint + } + DB.DropTableIfExists(&Bar{}) + + if ok := DB.HasTable("bars"); ok { + t.Errorf("Table should not exist, but does") + } + + if ok := tx.HasTable("bars"); ok { + t.Errorf("Table should not exist, but does") + } + }() + + func() { + type Bar struct { + Name string + } + err := tx.CreateTable(&Bar{}).Error + + if err != nil { + t.Errorf("Should have been able to create the table, but couldn't: %s", err) + } + + if ok := tx.HasTable(&Bar{}); !ok { + t.Errorf("The transaction should be able to see the table") + } + }() + + func() { + type Bar struct { + Stuff string + } + + err := tx.AutoMigrate(&Bar{}).Error + if err != nil { + t.Errorf("Should have been able to alter the table, but couldn't") + } + }() + + tx.Rollback() +} + type MultipleIndexes struct { ID int64 UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` diff --git a/model_struct.go b/model_struct.go index ca859a43..12860e67 100644 --- a/model_struct.go +++ b/model_struct.go @@ -17,28 +17,7 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { return defaultTableName } -type safeModelStructsMap struct { - m map[reflect.Type]*ModelStruct - l *sync.RWMutex -} - -func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newModelStructsMap() *safeModelStructsMap { - return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)} -} - -var modelStructsMap = newModelStructsMap() +var modelStructsMap sync.Map // ModelStruct model definition type ModelStruct struct { @@ -48,14 +27,14 @@ type ModelStruct struct { defaultTableName string } -// TableName get model's table name +// TableName returns model's table name func (s *ModelStruct) TableName(db *DB) string { if s.defaultTableName == "" && db != nil && s.ModelType != nil { // Set default table name if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { s.defaultTableName = tabler.TableName() } else { - tableName := ToDBName(s.ModelType.Name()) + tableName := ToTableName(s.ModelType.Name()) if db == nil || !db.parent.singularTable { tableName = inflection.Plural(tableName) } @@ -153,7 +132,7 @@ type Relationship struct { func getForeignField(column string, fields []*StructField) *StructField { for _, field := range fields { - if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) { + if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) { return field } } @@ -179,8 +158,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Get Cached model struct - if value := modelStructsMap.Get(reflectType); value != nil { - return value + if value, ok := modelStructsMap.Load(reflectType); ok && value != nil { + return value.(*ModelStruct) } modelStruct.ModelType = reflectType @@ -317,7 +296,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { // if defined join table's foreign key relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) } else { - defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName + defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) } } @@ -348,7 +327,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) } else { // join table foreign keys for association - joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) } } @@ -356,7 +335,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, many2many, reflectType, elemType) + joinTableHandler.Setup(relationship, ToTableName(many2many), reflectType, elemType) relationship.JoinTableHandler = &joinTableHandler field.Relationship = relationship } else { @@ -614,7 +593,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if value, ok := field.TagSettingsGet("COLUMN"); ok { field.DBName = value } else { - field.DBName = ToDBName(fieldStruct.Name) + field.DBName = ToColumnName(fieldStruct.Name) } modelStruct.StructFields = append(modelStruct.StructFields, field) @@ -628,7 +607,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } - modelStructsMap.Set(reflectType, &modelStruct) + modelStructsMap.Store(reflectType, &modelStruct) return &modelStruct } diff --git a/naming.go b/naming.go new file mode 100644 index 00000000..6b0a4fdd --- /dev/null +++ b/naming.go @@ -0,0 +1,124 @@ +package gorm + +import ( + "bytes" + "strings" +) + +// Namer is a function type which is given a string and return a string +type Namer func(string) string + +// NamingStrategy represents naming strategies +type NamingStrategy struct { + DB Namer + Table Namer + Column Namer +} + +// TheNamingStrategy is being initialized with defaultNamingStrategy +var TheNamingStrategy = &NamingStrategy{ + DB: defaultNamer, + Table: defaultNamer, + Column: defaultNamer, +} + +// AddNamingStrategy sets the naming strategy +func AddNamingStrategy(ns *NamingStrategy) { + if ns.DB == nil { + ns.DB = defaultNamer + } + if ns.Table == nil { + ns.Table = defaultNamer + } + if ns.Column == nil { + ns.Column = defaultNamer + } + TheNamingStrategy = ns +} + +// DBName alters the given name by DB +func (ns *NamingStrategy) DBName(name string) string { + return ns.DB(name) +} + +// TableName alters the given name by Table +func (ns *NamingStrategy) TableName(name string) string { + return ns.Table(name) +} + +// ColumnName alters the given name by Column +func (ns *NamingStrategy) ColumnName(name string) string { + return ns.Column(name) +} + +// ToDBName convert string to db name +func ToDBName(name string) string { + return TheNamingStrategy.DBName(name) +} + +// ToTableName convert string to table name +func ToTableName(name string) string { + return TheNamingStrategy.TableName(name) +} + +// ToColumnName convert string to db name +func ToColumnName(name string) string { + return TheNamingStrategy.ColumnName(name) +} + +var smap = newSafeMap() + +func defaultNamer(name string) string { + const ( + lower = false + upper = true + ) + + if v := smap.Get(name); v != "" { + return v + } + + if name == "" { + return "" + } + + var ( + value = commonInitialismsReplacer.Replace(name) + buf = bytes.NewBufferString("") + lastCase, currCase, nextCase, nextNumber bool + ) + + for i, v := range value[:len(value)-1] { + nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z') + nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9') + + if i > 0 { + if currCase == upper { + if lastCase == upper && (nextCase == upper || nextNumber == upper) { + buf.WriteRune(v) + } else { + if value[i-1] != '_' && value[i+1] != '_' { + buf.WriteRune('_') + } + buf.WriteRune(v) + } + } else { + buf.WriteRune(v) + if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { + buf.WriteRune('_') + } + } + } else { + currCase = upper + buf.WriteRune(v) + } + lastCase = currCase + currCase = nextCase + } + + buf.WriteByte(value[len(value)-1]) + + s := strings.ToLower(buf.String()) + smap.Set(name, s) + return s +} diff --git a/naming_test.go b/naming_test.go new file mode 100644 index 00000000..0c6f7713 --- /dev/null +++ b/naming_test.go @@ -0,0 +1,69 @@ +package gorm_test + +import ( + "testing" + + "github.com/jinzhu/gorm" +) + +func TestTheNamingStrategy(t *testing.T) { + + cases := []struct { + name string + namer gorm.Namer + expected string + }{ + {name: "auth", expected: "auth", namer: gorm.TheNamingStrategy.DB}, + {name: "userRestrictions", expected: "user_restrictions", namer: gorm.TheNamingStrategy.Table}, + {name: "clientID", expected: "client_id", namer: gorm.TheNamingStrategy.Column}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := c.namer(c.name) + if result != c.expected { + t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) + } + }) + } + +} + +func TestNamingStrategy(t *testing.T) { + + dbNameNS := func(name string) string { + return "db_" + name + } + tableNameNS := func(name string) string { + return "tbl_" + name + } + columnNameNS := func(name string) string { + return "col_" + name + } + + ns := &gorm.NamingStrategy{ + DB: dbNameNS, + Table: tableNameNS, + Column: columnNameNS, + } + + cases := []struct { + name string + namer gorm.Namer + expected string + }{ + {name: "auth", expected: "db_auth", namer: ns.DB}, + {name: "user", expected: "tbl_user", namer: ns.Table}, + {name: "password", expected: "col_password", namer: ns.Column}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := c.namer(c.name) + if result != c.expected { + t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) + } + }) + } + +} diff --git a/preload_test.go b/preload_test.go index 311ad0be..1db625c9 100644 --- a/preload_test.go +++ b/preload_test.go @@ -123,6 +123,31 @@ func TestAutoPreload(t *testing.T) { } } +func TestAutoPreloadFalseDoesntPreload(t *testing.T) { + user1 := getPreloadUser("auto_user1") + DB.Save(user1) + + preloadDB := DB.Set("gorm:auto_preload", false).Where("role = ?", "Preload") + var user User + preloadDB.Find(&user) + + if user.BillingAddress.Address1 != "" { + t.Error("AutoPreload was set to fasle, but still fetched data") + } + + user2 := getPreloadUser("auto_user2") + DB.Save(user2) + + var users []User + preloadDB.Find(&users) + + for _, user := range users { + if user.BillingAddress.Address1 != "" { + t.Error("AutoPreload was set to fasle, but still fetched data") + } + } +} + func TestNestedPreload1(t *testing.T) { type ( Level1 struct { diff --git a/query_test.go b/query_test.go index fac7d4d8..15bf8b3c 100644 --- a/query_test.go +++ b/query_test.go @@ -181,17 +181,17 @@ func TestSearchWithPlainSQL(t *testing.T) { scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users) if len(users) != 2 { - t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users)) + t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users)) } scopedb.Where("birthday > ?", "2002-10-10").Find(&users) if len(users) != 2 { - t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users)) + t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users)) } scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) if len(users) != 1 { - t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) + t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) } DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) @@ -532,28 +532,28 @@ func TestNot(t *testing.T) { DB.Table("users").Where("name = ?", "user3").Count(&name3Count) DB.Not("name", "user3").Find(&users4) if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not("name = ?", "user3").Find(&users4) if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not("name <> ?", "user3").Find(&users4) if len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(User{Name: "user3"}).Find(&users5) if len(users1)-len(users5) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6) if len(users1)-len(users6) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) @@ -563,14 +563,14 @@ func TestNot(t *testing.T) { DB.Not("name", []string{"user3"}).Find(&users8) if len(users1)-len(users8) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } var name2Count int64 DB.Table("users").Where("name = ?", "user2").Count(&name2Count) DB.Not("name", []string{"user3", "user2"}).Find(&users9) if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } } diff --git a/scope.go b/scope.go index 4f91cfb5..fa521ca2 100644 --- a/scope.go +++ b/scope.go @@ -63,7 +63,7 @@ func (scope *Scope) SQLDB() SQLCommon { // Dialect get dialect func (scope *Scope) Dialect() Dialect { - return scope.db.parent.dialect + return scope.db.dialect } // Quote used to quote string to escape them for database @@ -134,7 +134,7 @@ func (scope *Scope) Fields() []*Field { // FieldByName find `gorm.Field` with field name or db name func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { var ( - dbName = ToDBName(name) + dbName = ToColumnName(name) mostMatchedField *Field ) @@ -486,8 +486,10 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { values[index] = &ignored selectFields = fields + offset := 0 if idx, ok := selectedColumnsMap[column]; ok { - selectFields = selectFields[idx+1:] + offset = idx + 1 + selectFields = selectFields[offset:] } for fieldIndex, field := range selectFields { @@ -501,7 +503,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { resetFields[index] = field } - selectedColumnsMap[column] = fieldIndex + selectedColumnsMap[column] = offset + fieldIndex if field.IsNormal { break @@ -586,10 +588,10 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) scope.Err(fmt.Errorf("invalid query condition: %v", value)) return } - + scopeQuotedTableName := newScope.QuotedTableName() for _, field := range newScope.Fields() { if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) + sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) } } return strings.Join(sqls, " AND ") @@ -692,12 +694,12 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) buff := bytes.NewBuffer([]byte{}) i := 0 - for pos := range str { + for pos, char := range str { if str[pos] == '?' { buff.WriteString(replacements[i]) i++ } else { - buff.WriteByte(str[pos]) + buff.WriteRune(char) } } @@ -880,7 +882,7 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string switch reflectValue.Kind() { case reflect.Map: for _, key := range reflectValue.MapKeys() { - attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() + attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() } default: for _, field := range (&Scope{Value: values}).Fields() { @@ -907,7 +909,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin results[field.DBName] = value } else { err := field.Set(value) - if field.IsNormal { + if field.IsNormal && !field.IsIgnored { hasUpdate = true if err == ErrUnaddressable { results[field.DBName] = value @@ -1215,12 +1217,18 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on } func (scope *Scope) removeForeignKey(field string, dest string) { - keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest) - + keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return } - var query = `ALTER TABLE %s DROP CONSTRAINT %s;` + var mysql mysql + var query string + if scope.Dialect().GetName() == mysql.GetName() { + query = `ALTER TABLE %s DROP FOREIGN KEY %s;` + } else { + query = `ALTER TABLE %s DROP CONSTRAINT %s;` + } + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() } diff --git a/utils.go b/utils.go index dfaae939..e58e57a5 100644 --- a/utils.go +++ b/utils.go @@ -1,7 +1,6 @@ package gorm import ( - "bytes" "database/sql/driver" "fmt" "reflect" @@ -26,8 +25,8 @@ var NowFunc = func() time.Time { var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialismsReplacer *strings.Replacer -var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) -var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`) +var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`) +var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`) func init() { var commonInitialismsForReplacer []string @@ -58,64 +57,6 @@ func newSafeMap() *safeMap { return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} } -var smap = newSafeMap() - -type strCase bool - -const ( - lower strCase = false - upper strCase = true -) - -// ToDBName convert string to db name -func ToDBName(name string) string { - if v := smap.Get(name); v != "" { - return v - } - - if name == "" { - return "" - } - - var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase strCase - ) - - for i, v := range value[:len(value)-1] { - nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z') - if i > 0 { - if currCase == upper { - if lastCase == upper && nextCase == upper { - buf.WriteRune(v) - } else { - if value[i-1] != '_' && value[i+1] != '_' { - buf.WriteRune('_') - } - buf.WriteRune(v) - } - } else { - buf.WriteRune(v) - if i == len(value)-2 && nextCase == upper { - buf.WriteRune('_') - } - } - } else { - currCase = upper - buf.WriteRune(v) - } - lastCase = currCase - currCase = nextCase - } - - buf.WriteByte(value[len(value)-1]) - - s := strings.ToLower(buf.String()) - smap.Set(name, s) - return s -} - // SQL expression type expr struct { expr string @@ -265,7 +206,7 @@ func getValueFromFields(value reflect.Value, fieldNames []string) (results []int // as FieldByName could panic if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { for _, fieldName := range fieldNames { - if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { + if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() { result := fieldValue.Interface() if r, ok := result.(driver.Valuer); ok { result, _ = r.Value() diff --git a/utils_test.go b/utils_test.go deleted file mode 100644 index 152296d2..00000000 --- a/utils_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package gorm_test - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -func TestToDBNameGenerateFriendlyName(t *testing.T) { - var maps = map[string]string{ - "": "", - "X": "x", - "ThisIsATest": "this_is_a_test", - "PFAndESI": "pf_and_esi", - "AbcAndJkl": "abc_and_jkl", - "EmployeeID": "employee_id", - "SKU_ID": "sku_id", - "FieldX": "field_x", - "HTTPAndSMTP": "http_and_smtp", - "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", - "UUID": "uuid", - "HTTPURL": "http_url", - "HTTP_URL": "http_url", - "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", - } - - for key, value := range maps { - if gorm.ToDBName(key) != value { - t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key)) - } - } -}