From 52c5c8127cf4aeffde3e0aa9222640832075a90f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sa=C3=BAl=20Ortega?= Date: Thu, 15 Mar 2018 09:35:31 -0500 Subject: [PATCH 01/28] Support for UTF8 names on DB (#1793) --- scope.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scope.go b/scope.go index 150ac710..2f39e073 100644 --- a/scope.go +++ b/scope.go @@ -692,12 +692,12 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) buff := bytes.NewBuffer([]byte{}) i := 0 - for pos := range str { + for pos, char := range str { if str[pos] == '?' { buff.WriteString(replacements[i]) i++ } else { - buff.WriteByte(str[pos]) + buff.WriteRune(char) } } From 919c6db4f854e4feaae94202ae29da4e3779de49 Mon Sep 17 00:00:00 2001 From: Giuseppe Date: Mon, 16 Apr 2018 16:18:51 +0200 Subject: [PATCH 02/28] Do not panic if Begin().Error was ignored (#1830) --- main.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index c26e05c8..ffee4ec6 100644 --- a/main.go +++ b/main.go @@ -491,7 +491,8 @@ func (s *DB) Begin() *DB { // Commit commit a transaction func (s *DB) Commit() *DB { - if db, ok := s.db.(sqlTx); ok && db != nil { + var emptySQLTx *sql.Tx + if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { s.AddError(db.Commit()) } else { s.AddError(ErrInvalidTransaction) From 6842b49a1ad0feb6b93be830fe63a682cf853ada Mon Sep 17 00:00:00 2001 From: Shane Date: Mon, 16 Apr 2018 07:20:02 -0700 Subject: [PATCH 03/28] fix scope.removeForeignKey method (#1841) --- scope.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope.go b/scope.go index 2f39e073..397ccf0b 100644 --- a/scope.go +++ b/scope.go @@ -1215,7 +1215,7 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on } func (scope *Scope) removeForeignKey(field string, dest string) { - keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest) + keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return From 35efe68ba71d571e64ccd1ee62830c30a53ed967 Mon Sep 17 00:00:00 2001 From: Daniel McDonald Date: Wed, 2 May 2018 07:37:51 -0700 Subject: [PATCH 04/28] add simple input validation on gorm.Open function (#1855) Simply check if the passed-in database source meets the expected types and, if not, early return with error. --- main.go | 2 ++ main_test.go | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/main.go b/main.go index ffee4ec6..c8a43e8c 100644 --- a/main.go +++ b/main.go @@ -61,6 +61,8 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { dbSQL, err = sql.Open(driver, source) case SQLCommon: dbSQL = value + default: + return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) } db = &DB{ diff --git a/main_test.go b/main_test.go index 66c46af0..265e0be7 100644 --- a/main_test.go +++ b/main_test.go @@ -8,6 +8,7 @@ import ( "path/filepath" "reflect" "strconv" + "strings" "testing" "time" @@ -79,6 +80,22 @@ func OpenTestConnection() (db *gorm.DB, err error) { return } +func TestOpen_ReturnsError_WithBadArgs(t *testing.T) { + stringRef := "foo" + testCases := []interface{}{42, time.Now(), &stringRef} + for _, tc := range testCases { + t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) { + _, err := gorm.Open("postgresql", tc) + if err == nil { + t.Error("Should got error with invalid database source") + } + if !strings.HasPrefix(err.Error(), "invalid database source:") { + t.Errorf("Should got error starting with \"invalid database source:\", but got %q", err.Error()) + } + }) + } +} + func TestStringPrimaryKey(t *testing.T) { type UUIDStruct struct { ID string `gorm:"primary_key"` From 9044197ef935c0969d94cbcfba55ccb94d269bed Mon Sep 17 00:00:00 2001 From: Illya Busigin Date: Wed, 2 May 2018 09:38:52 -0500 Subject: [PATCH 05/28] Adding GetDialect function (#1869) --- dialect.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dialect.go b/dialect.go index 5f6439c1..506a6e86 100644 --- a/dialect.go +++ b/dialect.go @@ -72,6 +72,12 @@ func RegisterDialect(name string, dialect Dialect) { dialectsMap[name] = dialect } +// GetDialect gets the dialect for the specified dialect name +func GetDialect(name string) (dialect Dialect, ok bool) { + dialect, ok = dialectsMap[name] + return +} + // ParseFieldStructForDialect get field's sql data type var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { // Get redirected field type From a58b98acee2f3bf213b2cb0f1fe1468f236c9aec Mon Sep 17 00:00:00 2001 From: lrita Date: Sat, 12 May 2018 14:28:15 +0800 Subject: [PATCH 06/28] Do not panic if Begin().Error was ignored (#1830) (#1881) --- main.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index c8a43e8c..25c3a06b 100644 --- a/main.go +++ b/main.go @@ -504,7 +504,8 @@ func (s *DB) Commit() *DB { // Rollback rollback a transaction func (s *DB) Rollback() *DB { - if db, ok := s.db.(sqlTx); ok && db != nil { + var emptySQLTx *sql.Tx + if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { s.AddError(db.Rollback()) } else { s.AddError(ErrInvalidTransaction) From 82eb9f8a5bbb5e6b929d2f0ae5b934e6a253f94e Mon Sep 17 00:00:00 2001 From: Olga Kleitsa Date: Sat, 12 May 2018 09:29:00 +0300 Subject: [PATCH 07/28] included actual sql query to discover fi foreign key with the same name exists in a specific table of the database in use (#1896) --- dialects/mssql/mssql.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index e0606465..a8d3c45a 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -130,7 +130,14 @@ func (s mssql) RemoveIndex(tableName string, indexName string) error { } func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { - return false + var count int + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow(`SELECT count(*) + FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id + inner join information_schema.tables as I on I.TABLE_NAME = T.name + WHERE F.name = ? + AND T.Name = ? AND I.TABLE_CATALOG = ?;`, foreignKeyName, tableName, currentDatabase).Scan(&count) + return count > 0 } func (s mssql) HasTable(tableName string) bool { From 1907bff3732cb4c612e4118137d8f3c8829cc8c6 Mon Sep 17 00:00:00 2001 From: ia Date: Mon, 25 Jun 2018 07:06:58 +0200 Subject: [PATCH 08/28] all: gofmt (#1956) Run standard gofmt command on project root. - go version go1.10.3 darwin/amd64 Signed-off-by: ia --- dialects/postgres/postgres.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 1d0dcb60..424e8bdc 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -4,11 +4,11 @@ import ( "database/sql" "database/sql/driver" - _ "github.com/lib/pq" - "github.com/lib/pq/hstore" "encoding/json" "errors" "fmt" + _ "github.com/lib/pq" + "github.com/lib/pq/hstore" ) type Hstore map[string]*string From 0fd395ab37aefd2d50854f0556a4311dccc6f45a Mon Sep 17 00:00:00 2001 From: Masaki Yoshida Date: Mon, 25 Jun 2018 14:07:53 +0900 Subject: [PATCH 09/28] Fix ToDBName (#1941) Don't place '_' before number. - NG: SHA256Hash -> sha_256_hash - OK: SHA256Hash -> sha256_hash --- utils.go | 12 +++++++----- utils_test.go | 3 +++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/utils.go b/utils.go index dfaae939..99b532c5 100644 --- a/utils.go +++ b/utils.go @@ -78,16 +78,18 @@ func ToDBName(name string) string { } var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase strCase + value = commonInitialismsReplacer.Replace(name) + buf = bytes.NewBufferString("") + lastCase, currCase, nextCase, nextNumber strCase ) for i, v := range value[:len(value)-1] { nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z') + nextNumber = strCase(value[i+1] >= '0' && value[i+1] <= '9') + if i > 0 { if currCase == upper { - if lastCase == upper && nextCase == upper { + if lastCase == upper && (nextCase == upper || nextNumber == upper) { buf.WriteRune(v) } else { if value[i-1] != '_' && value[i+1] != '_' { @@ -97,7 +99,7 @@ func ToDBName(name string) string { } } else { buf.WriteRune(v) - if i == len(value)-2 && nextCase == upper { + if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { buf.WriteRune('_') } } diff --git a/utils_test.go b/utils_test.go index 152296d2..086c4450 100644 --- a/utils_test.go +++ b/utils_test.go @@ -15,6 +15,9 @@ func TestToDBNameGenerateFriendlyName(t *testing.T) { "AbcAndJkl": "abc_and_jkl", "EmployeeID": "employee_id", "SKU_ID": "sku_id", + "UTF8": "utf8", + "Level1": "level1", + "SHA256Hash": "sha256_hash", "FieldX": "field_x", "HTTPAndSMTP": "http_and_smtp", "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", From dbb25e94879f463c699430a74d29c9557e15a60f Mon Sep 17 00:00:00 2001 From: Louis Brauer Date: Fri, 27 Jul 2018 01:30:57 +0200 Subject: [PATCH 10/28] Adding json type for mssql dialect, similar to postgres.Jsonb (#1934) * Adding json type for mssql dialect, similar to postgres.Jsonb * Adding proper comments --- dialects/mssql/mssql.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index a8d3c45a..731721cb 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -1,12 +1,16 @@ package mssql import ( + "database/sql/driver" + "encoding/json" + "errors" "fmt" "reflect" "strconv" "strings" "time" + // Importing mssql driver package only in dialect file, otherwide not needed _ "github.com/denisenkom/go-mssqldb" "github.com/jinzhu/gorm" ) @@ -201,3 +205,27 @@ func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, st } return dialect.CurrentDatabase(), tableName } + +// JSON type to support easy handling of JSON data in character table fields +// using golang json.RawMessage for deferred decoding/encoding +type JSON struct { + json.RawMessage +} + +// Value get value of JSON +func (j JSON) Value() (driver.Value, error) { + if len(j.RawMessage) == 0 { + return nil, nil + } + return j.MarshalJSON() +} + +// Scan scan value into JSON +func (j *JSON) Scan(value interface{}) error { + str, ok := value.(string) + if !ok { + return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value)) + } + bytes := []byte(str) + return json.Unmarshal(bytes, j) +} From ac3ec858a6c375a466f613c86b053726abbe3755 Mon Sep 17 00:00:00 2001 From: Kevin Date: Thu, 26 Jul 2018 19:35:53 -0400 Subject: [PATCH 11/28] Edit DB.clone(), DB.Dialect(), and Scope.Dialect() preserve transactions (#1939) * Edit DB.clone(), DB.Dialect(), and Scope.Dialect() preserve transactions. * Adds a test case for tables creations and autoMigrate in the same transaction. --- main.go | 5 ++++- migration_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ scope.go | 2 +- 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/main.go b/main.go index 25c3a06b..3a5d6b0c 100644 --- a/main.go +++ b/main.go @@ -119,7 +119,7 @@ func (s *DB) CommonDB() SQLCommon { // Dialect get dialect func (s *DB) Dialect() Dialect { - return s.parent.dialect + return s.dialect } // Callback return `Callbacks` container, you could add/change/delete callbacks with it @@ -484,6 +484,8 @@ func (s *DB) Begin() *DB { if db, ok := c.db.(sqlDb); ok && db != nil { tx, err := db.Begin() c.db = interface{}(tx).(SQLCommon) + + c.dialect.SetDB(c.db) c.AddError(err) } else { c.AddError(ErrCantStartTransaction) @@ -748,6 +750,7 @@ func (s *DB) clone() *DB { Value: s.Value, Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate, + dialect: newDialect(s.dialect.GetName(), s.db), } for key, value := range s.values { diff --git a/migration_test.go b/migration_test.go index 7c694485..78555dcc 100644 --- a/migration_test.go +++ b/migration_test.go @@ -398,6 +398,53 @@ func TestAutoMigration(t *testing.T) { } } +func TestCreateAndAutomigrateTransaction(t *testing.T) { + tx := DB.Begin() + + func() { + type Bar struct { + ID uint + } + DB.DropTableIfExists(&Bar{}) + + if ok := DB.HasTable("bars"); ok { + t.Errorf("Table should not exist, but does") + } + + if ok := tx.HasTable("bars"); ok { + t.Errorf("Table should not exist, but does") + } + }() + + func() { + type Bar struct { + Name string + } + err := tx.CreateTable(&Bar{}).Error + + if err != nil { + t.Errorf("Should have been able to create the table, but couldn't: %s", err) + } + + if ok := tx.HasTable(&Bar{}); !ok { + t.Errorf("The transaction should be able to see the table") + } + }() + + func() { + type Bar struct { + Stuff string + } + + err := tx.AutoMigrate(&Bar{}).Error + if err != nil { + t.Errorf("Should have been able to alter the table, but couldn't") + } + }() + + tx.Rollback() +} + type MultipleIndexes struct { ID int64 UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` diff --git a/scope.go b/scope.go index 397ccf0b..5eb98963 100644 --- a/scope.go +++ b/scope.go @@ -63,7 +63,7 @@ func (scope *Scope) SQLDB() SQLCommon { // Dialect get dialect func (scope *Scope) Dialect() Dialect { - return scope.db.parent.dialect + return scope.db.dialect } // Quote used to quote string to escape them for database From 588e2eef5d9c33b11ee52895ad5cdfab0d6648e6 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Fri, 27 Jul 2018 07:38:02 +0800 Subject: [PATCH 12/28] Fix typo in query_test (#1977) --- query_test.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/query_test.go b/query_test.go index fac7d4d8..15bf8b3c 100644 --- a/query_test.go +++ b/query_test.go @@ -181,17 +181,17 @@ func TestSearchWithPlainSQL(t *testing.T) { scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users) if len(users) != 2 { - t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users)) + t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users)) } scopedb.Where("birthday > ?", "2002-10-10").Find(&users) if len(users) != 2 { - t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users)) + t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users)) } scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) if len(users) != 1 { - t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) + t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) } DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) @@ -532,28 +532,28 @@ func TestNot(t *testing.T) { DB.Table("users").Where("name = ?", "user3").Count(&name3Count) DB.Not("name", "user3").Find(&users4) if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not("name = ?", "user3").Find(&users4) if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not("name <> ?", "user3").Find(&users4) if len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(User{Name: "user3"}).Find(&users5) if len(users1)-len(users5) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6) if len(users1)-len(users6) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) @@ -563,14 +563,14 @@ func TestNot(t *testing.T) { DB.Not("name", []string{"user3"}).Find(&users8) if len(users1)-len(users8) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } var name2Count int64 DB.Table("users").Where("name = ?", "user2").Count(&name2Count) DB.Not("name", []string{"user3", "user2"}).Find(&users9) if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } } From d68403b29dbf3086b2335f6381545462d96808bc Mon Sep 17 00:00:00 2001 From: antness Date: Fri, 27 Jul 2018 02:43:09 +0300 Subject: [PATCH 13/28] do not close wrapped *sql.DB (#1985) --- main.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index 3a5d6b0c..de6ce428 100644 --- a/main.go +++ b/main.go @@ -48,6 +48,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { } var source string var dbSQL SQLCommon + var ownDbSQL bool switch value := args[0].(type) { case string: @@ -59,8 +60,10 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { source = args[1].(string) } dbSQL, err = sql.Open(driver, source) + ownDbSQL = true case SQLCommon: dbSQL = value + ownDbSQL = false default: return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) } @@ -78,7 +81,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { } // Send a ping to make sure the database connection is alive. if d, ok := dbSQL.(*sql.DB); ok { - if err = d.Ping(); err != nil { + if err = d.Ping(); err != nil && ownDbSQL { d.Close() } } From 409121d9e394922787885b001d148a05e3a42b6c Mon Sep 17 00:00:00 2001 From: Alexey <10kdmg@gmail.com> Date: Fri, 27 Jul 2018 02:43:49 +0300 Subject: [PATCH 14/28] Fixed mysql query syntax for FK removal (#1993) --- scope.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/scope.go b/scope.go index 5eb98963..a05c1d61 100644 --- a/scope.go +++ b/scope.go @@ -1216,11 +1216,17 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on func (scope *Scope) removeForeignKey(field string, dest string) { keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return } - var query = `ALTER TABLE %s DROP CONSTRAINT %s;` + var mysql mysql + var query string + if scope.Dialect().GetName() == mysql.GetName() { + query = `ALTER TABLE %s DROP FOREIGN KEY %s;` + } else { + query = `ALTER TABLE %s DROP CONSTRAINT %s;` + } + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() } From 0e04d414d59f3154d700692bda0d7649d0e101b3 Mon Sep 17 00:00:00 2001 From: Artemij Shepelev Date: Sun, 19 Aug 2018 02:09:21 +0300 Subject: [PATCH 15/28] Race fix. Changes modelStructsMap implementation from map with mutex to sync.Map (#2022) * fix (https://github.com/jinzhu/gorm/issues/1407) * changed map with mutex to sync.Map (https://github.com/jinzhu/gorm/issues/1407) * removed newModelStructsMap func * commit to rerun pipeline, comment changed --- main.go | 3 ++- model_struct.go | 31 +++++-------------------------- 2 files changed, 7 insertions(+), 27 deletions(-) diff --git a/main.go b/main.go index de6ce428..993e19b1 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "sync" "time" ) @@ -162,7 +163,7 @@ func (s *DB) HasBlockGlobalUpdate() bool { // SingularTable use singular table by default func (s *DB) SingularTable(enable bool) { - modelStructsMap = newModelStructsMap() + modelStructsMap = sync.Map{} s.parent.singularTable = enable } diff --git a/model_struct.go b/model_struct.go index f571e2e8..8506fe87 100644 --- a/model_struct.go +++ b/model_struct.go @@ -17,28 +17,7 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { return defaultTableName } -type safeModelStructsMap struct { - m map[reflect.Type]*ModelStruct - l *sync.RWMutex -} - -func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newModelStructsMap() *safeModelStructsMap { - return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)} -} - -var modelStructsMap = newModelStructsMap() +var modelStructsMap sync.Map // ModelStruct model definition type ModelStruct struct { @@ -48,7 +27,7 @@ type ModelStruct struct { defaultTableName string } -// TableName get model's table name +// TableName returns model's table name func (s *ModelStruct) TableName(db *DB) string { if s.defaultTableName == "" && db != nil && s.ModelType != nil { // Set default table name @@ -152,8 +131,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Get Cached model struct - if value := modelStructsMap.Get(reflectType); value != nil { - return value + if value, ok := modelStructsMap.Load(reflectType); ok && value != nil { + return value.(*ModelStruct) } modelStruct.ModelType = reflectType @@ -601,7 +580,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } - modelStructsMap.Set(reflectType, &modelStruct) + modelStructsMap.Store(reflectType, &modelStruct) return &modelStruct } From 31ec9255cdc16482f5bef2ceb996ba75ba750a8a Mon Sep 17 00:00:00 2001 From: Elliott <617942+ellman121@users.noreply.github.com> Date: Sun, 19 Aug 2018 01:11:27 +0200 Subject: [PATCH 16/28] Setting gorm:auto_preload to false now prevents preloading (#2031) --- callback_query_preload.go | 10 ++++++++-- preload_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/callback_query_preload.go b/callback_query_preload.go index 30f6b585..481bfbe3 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -14,8 +14,14 @@ func preloadCallback(scope *Scope) { return } - if _, ok := scope.Get("gorm:auto_preload"); ok { - autoPreload(scope) + if ap, ok := scope.Get("gorm:auto_preload"); ok { + // If gorm:auto_preload IS NOT a bool then auto preload. + // Else if it IS a bool, use the value + if apb, ok := ap.(bool); !ok { + autoPreload(scope) + } else if apb { + autoPreload(scope) + } } if scope.Search.preload == nil || scope.HasError() { diff --git a/preload_test.go b/preload_test.go index 311ad0be..1db625c9 100644 --- a/preload_test.go +++ b/preload_test.go @@ -123,6 +123,31 @@ func TestAutoPreload(t *testing.T) { } } +func TestAutoPreloadFalseDoesntPreload(t *testing.T) { + user1 := getPreloadUser("auto_user1") + DB.Save(user1) + + preloadDB := DB.Set("gorm:auto_preload", false).Where("role = ?", "Preload") + var user User + preloadDB.Find(&user) + + if user.BillingAddress.Address1 != "" { + t.Error("AutoPreload was set to fasle, but still fetched data") + } + + user2 := getPreloadUser("auto_user2") + DB.Save(user2) + + var users []User + preloadDB.Find(&users) + + for _, user := range users { + if user.BillingAddress.Address1 != "" { + t.Error("AutoPreload was set to fasle, but still fetched data") + } + } +} + func TestNestedPreload1(t *testing.T) { type ( Level1 struct { From 53995294ef73980d6eacee993ffa8bcdf769a0e2 Mon Sep 17 00:00:00 2001 From: hector <1069315972@qq.com> Date: Sun, 19 Aug 2018 07:13:16 +0800 Subject: [PATCH 17/28] Change buildCondition TableName to struct's TableName when query is interface{} (#2011) --- scope.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scope.go b/scope.go index a05c1d61..ca861d8a 100644 --- a/scope.go +++ b/scope.go @@ -586,10 +586,10 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) scope.Err(fmt.Errorf("invalid query condition: %v", value)) return } - + scopeQuotedTableName := newScope.QuotedTableName() for _, field := range newScope.Fields() { if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) + sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) } } return strings.Join(sqls, " AND ") From 32455088f24d6b1e9a502fb8e40fdc16139dbea8 Mon Sep 17 00:00:00 2001 From: Eason Lin Date: Sun, 19 Aug 2018 07:14:33 +0800 Subject: [PATCH 18/28] doc: document ErrRecordNotFound error more clear (#2015) * doc: document ErrRecordNotFound error more clear * fix goimports * fix goimports * undo change --- errors.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/errors.go b/errors.go index da2cf13c..27c9a92d 100644 --- a/errors.go +++ b/errors.go @@ -6,7 +6,7 @@ import ( ) var ( - // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct + // ErrRecordNotFound record not found error, happens when only haven't find any matched data when looking up with a struct, finding a slice won't return this error ErrRecordNotFound = errors.New("record not found") // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL ErrInvalidSQL = errors.New("invalid SQL") From 6f58f8a52cc3ad21950402d1adaa09682e07ec2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adem=20=C3=96zay?= Date: Mon, 10 Sep 2018 00:52:20 +0300 Subject: [PATCH 19/28] added naming strategy option for db, table and column names (#2040) --- model_struct.go | 12 ++--- naming.go | 124 ++++++++++++++++++++++++++++++++++++++++++++++++ naming_test.go | 69 +++++++++++++++++++++++++++ scope.go | 4 +- utils.go | 61 ------------------------ utils_test.go | 35 -------------- 6 files changed, 201 insertions(+), 104 deletions(-) create mode 100644 naming.go create mode 100644 naming_test.go delete mode 100644 utils_test.go diff --git a/model_struct.go b/model_struct.go index 8506fe87..5b5be618 100644 --- a/model_struct.go +++ b/model_struct.go @@ -34,7 +34,7 @@ func (s *ModelStruct) TableName(db *DB) string { if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { s.defaultTableName = tabler.TableName() } else { - tableName := ToDBName(s.ModelType.Name()) + tableName := ToTableName(s.ModelType.Name()) if db == nil || !db.parent.singularTable { tableName = inflection.Plural(tableName) } @@ -105,7 +105,7 @@ type Relationship struct { func getForeignField(column string, fields []*StructField) *StructField { for _, field := range fields { - if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) { + if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) { return field } } @@ -269,7 +269,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { // if defined join table's foreign key relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) } else { - defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName + defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) } } @@ -300,7 +300,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) } else { // join table foreign keys for association - joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) } } @@ -308,7 +308,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, many2many, reflectType, elemType) + joinTableHandler.Setup(relationship, ToTableName(many2many), reflectType, elemType) relationship.JoinTableHandler = &joinTableHandler field.Relationship = relationship } else { @@ -566,7 +566,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if value, ok := field.TagSettings["COLUMN"]; ok { field.DBName = value } else { - field.DBName = ToDBName(fieldStruct.Name) + field.DBName = ToColumnName(fieldStruct.Name) } modelStruct.StructFields = append(modelStruct.StructFields, field) diff --git a/naming.go b/naming.go new file mode 100644 index 00000000..6b0a4fdd --- /dev/null +++ b/naming.go @@ -0,0 +1,124 @@ +package gorm + +import ( + "bytes" + "strings" +) + +// Namer is a function type which is given a string and return a string +type Namer func(string) string + +// NamingStrategy represents naming strategies +type NamingStrategy struct { + DB Namer + Table Namer + Column Namer +} + +// TheNamingStrategy is being initialized with defaultNamingStrategy +var TheNamingStrategy = &NamingStrategy{ + DB: defaultNamer, + Table: defaultNamer, + Column: defaultNamer, +} + +// AddNamingStrategy sets the naming strategy +func AddNamingStrategy(ns *NamingStrategy) { + if ns.DB == nil { + ns.DB = defaultNamer + } + if ns.Table == nil { + ns.Table = defaultNamer + } + if ns.Column == nil { + ns.Column = defaultNamer + } + TheNamingStrategy = ns +} + +// DBName alters the given name by DB +func (ns *NamingStrategy) DBName(name string) string { + return ns.DB(name) +} + +// TableName alters the given name by Table +func (ns *NamingStrategy) TableName(name string) string { + return ns.Table(name) +} + +// ColumnName alters the given name by Column +func (ns *NamingStrategy) ColumnName(name string) string { + return ns.Column(name) +} + +// ToDBName convert string to db name +func ToDBName(name string) string { + return TheNamingStrategy.DBName(name) +} + +// ToTableName convert string to table name +func ToTableName(name string) string { + return TheNamingStrategy.TableName(name) +} + +// ToColumnName convert string to db name +func ToColumnName(name string) string { + return TheNamingStrategy.ColumnName(name) +} + +var smap = newSafeMap() + +func defaultNamer(name string) string { + const ( + lower = false + upper = true + ) + + if v := smap.Get(name); v != "" { + return v + } + + if name == "" { + return "" + } + + var ( + value = commonInitialismsReplacer.Replace(name) + buf = bytes.NewBufferString("") + lastCase, currCase, nextCase, nextNumber bool + ) + + for i, v := range value[:len(value)-1] { + nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z') + nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9') + + if i > 0 { + if currCase == upper { + if lastCase == upper && (nextCase == upper || nextNumber == upper) { + buf.WriteRune(v) + } else { + if value[i-1] != '_' && value[i+1] != '_' { + buf.WriteRune('_') + } + buf.WriteRune(v) + } + } else { + buf.WriteRune(v) + if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { + buf.WriteRune('_') + } + } + } else { + currCase = upper + buf.WriteRune(v) + } + lastCase = currCase + currCase = nextCase + } + + buf.WriteByte(value[len(value)-1]) + + s := strings.ToLower(buf.String()) + smap.Set(name, s) + return s +} diff --git a/naming_test.go b/naming_test.go new file mode 100644 index 00000000..0c6f7713 --- /dev/null +++ b/naming_test.go @@ -0,0 +1,69 @@ +package gorm_test + +import ( + "testing" + + "github.com/jinzhu/gorm" +) + +func TestTheNamingStrategy(t *testing.T) { + + cases := []struct { + name string + namer gorm.Namer + expected string + }{ + {name: "auth", expected: "auth", namer: gorm.TheNamingStrategy.DB}, + {name: "userRestrictions", expected: "user_restrictions", namer: gorm.TheNamingStrategy.Table}, + {name: "clientID", expected: "client_id", namer: gorm.TheNamingStrategy.Column}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := c.namer(c.name) + if result != c.expected { + t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) + } + }) + } + +} + +func TestNamingStrategy(t *testing.T) { + + dbNameNS := func(name string) string { + return "db_" + name + } + tableNameNS := func(name string) string { + return "tbl_" + name + } + columnNameNS := func(name string) string { + return "col_" + name + } + + ns := &gorm.NamingStrategy{ + DB: dbNameNS, + Table: tableNameNS, + Column: columnNameNS, + } + + cases := []struct { + name string + namer gorm.Namer + expected string + }{ + {name: "auth", expected: "db_auth", namer: ns.DB}, + {name: "user", expected: "tbl_user", namer: ns.Table}, + {name: "password", expected: "col_password", namer: ns.Column}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := c.namer(c.name) + if result != c.expected { + t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) + } + }) + } + +} diff --git a/scope.go b/scope.go index ca861d8a..fbf7634e 100644 --- a/scope.go +++ b/scope.go @@ -134,7 +134,7 @@ func (scope *Scope) Fields() []*Field { // FieldByName find `gorm.Field` with field name or db name func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { var ( - dbName = ToDBName(name) + dbName = ToColumnName(name) mostMatchedField *Field ) @@ -880,7 +880,7 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string switch reflectValue.Kind() { case reflect.Map: for _, key := range reflectValue.MapKeys() { - attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() + attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() } default: for _, field := range (&Scope{Value: values}).Fields() { diff --git a/utils.go b/utils.go index 99b532c5..ad700b98 100644 --- a/utils.go +++ b/utils.go @@ -1,7 +1,6 @@ package gorm import ( - "bytes" "database/sql/driver" "fmt" "reflect" @@ -58,66 +57,6 @@ func newSafeMap() *safeMap { return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} } -var smap = newSafeMap() - -type strCase bool - -const ( - lower strCase = false - upper strCase = true -) - -// ToDBName convert string to db name -func ToDBName(name string) string { - if v := smap.Get(name); v != "" { - return v - } - - if name == "" { - return "" - } - - var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase, nextNumber strCase - ) - - for i, v := range value[:len(value)-1] { - nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z') - nextNumber = strCase(value[i+1] >= '0' && value[i+1] <= '9') - - if i > 0 { - if currCase == upper { - if lastCase == upper && (nextCase == upper || nextNumber == upper) { - buf.WriteRune(v) - } else { - if value[i-1] != '_' && value[i+1] != '_' { - buf.WriteRune('_') - } - buf.WriteRune(v) - } - } else { - buf.WriteRune(v) - if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { - buf.WriteRune('_') - } - } - } else { - currCase = upper - buf.WriteRune(v) - } - lastCase = currCase - currCase = nextCase - } - - buf.WriteByte(value[len(value)-1]) - - s := strings.ToLower(buf.String()) - smap.Set(name, s) - return s -} - // SQL expression type expr struct { expr string diff --git a/utils_test.go b/utils_test.go deleted file mode 100644 index 086c4450..00000000 --- a/utils_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package gorm_test - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -func TestToDBNameGenerateFriendlyName(t *testing.T) { - var maps = map[string]string{ - "": "", - "X": "x", - "ThisIsATest": "this_is_a_test", - "PFAndESI": "pf_and_esi", - "AbcAndJkl": "abc_and_jkl", - "EmployeeID": "employee_id", - "SKU_ID": "sku_id", - "UTF8": "utf8", - "Level1": "level1", - "SHA256Hash": "sha256_hash", - "FieldX": "field_x", - "HTTPAndSMTP": "http_and_smtp", - "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", - "UUID": "uuid", - "HTTPURL": "http_url", - "HTTP_URL": "http_url", - "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", - } - - for key, value := range maps { - if gorm.ToDBName(key) != value { - t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key)) - } - } -} From dc3b2476c4eb61c37424a1ca2f46859e4e6fcd81 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 10 Sep 2018 06:03:41 +0800 Subject: [PATCH 20/28] Don't save ignored fields into database --- callback_create.go | 2 +- scope.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/callback_create.go b/callback_create.go index e7fe6f86..2ab05d3b 100644 --- a/callback_create.go +++ b/callback_create.go @@ -59,7 +59,7 @@ func createCallback(scope *Scope) { for _, field := range scope.Fields() { if scope.changeableField(field) { - if field.IsNormal { + if field.IsNormal && !field.IsIgnored { if field.IsBlank && field.HasDefaultValue { blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) diff --git a/scope.go b/scope.go index fbf7634e..7d6ba1c0 100644 --- a/scope.go +++ b/scope.go @@ -907,7 +907,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin results[field.DBName] = value } else { err := field.Set(value) - if field.IsNormal { + if field.IsNormal && !field.IsIgnored { hasUpdate = true if err == ErrUnaddressable { results[field.DBName] = value From 71b7f19aad77eaf99a90324c7d2ac5634eaefca8 Mon Sep 17 00:00:00 2001 From: Xy Ziemba Date: Sun, 9 Sep 2018 15:12:58 -0700 Subject: [PATCH 21/28] Fix scanning identical column names occurring >2 times (#2080) Fix the indexing logic used in selectedColumnsMap to skip fields that have already been seen. The values of selectedColumns map must be indexed relative to fields, not relative to selectFields. --- main_test.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++ scope.go | 6 ++++-- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/main_test.go b/main_test.go index 265e0be7..11c4bb87 100644 --- a/main_test.go +++ b/main_test.go @@ -581,6 +581,60 @@ func TestJoins(t *testing.T) { } } +type JoinedIds struct { + UserID int64 `gorm:"column:id"` + BillingAddressID int64 `gorm:"column:id"` + EmailID int64 `gorm:"column:id"` +} + +func TestScanIdenticalColumnNames(t *testing.T) { + var user = User{ + Name: "joinsIds", + Email: "joinIds@example.com", + BillingAddress: Address{ + Address1: "One Park Place", + }, + Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, + } + DB.Save(&user) + + var users []JoinedIds + DB.Select("users.id, addresses.id, emails.id").Table("users"). + Joins("left join addresses on users.billing_address_id = addresses.id"). + Joins("left join emails on emails.user_id = users.id"). + Where("name = ?", "joinsIds").Scan(&users) + + if len(users) != 2 { + t.Fatal("should find two rows using left join") + } + + if user.Id != users[0].UserID { + t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[0].UserID) + } + if user.Id != users[1].UserID { + t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[1].UserID) + } + + if user.BillingAddressID.Int64 != users[0].BillingAddressID { + t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID) + } + if user.BillingAddressID.Int64 != users[1].BillingAddressID { + t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID) + } + + if users[0].EmailID == users[1].EmailID { + t.Errorf("Email ids should be unique. Got %d and %d", users[0].EmailID, users[1].EmailID) + } + + if int64(user.Emails[0].Id) != users[0].EmailID && int64(user.Emails[1].Id) != users[0].EmailID { + t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[0].EmailID) + } + + if int64(user.Emails[0].Id) != users[1].EmailID && int64(user.Emails[1].Id) != users[1].EmailID { + t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[1].EmailID) + } +} + func TestJoinsWithSelect(t *testing.T) { type result struct { Name string diff --git a/scope.go b/scope.go index 7d6ba1c0..ce80ab86 100644 --- a/scope.go +++ b/scope.go @@ -486,8 +486,10 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { values[index] = &ignored selectFields = fields + offset := 0 if idx, ok := selectedColumnsMap[column]; ok { - selectFields = selectFields[idx+1:] + offset = idx + 1 + selectFields = selectFields[offset:] } for fieldIndex, field := range selectFields { @@ -501,7 +503,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { resetFields[index] = field } - selectedColumnsMap[column] = fieldIndex + selectedColumnsMap[column] = offset + fieldIndex if field.IsNormal { break From 12607e8bdf4a724492d53d8c788edc77ad4439e7 Mon Sep 17 00:00:00 2001 From: kuangzhiqiang Date: Mon, 10 Sep 2018 06:14:05 +0800 Subject: [PATCH 22/28] for go1.11 go mod (#2072) when used go1.11 gomodules the code dir will be `$GOPATH/pkg/mod/github.com/jinzhu/gorm@*/` fileWithLineNum check failed --- utils.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils.go b/utils.go index ad700b98..8489538c 100644 --- a/utils.go +++ b/utils.go @@ -25,8 +25,8 @@ var NowFunc = func() time.Time { var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialismsReplacer *strings.Replacer -var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) -var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`) +var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`) +var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`) func init() { var commonInitialismsForReplacer []string From d3e666a1e086a020905e3f6cf293941806520d97 Mon Sep 17 00:00:00 2001 From: Ikhtiyor <33823221+iahmedov@users.noreply.github.com> Date: Mon, 10 Sep 2018 03:25:26 +0500 Subject: [PATCH 23/28] save_associations:true should store related item (#2067) * save_associations:true should store related item, save_associations priority on related objects * code quality --- callback_save.go | 6 ++-- main_test.go | 88 +++++++++++++++++++++++++++++++++++++++++++++++ migration_test.go | 10 +++++- 3 files changed, 100 insertions(+), 4 deletions(-) diff --git a/callback_save.go b/callback_save.go index ef267141..ebfd0b34 100644 --- a/callback_save.go +++ b/callback_save.go @@ -21,9 +21,7 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea if v, ok := value.(string); ok { v = strings.ToLower(v) - if v == "false" || v != "skip" { - return false - } + return v == "true" } return true @@ -36,9 +34,11 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea if value, ok := scope.Get("gorm:save_associations"); ok { autoUpdate = checkTruth(value) autoCreate = autoUpdate + saveReference = autoUpdate } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok { autoUpdate = checkTruth(value) autoCreate = autoUpdate + saveReference = autoUpdate } if value, ok := scope.Get("gorm:association_autoupdate"); ok { diff --git a/main_test.go b/main_test.go index 11c4bb87..94d2fa39 100644 --- a/main_test.go +++ b/main_test.go @@ -933,6 +933,94 @@ func TestOpenWithOneParameter(t *testing.T) { } } +func TestSaveAssociations(t *testing.T) { + db := DB.New() + deltaAddressCount := 0 + if err := db.Model(&Address{}).Count(&deltaAddressCount).Error; err != nil { + t.Errorf("failed to fetch address count") + t.FailNow() + } + + placeAddress := &Address{ + Address1: "somewhere on earth", + } + ownerAddress1 := &Address{ + Address1: "near place address", + } + ownerAddress2 := &Address{ + Address1: "address2", + } + db.Create(placeAddress) + + addressCountShouldBe := func(t *testing.T, expectedCount int) { + countFromDB := 0 + t.Helper() + err := db.Model(&Address{}).Count(&countFromDB).Error + if err != nil { + t.Error("failed to fetch address count") + } + if countFromDB != expectedCount { + t.Errorf("address count mismatch: %d", countFromDB) + } + } + addressCountShouldBe(t, deltaAddressCount+1) + + // owner address should be created, place address should be reused + place1 := &Place{ + PlaceAddressID: placeAddress.ID, + PlaceAddress: placeAddress, + OwnerAddress: ownerAddress1, + } + err := db.Create(place1).Error + if err != nil { + t.Errorf("failed to store place: %s", err.Error()) + } + addressCountShouldBe(t, deltaAddressCount+2) + + // owner address should be created again, place address should be reused + place2 := &Place{ + PlaceAddressID: placeAddress.ID, + PlaceAddress: &Address{ + ID: 777, + Address1: "address1", + }, + OwnerAddress: ownerAddress2, + OwnerAddressID: 778, + } + err = db.Create(place2).Error + if err != nil { + t.Errorf("failed to store place: %s", err.Error()) + } + addressCountShouldBe(t, deltaAddressCount+3) + + count := 0 + db.Model(&Place{}).Where(&Place{ + PlaceAddressID: placeAddress.ID, + OwnerAddressID: ownerAddress1.ID, + }).Count(&count) + if count != 1 { + t.Errorf("only one instance of (%d, %d) should be available, found: %d", + placeAddress.ID, ownerAddress1.ID, count) + } + + db.Model(&Place{}).Where(&Place{ + PlaceAddressID: placeAddress.ID, + OwnerAddressID: ownerAddress2.ID, + }).Count(&count) + if count != 1 { + t.Errorf("only one instance of (%d, %d) should be available, found: %d", + placeAddress.ID, ownerAddress2.ID, count) + } + + db.Model(&Place{}).Where(&Place{ + PlaceAddressID: placeAddress.ID, + }).Count(&count) + if count != 2 { + t.Errorf("two instances of (%d) should be available, found: %d", + placeAddress.ID, count) + } +} + func TestBlockGlobalUpdate(t *testing.T) { db := DB.New() db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"}) diff --git a/migration_test.go b/migration_test.go index 78555dcc..3fb06648 100644 --- a/migration_test.go +++ b/migration_test.go @@ -118,6 +118,14 @@ type Company struct { Owner *User `sql:"-"` } +type Place struct { + Id int64 + PlaceAddressID int + PlaceAddress *Address `gorm:"save_associations:false"` + OwnerAddressID int + OwnerAddress *Address `gorm:"save_associations:true"` +} + type EncryptedData []byte func (data *EncryptedData) Scan(value interface{}) error { @@ -284,7 +292,7 @@ func runMigration() { DB.Exec(fmt.Sprintf("drop table %v;", table)) } - values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}} + values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}} for _, value := range values { DB.DropTable(value) } From 73e7561e20e8e554ec54463ccbed38e426aad17f Mon Sep 17 00:00:00 2001 From: Aaron Leung Date: Sun, 9 Sep 2018 15:26:29 -0700 Subject: [PATCH 24/28] Use sync.Map for DB.values (#2064) * Replace the regular map with a sync.Map to avoid fatal concurrent map reads/writes * fix the formatting --- main.go | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/main.go b/main.go index 993e19b1..364d8e8e 100644 --- a/main.go +++ b/main.go @@ -22,7 +22,7 @@ type DB struct { logMode int logger logger search *search - values map[string]interface{} + values sync.Map // global db parent *DB @@ -72,7 +72,6 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { db = &DB{ db: dbSQL, logger: defaultLogger, - values: map[string]interface{}{}, callbacks: DefaultCallback, dialect: newDialect(dialect, dbSQL), } @@ -680,13 +679,13 @@ func (s *DB) Set(name string, value interface{}) *DB { // InstantSet instant set setting, will affect current db func (s *DB) InstantSet(name string, value interface{}) *DB { - s.values[name] = value + s.values.Store(name, value) return s } // Get get setting by name func (s *DB) Get(name string) (value interface{}, ok bool) { - value, ok = s.values[name] + value, ok = s.values.Load(name) return } @@ -750,16 +749,16 @@ func (s *DB) clone() *DB { parent: s.parent, logger: s.logger, logMode: s.logMode, - values: map[string]interface{}{}, Value: s.Value, Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate, dialect: newDialect(s.dialect.GetName(), s.db), } - for key, value := range s.values { - db.values[key] = value - } + s.values.Range(func(k, v interface{}) bool { + db.values.Store(k, v) + return true + }) if s.search == nil { db.search = &search{limit: -1, offset: -1} From 012d1479740ec593b0c07f0372e0111c01c3b34a Mon Sep 17 00:00:00 2001 From: maddie Date: Mon, 10 Sep 2018 06:45:55 +0800 Subject: [PATCH 25/28] Improve preload speed (#2058) All credits to @vanjapt who came up with this patch. Closes #1672 --- callback_query_preload.go | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/callback_query_preload.go b/callback_query_preload.go index 481bfbe3..46405c38 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -161,14 +161,17 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) ) if indirectScopeValue.Kind() == reflect.Slice { + foreignValuesToResults := make(map[string]reflect.Value) + for i := 0; i < resultsValue.Len(); i++ { + result := resultsValue.Index(i) + foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames)) + foreignValuesToResults[foreignValues] = result + } for j := 0; j < indirectScopeValue.Len(); j++ { - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { - indirectValue.FieldByName(field.Name).Set(result) - break - } + indirectValue := indirect(indirectScopeValue.Index(j)) + valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames)) + if result, found := foreignValuesToResults[valueString]; found { + indirectValue.FieldByName(field.Name).Set(result) } } } else { @@ -255,13 +258,21 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ indirectScopeValue = scope.IndirectValue() ) + foreignFieldToObjects := make(map[string][]*reflect.Value) + if indirectScopeValue.Kind() == reflect.Slice { + for j := 0; j < indirectScopeValue.Len(); j++ { + object := indirect(indirectScopeValue.Index(j)) + valueString := toString(getValueFromFields(object, relation.ForeignFieldNames)) + foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object) + } + } + for i := 0; i < resultsValue.Len(); i++ { result := resultsValue.Index(i) if indirectScopeValue.Kind() == reflect.Slice { - value := getValueFromFields(result, relation.AssociationForeignFieldNames) - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { + valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames)) + if objects, found := foreignFieldToObjects[valueString]; found { + for _, object := range objects { object.FieldByName(field.Name).Set(result) } } From 26fde9110f932df8cb5cc24396e7a54a6d3a94c2 Mon Sep 17 00:00:00 2001 From: Gustavo Brunoro Date: Sun, 9 Sep 2018 19:47:18 -0300 Subject: [PATCH 26/28] getValueFromFields doesn't panic on nil pointers (#2021) * `IsValid()` won't return `false` for nil pointers unless Value is wrapped in a `reflect.Indirect`. --- utils.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils.go b/utils.go index 8489538c..e58e57a5 100644 --- a/utils.go +++ b/utils.go @@ -206,7 +206,7 @@ func getValueFromFields(value reflect.Value, fieldNames []string) (results []int // as FieldByName could panic if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { for _, fieldName := range fieldNames { - if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { + if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() { result := fieldValue.Interface() if r, ok := result.(driver.Valuer); ok { result, _ = r.Value() From 588b598f9fbf9a0c84b6ec18f617940b045c54d4 Mon Sep 17 00:00:00 2001 From: Phillip Shipley Date: Sun, 9 Sep 2018 18:50:22 -0400 Subject: [PATCH 27/28] Fix issue updating models with foreign key constraints (#1988) * fix update callback to not try to write zero values when field has default value * fix to update callback for gorm tests --- callback_update.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/callback_update.go b/callback_update.go index 373bd726..f6ba0ffd 100644 --- a/callback_update.go +++ b/callback_update.go @@ -76,7 +76,9 @@ func updateCallback(scope *Scope) { for _, field := range scope.Fields() { if scope.changeableField(field) { if !field.IsPrimaryKey && field.IsNormal { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { for _, foreignKey := range relationship.ForeignDBNames { if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { From 282f11af1900a36646b607797273056d76350223 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 9 Sep 2018 19:52:32 -0300 Subject: [PATCH 28/28] Support only preloading (#1926) * add support for only preloading relations on an already populated model * Update callback_query.go comments --- callback_query.go | 5 +++++ main.go | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/callback_query.go b/callback_query.go index ba10cc7d..593e5d30 100644 --- a/callback_query.go +++ b/callback_query.go @@ -18,6 +18,11 @@ func queryCallback(scope *Scope) { if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { return } + + //we are only preloading relations, dont touch base model + if _, skip := scope.InstanceGet("gorm:only_preload"); skip { + return + } defer scope.trace(NowFunc()) diff --git a/main.go b/main.go index 364d8e8e..4dbda61e 100644 --- a/main.go +++ b/main.go @@ -314,6 +314,11 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB { return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } +//Preloads preloads relations, don`t touch out +func (s *DB) Preloads(out interface{}) *DB { + return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db +} + // Scan scan value to a struct func (s *DB) Scan(dest interface{}) *DB { return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db