From a12c2a2e13b0f644647dbd369a88b01fac109bd5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 27 Feb 2018 10:48:59 +0800 Subject: [PATCH 01/34] Remove mysql8 from CI --- wercker.yml | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/wercker.yml b/wercker.yml index 2f2370b3..0c3e73ef 100644 --- a/wercker.yml +++ b/wercker.yml @@ -9,13 +9,6 @@ services: MYSQL_USER: gorm MYSQL_PASSWORD: gorm MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql - id: mysql:8 - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - name: mysql57 id: mysql:5.7 env: @@ -109,11 +102,6 @@ build: code: | GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test ./... - - script: - name: test mysql - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test ./... - - script: name: test mysql5.7 code: | From 6ed508ec6a4ecb3531899a69cbc746ccf65a4166 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 28 Feb 2018 07:43:56 +0800 Subject: [PATCH 02/34] Fix panic with raw SQL --- scope.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope.go b/scope.go index 25077efc..150ac710 100644 --- a/scope.go +++ b/scope.go @@ -650,7 +650,7 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) buff := bytes.NewBuffer([]byte{}) i := 0 for _, s := range str { - if s == '?' { + if s == '?' && len(replacements) > i { buff.WriteString(replacements[i]) i++ } else { From 52c5c8127cf4aeffde3e0aa9222640832075a90f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sa=C3=BAl=20Ortega?= Date: Thu, 15 Mar 2018 09:35:31 -0500 Subject: [PATCH 03/34] Support for UTF8 names on DB (#1793) --- scope.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scope.go b/scope.go index 150ac710..2f39e073 100644 --- a/scope.go +++ b/scope.go @@ -692,12 +692,12 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) buff := bytes.NewBuffer([]byte{}) i := 0 - for pos := range str { + for pos, char := range str { if str[pos] == '?' { buff.WriteString(replacements[i]) i++ } else { - buff.WriteByte(str[pos]) + buff.WriteRune(char) } } From 919c6db4f854e4feaae94202ae29da4e3779de49 Mon Sep 17 00:00:00 2001 From: Giuseppe Date: Mon, 16 Apr 2018 16:18:51 +0200 Subject: [PATCH 04/34] Do not panic if Begin().Error was ignored (#1830) --- main.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index c26e05c8..ffee4ec6 100644 --- a/main.go +++ b/main.go @@ -491,7 +491,8 @@ func (s *DB) Begin() *DB { // Commit commit a transaction func (s *DB) Commit() *DB { - if db, ok := s.db.(sqlTx); ok && db != nil { + var emptySQLTx *sql.Tx + if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { s.AddError(db.Commit()) } else { s.AddError(ErrInvalidTransaction) From 6842b49a1ad0feb6b93be830fe63a682cf853ada Mon Sep 17 00:00:00 2001 From: Shane Date: Mon, 16 Apr 2018 07:20:02 -0700 Subject: [PATCH 05/34] fix scope.removeForeignKey method (#1841) --- scope.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope.go b/scope.go index 2f39e073..397ccf0b 100644 --- a/scope.go +++ b/scope.go @@ -1215,7 +1215,7 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on } func (scope *Scope) removeForeignKey(field string, dest string) { - keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest) + keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return From 35efe68ba71d571e64ccd1ee62830c30a53ed967 Mon Sep 17 00:00:00 2001 From: Daniel McDonald Date: Wed, 2 May 2018 07:37:51 -0700 Subject: [PATCH 06/34] add simple input validation on gorm.Open function (#1855) Simply check if the passed-in database source meets the expected types and, if not, early return with error. --- main.go | 2 ++ main_test.go | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/main.go b/main.go index ffee4ec6..c8a43e8c 100644 --- a/main.go +++ b/main.go @@ -61,6 +61,8 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { dbSQL, err = sql.Open(driver, source) case SQLCommon: dbSQL = value + default: + return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) } db = &DB{ diff --git a/main_test.go b/main_test.go index 66c46af0..265e0be7 100644 --- a/main_test.go +++ b/main_test.go @@ -8,6 +8,7 @@ import ( "path/filepath" "reflect" "strconv" + "strings" "testing" "time" @@ -79,6 +80,22 @@ func OpenTestConnection() (db *gorm.DB, err error) { return } +func TestOpen_ReturnsError_WithBadArgs(t *testing.T) { + stringRef := "foo" + testCases := []interface{}{42, time.Now(), &stringRef} + for _, tc := range testCases { + t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) { + _, err := gorm.Open("postgresql", tc) + if err == nil { + t.Error("Should got error with invalid database source") + } + if !strings.HasPrefix(err.Error(), "invalid database source:") { + t.Errorf("Should got error starting with \"invalid database source:\", but got %q", err.Error()) + } + }) + } +} + func TestStringPrimaryKey(t *testing.T) { type UUIDStruct struct { ID string `gorm:"primary_key"` From 9044197ef935c0969d94cbcfba55ccb94d269bed Mon Sep 17 00:00:00 2001 From: Illya Busigin Date: Wed, 2 May 2018 09:38:52 -0500 Subject: [PATCH 07/34] Adding GetDialect function (#1869) --- dialect.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dialect.go b/dialect.go index 5f6439c1..506a6e86 100644 --- a/dialect.go +++ b/dialect.go @@ -72,6 +72,12 @@ func RegisterDialect(name string, dialect Dialect) { dialectsMap[name] = dialect } +// GetDialect gets the dialect for the specified dialect name +func GetDialect(name string) (dialect Dialect, ok bool) { + dialect, ok = dialectsMap[name] + return +} + // ParseFieldStructForDialect get field's sql data type var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { // Get redirected field type From a58b98acee2f3bf213b2cb0f1fe1468f236c9aec Mon Sep 17 00:00:00 2001 From: lrita Date: Sat, 12 May 2018 14:28:15 +0800 Subject: [PATCH 08/34] Do not panic if Begin().Error was ignored (#1830) (#1881) --- main.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index c8a43e8c..25c3a06b 100644 --- a/main.go +++ b/main.go @@ -504,7 +504,8 @@ func (s *DB) Commit() *DB { // Rollback rollback a transaction func (s *DB) Rollback() *DB { - if db, ok := s.db.(sqlTx); ok && db != nil { + var emptySQLTx *sql.Tx + if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { s.AddError(db.Rollback()) } else { s.AddError(ErrInvalidTransaction) From 82eb9f8a5bbb5e6b929d2f0ae5b934e6a253f94e Mon Sep 17 00:00:00 2001 From: Olga Kleitsa Date: Sat, 12 May 2018 09:29:00 +0300 Subject: [PATCH 09/34] included actual sql query to discover fi foreign key with the same name exists in a specific table of the database in use (#1896) --- dialects/mssql/mssql.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index e0606465..a8d3c45a 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -130,7 +130,14 @@ func (s mssql) RemoveIndex(tableName string, indexName string) error { } func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { - return false + var count int + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow(`SELECT count(*) + FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id + inner join information_schema.tables as I on I.TABLE_NAME = T.name + WHERE F.name = ? + AND T.Name = ? AND I.TABLE_CATALOG = ?;`, foreignKeyName, tableName, currentDatabase).Scan(&count) + return count > 0 } func (s mssql) HasTable(tableName string) bool { From 1907bff3732cb4c612e4118137d8f3c8829cc8c6 Mon Sep 17 00:00:00 2001 From: ia Date: Mon, 25 Jun 2018 07:06:58 +0200 Subject: [PATCH 10/34] all: gofmt (#1956) Run standard gofmt command on project root. - go version go1.10.3 darwin/amd64 Signed-off-by: ia --- dialects/postgres/postgres.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 1d0dcb60..424e8bdc 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -4,11 +4,11 @@ import ( "database/sql" "database/sql/driver" - _ "github.com/lib/pq" - "github.com/lib/pq/hstore" "encoding/json" "errors" "fmt" + _ "github.com/lib/pq" + "github.com/lib/pq/hstore" ) type Hstore map[string]*string From 0fd395ab37aefd2d50854f0556a4311dccc6f45a Mon Sep 17 00:00:00 2001 From: Masaki Yoshida Date: Mon, 25 Jun 2018 14:07:53 +0900 Subject: [PATCH 11/34] Fix ToDBName (#1941) Don't place '_' before number. - NG: SHA256Hash -> sha_256_hash - OK: SHA256Hash -> sha256_hash --- utils.go | 12 +++++++----- utils_test.go | 3 +++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/utils.go b/utils.go index dfaae939..99b532c5 100644 --- a/utils.go +++ b/utils.go @@ -78,16 +78,18 @@ func ToDBName(name string) string { } var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase strCase + value = commonInitialismsReplacer.Replace(name) + buf = bytes.NewBufferString("") + lastCase, currCase, nextCase, nextNumber strCase ) for i, v := range value[:len(value)-1] { nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z') + nextNumber = strCase(value[i+1] >= '0' && value[i+1] <= '9') + if i > 0 { if currCase == upper { - if lastCase == upper && nextCase == upper { + if lastCase == upper && (nextCase == upper || nextNumber == upper) { buf.WriteRune(v) } else { if value[i-1] != '_' && value[i+1] != '_' { @@ -97,7 +99,7 @@ func ToDBName(name string) string { } } else { buf.WriteRune(v) - if i == len(value)-2 && nextCase == upper { + if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { buf.WriteRune('_') } } diff --git a/utils_test.go b/utils_test.go index 152296d2..086c4450 100644 --- a/utils_test.go +++ b/utils_test.go @@ -15,6 +15,9 @@ func TestToDBNameGenerateFriendlyName(t *testing.T) { "AbcAndJkl": "abc_and_jkl", "EmployeeID": "employee_id", "SKU_ID": "sku_id", + "UTF8": "utf8", + "Level1": "level1", + "SHA256Hash": "sha256_hash", "FieldX": "field_x", "HTTPAndSMTP": "http_and_smtp", "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", From dbb25e94879f463c699430a74d29c9557e15a60f Mon Sep 17 00:00:00 2001 From: Louis Brauer Date: Fri, 27 Jul 2018 01:30:57 +0200 Subject: [PATCH 12/34] Adding json type for mssql dialect, similar to postgres.Jsonb (#1934) * Adding json type for mssql dialect, similar to postgres.Jsonb * Adding proper comments --- dialects/mssql/mssql.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index a8d3c45a..731721cb 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -1,12 +1,16 @@ package mssql import ( + "database/sql/driver" + "encoding/json" + "errors" "fmt" "reflect" "strconv" "strings" "time" + // Importing mssql driver package only in dialect file, otherwide not needed _ "github.com/denisenkom/go-mssqldb" "github.com/jinzhu/gorm" ) @@ -201,3 +205,27 @@ func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, st } return dialect.CurrentDatabase(), tableName } + +// JSON type to support easy handling of JSON data in character table fields +// using golang json.RawMessage for deferred decoding/encoding +type JSON struct { + json.RawMessage +} + +// Value get value of JSON +func (j JSON) Value() (driver.Value, error) { + if len(j.RawMessage) == 0 { + return nil, nil + } + return j.MarshalJSON() +} + +// Scan scan value into JSON +func (j *JSON) Scan(value interface{}) error { + str, ok := value.(string) + if !ok { + return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value)) + } + bytes := []byte(str) + return json.Unmarshal(bytes, j) +} From ac3ec858a6c375a466f613c86b053726abbe3755 Mon Sep 17 00:00:00 2001 From: Kevin Date: Thu, 26 Jul 2018 19:35:53 -0400 Subject: [PATCH 13/34] Edit DB.clone(), DB.Dialect(), and Scope.Dialect() preserve transactions (#1939) * Edit DB.clone(), DB.Dialect(), and Scope.Dialect() preserve transactions. * Adds a test case for tables creations and autoMigrate in the same transaction. --- main.go | 5 ++++- migration_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ scope.go | 2 +- 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/main.go b/main.go index 25c3a06b..3a5d6b0c 100644 --- a/main.go +++ b/main.go @@ -119,7 +119,7 @@ func (s *DB) CommonDB() SQLCommon { // Dialect get dialect func (s *DB) Dialect() Dialect { - return s.parent.dialect + return s.dialect } // Callback return `Callbacks` container, you could add/change/delete callbacks with it @@ -484,6 +484,8 @@ func (s *DB) Begin() *DB { if db, ok := c.db.(sqlDb); ok && db != nil { tx, err := db.Begin() c.db = interface{}(tx).(SQLCommon) + + c.dialect.SetDB(c.db) c.AddError(err) } else { c.AddError(ErrCantStartTransaction) @@ -748,6 +750,7 @@ func (s *DB) clone() *DB { Value: s.Value, Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate, + dialect: newDialect(s.dialect.GetName(), s.db), } for key, value := range s.values { diff --git a/migration_test.go b/migration_test.go index 7c694485..78555dcc 100644 --- a/migration_test.go +++ b/migration_test.go @@ -398,6 +398,53 @@ func TestAutoMigration(t *testing.T) { } } +func TestCreateAndAutomigrateTransaction(t *testing.T) { + tx := DB.Begin() + + func() { + type Bar struct { + ID uint + } + DB.DropTableIfExists(&Bar{}) + + if ok := DB.HasTable("bars"); ok { + t.Errorf("Table should not exist, but does") + } + + if ok := tx.HasTable("bars"); ok { + t.Errorf("Table should not exist, but does") + } + }() + + func() { + type Bar struct { + Name string + } + err := tx.CreateTable(&Bar{}).Error + + if err != nil { + t.Errorf("Should have been able to create the table, but couldn't: %s", err) + } + + if ok := tx.HasTable(&Bar{}); !ok { + t.Errorf("The transaction should be able to see the table") + } + }() + + func() { + type Bar struct { + Stuff string + } + + err := tx.AutoMigrate(&Bar{}).Error + if err != nil { + t.Errorf("Should have been able to alter the table, but couldn't") + } + }() + + tx.Rollback() +} + type MultipleIndexes struct { ID int64 UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` diff --git a/scope.go b/scope.go index 397ccf0b..5eb98963 100644 --- a/scope.go +++ b/scope.go @@ -63,7 +63,7 @@ func (scope *Scope) SQLDB() SQLCommon { // Dialect get dialect func (scope *Scope) Dialect() Dialect { - return scope.db.parent.dialect + return scope.db.dialect } // Quote used to quote string to escape them for database From 588e2eef5d9c33b11ee52895ad5cdfab0d6648e6 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Fri, 27 Jul 2018 07:38:02 +0800 Subject: [PATCH 14/34] Fix typo in query_test (#1977) --- query_test.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/query_test.go b/query_test.go index fac7d4d8..15bf8b3c 100644 --- a/query_test.go +++ b/query_test.go @@ -181,17 +181,17 @@ func TestSearchWithPlainSQL(t *testing.T) { scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users) if len(users) != 2 { - t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users)) + t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users)) } scopedb.Where("birthday > ?", "2002-10-10").Find(&users) if len(users) != 2 { - t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users)) + t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users)) } scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) if len(users) != 1 { - t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) + t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) } DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) @@ -532,28 +532,28 @@ func TestNot(t *testing.T) { DB.Table("users").Where("name = ?", "user3").Count(&name3Count) DB.Not("name", "user3").Find(&users4) if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not("name = ?", "user3").Find(&users4) if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not("name <> ?", "user3").Find(&users4) if len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(User{Name: "user3"}).Find(&users5) if len(users1)-len(users5) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6) if len(users1)-len(users6) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) @@ -563,14 +563,14 @@ func TestNot(t *testing.T) { DB.Not("name", []string{"user3"}).Find(&users8) if len(users1)-len(users8) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } var name2Count int64 DB.Table("users").Where("name = ?", "user2").Count(&name2Count) DB.Not("name", []string{"user3", "user2"}).Find(&users9) if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } } From d68403b29dbf3086b2335f6381545462d96808bc Mon Sep 17 00:00:00 2001 From: antness Date: Fri, 27 Jul 2018 02:43:09 +0300 Subject: [PATCH 15/34] do not close wrapped *sql.DB (#1985) --- main.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index 3a5d6b0c..de6ce428 100644 --- a/main.go +++ b/main.go @@ -48,6 +48,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { } var source string var dbSQL SQLCommon + var ownDbSQL bool switch value := args[0].(type) { case string: @@ -59,8 +60,10 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { source = args[1].(string) } dbSQL, err = sql.Open(driver, source) + ownDbSQL = true case SQLCommon: dbSQL = value + ownDbSQL = false default: return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) } @@ -78,7 +81,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { } // Send a ping to make sure the database connection is alive. if d, ok := dbSQL.(*sql.DB); ok { - if err = d.Ping(); err != nil { + if err = d.Ping(); err != nil && ownDbSQL { d.Close() } } From 409121d9e394922787885b001d148a05e3a42b6c Mon Sep 17 00:00:00 2001 From: Alexey <10kdmg@gmail.com> Date: Fri, 27 Jul 2018 02:43:49 +0300 Subject: [PATCH 16/34] Fixed mysql query syntax for FK removal (#1993) --- scope.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/scope.go b/scope.go index 5eb98963..a05c1d61 100644 --- a/scope.go +++ b/scope.go @@ -1216,11 +1216,17 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on func (scope *Scope) removeForeignKey(field string, dest string) { keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return } - var query = `ALTER TABLE %s DROP CONSTRAINT %s;` + var mysql mysql + var query string + if scope.Dialect().GetName() == mysql.GetName() { + query = `ALTER TABLE %s DROP FOREIGN KEY %s;` + } else { + query = `ALTER TABLE %s DROP CONSTRAINT %s;` + } + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() } From 0e04d414d59f3154d700692bda0d7649d0e101b3 Mon Sep 17 00:00:00 2001 From: Artemij Shepelev Date: Sun, 19 Aug 2018 02:09:21 +0300 Subject: [PATCH 17/34] Race fix. Changes modelStructsMap implementation from map with mutex to sync.Map (#2022) * fix (https://github.com/jinzhu/gorm/issues/1407) * changed map with mutex to sync.Map (https://github.com/jinzhu/gorm/issues/1407) * removed newModelStructsMap func * commit to rerun pipeline, comment changed --- main.go | 3 ++- model_struct.go | 31 +++++-------------------------- 2 files changed, 7 insertions(+), 27 deletions(-) diff --git a/main.go b/main.go index de6ce428..993e19b1 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "sync" "time" ) @@ -162,7 +163,7 @@ func (s *DB) HasBlockGlobalUpdate() bool { // SingularTable use singular table by default func (s *DB) SingularTable(enable bool) { - modelStructsMap = newModelStructsMap() + modelStructsMap = sync.Map{} s.parent.singularTable = enable } diff --git a/model_struct.go b/model_struct.go index f571e2e8..8506fe87 100644 --- a/model_struct.go +++ b/model_struct.go @@ -17,28 +17,7 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { return defaultTableName } -type safeModelStructsMap struct { - m map[reflect.Type]*ModelStruct - l *sync.RWMutex -} - -func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newModelStructsMap() *safeModelStructsMap { - return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)} -} - -var modelStructsMap = newModelStructsMap() +var modelStructsMap sync.Map // ModelStruct model definition type ModelStruct struct { @@ -48,7 +27,7 @@ type ModelStruct struct { defaultTableName string } -// TableName get model's table name +// TableName returns model's table name func (s *ModelStruct) TableName(db *DB) string { if s.defaultTableName == "" && db != nil && s.ModelType != nil { // Set default table name @@ -152,8 +131,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Get Cached model struct - if value := modelStructsMap.Get(reflectType); value != nil { - return value + if value, ok := modelStructsMap.Load(reflectType); ok && value != nil { + return value.(*ModelStruct) } modelStruct.ModelType = reflectType @@ -601,7 +580,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } - modelStructsMap.Set(reflectType, &modelStruct) + modelStructsMap.Store(reflectType, &modelStruct) return &modelStruct } From 31ec9255cdc16482f5bef2ceb996ba75ba750a8a Mon Sep 17 00:00:00 2001 From: Elliott <617942+ellman121@users.noreply.github.com> Date: Sun, 19 Aug 2018 01:11:27 +0200 Subject: [PATCH 18/34] Setting gorm:auto_preload to false now prevents preloading (#2031) --- callback_query_preload.go | 10 ++++++++-- preload_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/callback_query_preload.go b/callback_query_preload.go index 30f6b585..481bfbe3 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -14,8 +14,14 @@ func preloadCallback(scope *Scope) { return } - if _, ok := scope.Get("gorm:auto_preload"); ok { - autoPreload(scope) + if ap, ok := scope.Get("gorm:auto_preload"); ok { + // If gorm:auto_preload IS NOT a bool then auto preload. + // Else if it IS a bool, use the value + if apb, ok := ap.(bool); !ok { + autoPreload(scope) + } else if apb { + autoPreload(scope) + } } if scope.Search.preload == nil || scope.HasError() { diff --git a/preload_test.go b/preload_test.go index 311ad0be..1db625c9 100644 --- a/preload_test.go +++ b/preload_test.go @@ -123,6 +123,31 @@ func TestAutoPreload(t *testing.T) { } } +func TestAutoPreloadFalseDoesntPreload(t *testing.T) { + user1 := getPreloadUser("auto_user1") + DB.Save(user1) + + preloadDB := DB.Set("gorm:auto_preload", false).Where("role = ?", "Preload") + var user User + preloadDB.Find(&user) + + if user.BillingAddress.Address1 != "" { + t.Error("AutoPreload was set to fasle, but still fetched data") + } + + user2 := getPreloadUser("auto_user2") + DB.Save(user2) + + var users []User + preloadDB.Find(&users) + + for _, user := range users { + if user.BillingAddress.Address1 != "" { + t.Error("AutoPreload was set to fasle, but still fetched data") + } + } +} + func TestNestedPreload1(t *testing.T) { type ( Level1 struct { From 53995294ef73980d6eacee993ffa8bcdf769a0e2 Mon Sep 17 00:00:00 2001 From: hector <1069315972@qq.com> Date: Sun, 19 Aug 2018 07:13:16 +0800 Subject: [PATCH 19/34] Change buildCondition TableName to struct's TableName when query is interface{} (#2011) --- scope.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scope.go b/scope.go index a05c1d61..ca861d8a 100644 --- a/scope.go +++ b/scope.go @@ -586,10 +586,10 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) scope.Err(fmt.Errorf("invalid query condition: %v", value)) return } - + scopeQuotedTableName := newScope.QuotedTableName() for _, field := range newScope.Fields() { if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) + sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) } } return strings.Join(sqls, " AND ") From 32455088f24d6b1e9a502fb8e40fdc16139dbea8 Mon Sep 17 00:00:00 2001 From: Eason Lin Date: Sun, 19 Aug 2018 07:14:33 +0800 Subject: [PATCH 20/34] doc: document ErrRecordNotFound error more clear (#2015) * doc: document ErrRecordNotFound error more clear * fix goimports * fix goimports * undo change --- errors.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/errors.go b/errors.go index da2cf13c..27c9a92d 100644 --- a/errors.go +++ b/errors.go @@ -6,7 +6,7 @@ import ( ) var ( - // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct + // ErrRecordNotFound record not found error, happens when only haven't find any matched data when looking up with a struct, finding a slice won't return this error ErrRecordNotFound = errors.New("record not found") // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL ErrInvalidSQL = errors.New("invalid SQL") From 6f58f8a52cc3ad21950402d1adaa09682e07ec2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adem=20=C3=96zay?= Date: Mon, 10 Sep 2018 00:52:20 +0300 Subject: [PATCH 21/34] added naming strategy option for db, table and column names (#2040) --- model_struct.go | 12 ++--- naming.go | 124 ++++++++++++++++++++++++++++++++++++++++++++++++ naming_test.go | 69 +++++++++++++++++++++++++++ scope.go | 4 +- utils.go | 61 ------------------------ utils_test.go | 35 -------------- 6 files changed, 201 insertions(+), 104 deletions(-) create mode 100644 naming.go create mode 100644 naming_test.go delete mode 100644 utils_test.go diff --git a/model_struct.go b/model_struct.go index 8506fe87..5b5be618 100644 --- a/model_struct.go +++ b/model_struct.go @@ -34,7 +34,7 @@ func (s *ModelStruct) TableName(db *DB) string { if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { s.defaultTableName = tabler.TableName() } else { - tableName := ToDBName(s.ModelType.Name()) + tableName := ToTableName(s.ModelType.Name()) if db == nil || !db.parent.singularTable { tableName = inflection.Plural(tableName) } @@ -105,7 +105,7 @@ type Relationship struct { func getForeignField(column string, fields []*StructField) *StructField { for _, field := range fields { - if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) { + if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) { return field } } @@ -269,7 +269,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { // if defined join table's foreign key relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) } else { - defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName + defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) } } @@ -300,7 +300,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) } else { // join table foreign keys for association - joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) } } @@ -308,7 +308,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, many2many, reflectType, elemType) + joinTableHandler.Setup(relationship, ToTableName(many2many), reflectType, elemType) relationship.JoinTableHandler = &joinTableHandler field.Relationship = relationship } else { @@ -566,7 +566,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if value, ok := field.TagSettings["COLUMN"]; ok { field.DBName = value } else { - field.DBName = ToDBName(fieldStruct.Name) + field.DBName = ToColumnName(fieldStruct.Name) } modelStruct.StructFields = append(modelStruct.StructFields, field) diff --git a/naming.go b/naming.go new file mode 100644 index 00000000..6b0a4fdd --- /dev/null +++ b/naming.go @@ -0,0 +1,124 @@ +package gorm + +import ( + "bytes" + "strings" +) + +// Namer is a function type which is given a string and return a string +type Namer func(string) string + +// NamingStrategy represents naming strategies +type NamingStrategy struct { + DB Namer + Table Namer + Column Namer +} + +// TheNamingStrategy is being initialized with defaultNamingStrategy +var TheNamingStrategy = &NamingStrategy{ + DB: defaultNamer, + Table: defaultNamer, + Column: defaultNamer, +} + +// AddNamingStrategy sets the naming strategy +func AddNamingStrategy(ns *NamingStrategy) { + if ns.DB == nil { + ns.DB = defaultNamer + } + if ns.Table == nil { + ns.Table = defaultNamer + } + if ns.Column == nil { + ns.Column = defaultNamer + } + TheNamingStrategy = ns +} + +// DBName alters the given name by DB +func (ns *NamingStrategy) DBName(name string) string { + return ns.DB(name) +} + +// TableName alters the given name by Table +func (ns *NamingStrategy) TableName(name string) string { + return ns.Table(name) +} + +// ColumnName alters the given name by Column +func (ns *NamingStrategy) ColumnName(name string) string { + return ns.Column(name) +} + +// ToDBName convert string to db name +func ToDBName(name string) string { + return TheNamingStrategy.DBName(name) +} + +// ToTableName convert string to table name +func ToTableName(name string) string { + return TheNamingStrategy.TableName(name) +} + +// ToColumnName convert string to db name +func ToColumnName(name string) string { + return TheNamingStrategy.ColumnName(name) +} + +var smap = newSafeMap() + +func defaultNamer(name string) string { + const ( + lower = false + upper = true + ) + + if v := smap.Get(name); v != "" { + return v + } + + if name == "" { + return "" + } + + var ( + value = commonInitialismsReplacer.Replace(name) + buf = bytes.NewBufferString("") + lastCase, currCase, nextCase, nextNumber bool + ) + + for i, v := range value[:len(value)-1] { + nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z') + nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9') + + if i > 0 { + if currCase == upper { + if lastCase == upper && (nextCase == upper || nextNumber == upper) { + buf.WriteRune(v) + } else { + if value[i-1] != '_' && value[i+1] != '_' { + buf.WriteRune('_') + } + buf.WriteRune(v) + } + } else { + buf.WriteRune(v) + if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { + buf.WriteRune('_') + } + } + } else { + currCase = upper + buf.WriteRune(v) + } + lastCase = currCase + currCase = nextCase + } + + buf.WriteByte(value[len(value)-1]) + + s := strings.ToLower(buf.String()) + smap.Set(name, s) + return s +} diff --git a/naming_test.go b/naming_test.go new file mode 100644 index 00000000..0c6f7713 --- /dev/null +++ b/naming_test.go @@ -0,0 +1,69 @@ +package gorm_test + +import ( + "testing" + + "github.com/jinzhu/gorm" +) + +func TestTheNamingStrategy(t *testing.T) { + + cases := []struct { + name string + namer gorm.Namer + expected string + }{ + {name: "auth", expected: "auth", namer: gorm.TheNamingStrategy.DB}, + {name: "userRestrictions", expected: "user_restrictions", namer: gorm.TheNamingStrategy.Table}, + {name: "clientID", expected: "client_id", namer: gorm.TheNamingStrategy.Column}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := c.namer(c.name) + if result != c.expected { + t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) + } + }) + } + +} + +func TestNamingStrategy(t *testing.T) { + + dbNameNS := func(name string) string { + return "db_" + name + } + tableNameNS := func(name string) string { + return "tbl_" + name + } + columnNameNS := func(name string) string { + return "col_" + name + } + + ns := &gorm.NamingStrategy{ + DB: dbNameNS, + Table: tableNameNS, + Column: columnNameNS, + } + + cases := []struct { + name string + namer gorm.Namer + expected string + }{ + {name: "auth", expected: "db_auth", namer: ns.DB}, + {name: "user", expected: "tbl_user", namer: ns.Table}, + {name: "password", expected: "col_password", namer: ns.Column}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := c.namer(c.name) + if result != c.expected { + t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) + } + }) + } + +} diff --git a/scope.go b/scope.go index ca861d8a..fbf7634e 100644 --- a/scope.go +++ b/scope.go @@ -134,7 +134,7 @@ func (scope *Scope) Fields() []*Field { // FieldByName find `gorm.Field` with field name or db name func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { var ( - dbName = ToDBName(name) + dbName = ToColumnName(name) mostMatchedField *Field ) @@ -880,7 +880,7 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string switch reflectValue.Kind() { case reflect.Map: for _, key := range reflectValue.MapKeys() { - attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() + attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() } default: for _, field := range (&Scope{Value: values}).Fields() { diff --git a/utils.go b/utils.go index 99b532c5..ad700b98 100644 --- a/utils.go +++ b/utils.go @@ -1,7 +1,6 @@ package gorm import ( - "bytes" "database/sql/driver" "fmt" "reflect" @@ -58,66 +57,6 @@ func newSafeMap() *safeMap { return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} } -var smap = newSafeMap() - -type strCase bool - -const ( - lower strCase = false - upper strCase = true -) - -// ToDBName convert string to db name -func ToDBName(name string) string { - if v := smap.Get(name); v != "" { - return v - } - - if name == "" { - return "" - } - - var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase, nextNumber strCase - ) - - for i, v := range value[:len(value)-1] { - nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z') - nextNumber = strCase(value[i+1] >= '0' && value[i+1] <= '9') - - if i > 0 { - if currCase == upper { - if lastCase == upper && (nextCase == upper || nextNumber == upper) { - buf.WriteRune(v) - } else { - if value[i-1] != '_' && value[i+1] != '_' { - buf.WriteRune('_') - } - buf.WriteRune(v) - } - } else { - buf.WriteRune(v) - if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { - buf.WriteRune('_') - } - } - } else { - currCase = upper - buf.WriteRune(v) - } - lastCase = currCase - currCase = nextCase - } - - buf.WriteByte(value[len(value)-1]) - - s := strings.ToLower(buf.String()) - smap.Set(name, s) - return s -} - // SQL expression type expr struct { expr string diff --git a/utils_test.go b/utils_test.go deleted file mode 100644 index 086c4450..00000000 --- a/utils_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package gorm_test - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -func TestToDBNameGenerateFriendlyName(t *testing.T) { - var maps = map[string]string{ - "": "", - "X": "x", - "ThisIsATest": "this_is_a_test", - "PFAndESI": "pf_and_esi", - "AbcAndJkl": "abc_and_jkl", - "EmployeeID": "employee_id", - "SKU_ID": "sku_id", - "UTF8": "utf8", - "Level1": "level1", - "SHA256Hash": "sha256_hash", - "FieldX": "field_x", - "HTTPAndSMTP": "http_and_smtp", - "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", - "UUID": "uuid", - "HTTPURL": "http_url", - "HTTP_URL": "http_url", - "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", - } - - for key, value := range maps { - if gorm.ToDBName(key) != value { - t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key)) - } - } -} From dc3b2476c4eb61c37424a1ca2f46859e4e6fcd81 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 10 Sep 2018 06:03:41 +0800 Subject: [PATCH 22/34] Don't save ignored fields into database --- callback_create.go | 2 +- scope.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/callback_create.go b/callback_create.go index e7fe6f86..2ab05d3b 100644 --- a/callback_create.go +++ b/callback_create.go @@ -59,7 +59,7 @@ func createCallback(scope *Scope) { for _, field := range scope.Fields() { if scope.changeableField(field) { - if field.IsNormal { + if field.IsNormal && !field.IsIgnored { if field.IsBlank && field.HasDefaultValue { blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) diff --git a/scope.go b/scope.go index fbf7634e..7d6ba1c0 100644 --- a/scope.go +++ b/scope.go @@ -907,7 +907,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin results[field.DBName] = value } else { err := field.Set(value) - if field.IsNormal { + if field.IsNormal && !field.IsIgnored { hasUpdate = true if err == ErrUnaddressable { results[field.DBName] = value From 71b7f19aad77eaf99a90324c7d2ac5634eaefca8 Mon Sep 17 00:00:00 2001 From: Xy Ziemba Date: Sun, 9 Sep 2018 15:12:58 -0700 Subject: [PATCH 23/34] Fix scanning identical column names occurring >2 times (#2080) Fix the indexing logic used in selectedColumnsMap to skip fields that have already been seen. The values of selectedColumns map must be indexed relative to fields, not relative to selectFields. --- main_test.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++ scope.go | 6 ++++-- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/main_test.go b/main_test.go index 265e0be7..11c4bb87 100644 --- a/main_test.go +++ b/main_test.go @@ -581,6 +581,60 @@ func TestJoins(t *testing.T) { } } +type JoinedIds struct { + UserID int64 `gorm:"column:id"` + BillingAddressID int64 `gorm:"column:id"` + EmailID int64 `gorm:"column:id"` +} + +func TestScanIdenticalColumnNames(t *testing.T) { + var user = User{ + Name: "joinsIds", + Email: "joinIds@example.com", + BillingAddress: Address{ + Address1: "One Park Place", + }, + Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, + } + DB.Save(&user) + + var users []JoinedIds + DB.Select("users.id, addresses.id, emails.id").Table("users"). + Joins("left join addresses on users.billing_address_id = addresses.id"). + Joins("left join emails on emails.user_id = users.id"). + Where("name = ?", "joinsIds").Scan(&users) + + if len(users) != 2 { + t.Fatal("should find two rows using left join") + } + + if user.Id != users[0].UserID { + t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[0].UserID) + } + if user.Id != users[1].UserID { + t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[1].UserID) + } + + if user.BillingAddressID.Int64 != users[0].BillingAddressID { + t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID) + } + if user.BillingAddressID.Int64 != users[1].BillingAddressID { + t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID) + } + + if users[0].EmailID == users[1].EmailID { + t.Errorf("Email ids should be unique. Got %d and %d", users[0].EmailID, users[1].EmailID) + } + + if int64(user.Emails[0].Id) != users[0].EmailID && int64(user.Emails[1].Id) != users[0].EmailID { + t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[0].EmailID) + } + + if int64(user.Emails[0].Id) != users[1].EmailID && int64(user.Emails[1].Id) != users[1].EmailID { + t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[1].EmailID) + } +} + func TestJoinsWithSelect(t *testing.T) { type result struct { Name string diff --git a/scope.go b/scope.go index 7d6ba1c0..ce80ab86 100644 --- a/scope.go +++ b/scope.go @@ -486,8 +486,10 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { values[index] = &ignored selectFields = fields + offset := 0 if idx, ok := selectedColumnsMap[column]; ok { - selectFields = selectFields[idx+1:] + offset = idx + 1 + selectFields = selectFields[offset:] } for fieldIndex, field := range selectFields { @@ -501,7 +503,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { resetFields[index] = field } - selectedColumnsMap[column] = fieldIndex + selectedColumnsMap[column] = offset + fieldIndex if field.IsNormal { break From 12607e8bdf4a724492d53d8c788edc77ad4439e7 Mon Sep 17 00:00:00 2001 From: kuangzhiqiang Date: Mon, 10 Sep 2018 06:14:05 +0800 Subject: [PATCH 24/34] for go1.11 go mod (#2072) when used go1.11 gomodules the code dir will be `$GOPATH/pkg/mod/github.com/jinzhu/gorm@*/` fileWithLineNum check failed --- utils.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils.go b/utils.go index ad700b98..8489538c 100644 --- a/utils.go +++ b/utils.go @@ -25,8 +25,8 @@ var NowFunc = func() time.Time { var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialismsReplacer *strings.Replacer -var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) -var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`) +var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`) +var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`) func init() { var commonInitialismsForReplacer []string From d3e666a1e086a020905e3f6cf293941806520d97 Mon Sep 17 00:00:00 2001 From: Ikhtiyor <33823221+iahmedov@users.noreply.github.com> Date: Mon, 10 Sep 2018 03:25:26 +0500 Subject: [PATCH 25/34] save_associations:true should store related item (#2067) * save_associations:true should store related item, save_associations priority on related objects * code quality --- callback_save.go | 6 ++-- main_test.go | 88 +++++++++++++++++++++++++++++++++++++++++++++++ migration_test.go | 10 +++++- 3 files changed, 100 insertions(+), 4 deletions(-) diff --git a/callback_save.go b/callback_save.go index ef267141..ebfd0b34 100644 --- a/callback_save.go +++ b/callback_save.go @@ -21,9 +21,7 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea if v, ok := value.(string); ok { v = strings.ToLower(v) - if v == "false" || v != "skip" { - return false - } + return v == "true" } return true @@ -36,9 +34,11 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea if value, ok := scope.Get("gorm:save_associations"); ok { autoUpdate = checkTruth(value) autoCreate = autoUpdate + saveReference = autoUpdate } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok { autoUpdate = checkTruth(value) autoCreate = autoUpdate + saveReference = autoUpdate } if value, ok := scope.Get("gorm:association_autoupdate"); ok { diff --git a/main_test.go b/main_test.go index 11c4bb87..94d2fa39 100644 --- a/main_test.go +++ b/main_test.go @@ -933,6 +933,94 @@ func TestOpenWithOneParameter(t *testing.T) { } } +func TestSaveAssociations(t *testing.T) { + db := DB.New() + deltaAddressCount := 0 + if err := db.Model(&Address{}).Count(&deltaAddressCount).Error; err != nil { + t.Errorf("failed to fetch address count") + t.FailNow() + } + + placeAddress := &Address{ + Address1: "somewhere on earth", + } + ownerAddress1 := &Address{ + Address1: "near place address", + } + ownerAddress2 := &Address{ + Address1: "address2", + } + db.Create(placeAddress) + + addressCountShouldBe := func(t *testing.T, expectedCount int) { + countFromDB := 0 + t.Helper() + err := db.Model(&Address{}).Count(&countFromDB).Error + if err != nil { + t.Error("failed to fetch address count") + } + if countFromDB != expectedCount { + t.Errorf("address count mismatch: %d", countFromDB) + } + } + addressCountShouldBe(t, deltaAddressCount+1) + + // owner address should be created, place address should be reused + place1 := &Place{ + PlaceAddressID: placeAddress.ID, + PlaceAddress: placeAddress, + OwnerAddress: ownerAddress1, + } + err := db.Create(place1).Error + if err != nil { + t.Errorf("failed to store place: %s", err.Error()) + } + addressCountShouldBe(t, deltaAddressCount+2) + + // owner address should be created again, place address should be reused + place2 := &Place{ + PlaceAddressID: placeAddress.ID, + PlaceAddress: &Address{ + ID: 777, + Address1: "address1", + }, + OwnerAddress: ownerAddress2, + OwnerAddressID: 778, + } + err = db.Create(place2).Error + if err != nil { + t.Errorf("failed to store place: %s", err.Error()) + } + addressCountShouldBe(t, deltaAddressCount+3) + + count := 0 + db.Model(&Place{}).Where(&Place{ + PlaceAddressID: placeAddress.ID, + OwnerAddressID: ownerAddress1.ID, + }).Count(&count) + if count != 1 { + t.Errorf("only one instance of (%d, %d) should be available, found: %d", + placeAddress.ID, ownerAddress1.ID, count) + } + + db.Model(&Place{}).Where(&Place{ + PlaceAddressID: placeAddress.ID, + OwnerAddressID: ownerAddress2.ID, + }).Count(&count) + if count != 1 { + t.Errorf("only one instance of (%d, %d) should be available, found: %d", + placeAddress.ID, ownerAddress2.ID, count) + } + + db.Model(&Place{}).Where(&Place{ + PlaceAddressID: placeAddress.ID, + }).Count(&count) + if count != 2 { + t.Errorf("two instances of (%d) should be available, found: %d", + placeAddress.ID, count) + } +} + func TestBlockGlobalUpdate(t *testing.T) { db := DB.New() db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"}) diff --git a/migration_test.go b/migration_test.go index 78555dcc..3fb06648 100644 --- a/migration_test.go +++ b/migration_test.go @@ -118,6 +118,14 @@ type Company struct { Owner *User `sql:"-"` } +type Place struct { + Id int64 + PlaceAddressID int + PlaceAddress *Address `gorm:"save_associations:false"` + OwnerAddressID int + OwnerAddress *Address `gorm:"save_associations:true"` +} + type EncryptedData []byte func (data *EncryptedData) Scan(value interface{}) error { @@ -284,7 +292,7 @@ func runMigration() { DB.Exec(fmt.Sprintf("drop table %v;", table)) } - values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}} + values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}} for _, value := range values { DB.DropTable(value) } From 73e7561e20e8e554ec54463ccbed38e426aad17f Mon Sep 17 00:00:00 2001 From: Aaron Leung Date: Sun, 9 Sep 2018 15:26:29 -0700 Subject: [PATCH 26/34] Use sync.Map for DB.values (#2064) * Replace the regular map with a sync.Map to avoid fatal concurrent map reads/writes * fix the formatting --- main.go | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/main.go b/main.go index 993e19b1..364d8e8e 100644 --- a/main.go +++ b/main.go @@ -22,7 +22,7 @@ type DB struct { logMode int logger logger search *search - values map[string]interface{} + values sync.Map // global db parent *DB @@ -72,7 +72,6 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { db = &DB{ db: dbSQL, logger: defaultLogger, - values: map[string]interface{}{}, callbacks: DefaultCallback, dialect: newDialect(dialect, dbSQL), } @@ -680,13 +679,13 @@ func (s *DB) Set(name string, value interface{}) *DB { // InstantSet instant set setting, will affect current db func (s *DB) InstantSet(name string, value interface{}) *DB { - s.values[name] = value + s.values.Store(name, value) return s } // Get get setting by name func (s *DB) Get(name string) (value interface{}, ok bool) { - value, ok = s.values[name] + value, ok = s.values.Load(name) return } @@ -750,16 +749,16 @@ func (s *DB) clone() *DB { parent: s.parent, logger: s.logger, logMode: s.logMode, - values: map[string]interface{}{}, Value: s.Value, Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate, dialect: newDialect(s.dialect.GetName(), s.db), } - for key, value := range s.values { - db.values[key] = value - } + s.values.Range(func(k, v interface{}) bool { + db.values.Store(k, v) + return true + }) if s.search == nil { db.search = &search{limit: -1, offset: -1} From 012d1479740ec593b0c07f0372e0111c01c3b34a Mon Sep 17 00:00:00 2001 From: maddie Date: Mon, 10 Sep 2018 06:45:55 +0800 Subject: [PATCH 27/34] Improve preload speed (#2058) All credits to @vanjapt who came up with this patch. Closes #1672 --- callback_query_preload.go | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/callback_query_preload.go b/callback_query_preload.go index 481bfbe3..46405c38 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -161,14 +161,17 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) ) if indirectScopeValue.Kind() == reflect.Slice { + foreignValuesToResults := make(map[string]reflect.Value) + for i := 0; i < resultsValue.Len(); i++ { + result := resultsValue.Index(i) + foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames)) + foreignValuesToResults[foreignValues] = result + } for j := 0; j < indirectScopeValue.Len(); j++ { - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { - indirectValue.FieldByName(field.Name).Set(result) - break - } + indirectValue := indirect(indirectScopeValue.Index(j)) + valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames)) + if result, found := foreignValuesToResults[valueString]; found { + indirectValue.FieldByName(field.Name).Set(result) } } } else { @@ -255,13 +258,21 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ indirectScopeValue = scope.IndirectValue() ) + foreignFieldToObjects := make(map[string][]*reflect.Value) + if indirectScopeValue.Kind() == reflect.Slice { + for j := 0; j < indirectScopeValue.Len(); j++ { + object := indirect(indirectScopeValue.Index(j)) + valueString := toString(getValueFromFields(object, relation.ForeignFieldNames)) + foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object) + } + } + for i := 0; i < resultsValue.Len(); i++ { result := resultsValue.Index(i) if indirectScopeValue.Kind() == reflect.Slice { - value := getValueFromFields(result, relation.AssociationForeignFieldNames) - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { + valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames)) + if objects, found := foreignFieldToObjects[valueString]; found { + for _, object := range objects { object.FieldByName(field.Name).Set(result) } } From 26fde9110f932df8cb5cc24396e7a54a6d3a94c2 Mon Sep 17 00:00:00 2001 From: Gustavo Brunoro Date: Sun, 9 Sep 2018 19:47:18 -0300 Subject: [PATCH 28/34] getValueFromFields doesn't panic on nil pointers (#2021) * `IsValid()` won't return `false` for nil pointers unless Value is wrapped in a `reflect.Indirect`. --- utils.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils.go b/utils.go index 8489538c..e58e57a5 100644 --- a/utils.go +++ b/utils.go @@ -206,7 +206,7 @@ func getValueFromFields(value reflect.Value, fieldNames []string) (results []int // as FieldByName could panic if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { for _, fieldName := range fieldNames { - if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { + if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() { result := fieldValue.Interface() if r, ok := result.(driver.Valuer); ok { result, _ = r.Value() From 588b598f9fbf9a0c84b6ec18f617940b045c54d4 Mon Sep 17 00:00:00 2001 From: Phillip Shipley Date: Sun, 9 Sep 2018 18:50:22 -0400 Subject: [PATCH 29/34] Fix issue updating models with foreign key constraints (#1988) * fix update callback to not try to write zero values when field has default value * fix to update callback for gorm tests --- callback_update.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/callback_update.go b/callback_update.go index 373bd726..f6ba0ffd 100644 --- a/callback_update.go +++ b/callback_update.go @@ -76,7 +76,9 @@ func updateCallback(scope *Scope) { for _, field := range scope.Fields() { if scope.changeableField(field) { if !field.IsPrimaryKey && field.IsNormal { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { for _, foreignKey := range relationship.ForeignDBNames { if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { From 282f11af1900a36646b607797273056d76350223 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 9 Sep 2018 19:52:32 -0300 Subject: [PATCH 30/34] Support only preloading (#1926) * add support for only preloading relations on an already populated model * Update callback_query.go comments --- callback_query.go | 5 +++++ main.go | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/callback_query.go b/callback_query.go index ba10cc7d..593e5d30 100644 --- a/callback_query.go +++ b/callback_query.go @@ -18,6 +18,11 @@ func queryCallback(scope *Scope) { if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { return } + + //we are only preloading relations, dont touch base model + if _, skip := scope.InstanceGet("gorm:only_preload"); skip { + return + } defer scope.trace(NowFunc()) diff --git a/main.go b/main.go index 364d8e8e..4dbda61e 100644 --- a/main.go +++ b/main.go @@ -314,6 +314,11 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB { return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } +//Preloads preloads relations, don`t touch out +func (s *DB) Preloads(out interface{}) *DB { + return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db +} + // Scan scan value to a struct func (s *DB) Scan(dest interface{}) *DB { return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db From 123d4f50ef8a8209ee8434daa41c6045a9111864 Mon Sep 17 00:00:00 2001 From: Eyal Posener Date: Mon, 10 Sep 2018 02:11:00 +0300 Subject: [PATCH 31/34] lock TagSettings structure when modified (#1796) The map is modified in different places in the code which results in race conditions on execution. This commit locks the map with read-write lock when it is modified --- callback_query_preload.go | 2 +- callback_save.go | 8 ++--- dialect.go | 10 +++--- dialect_common.go | 2 +- dialect_mysql.go | 22 ++++++------ dialect_postgres.go | 6 ++-- dialect_sqlite3.go | 4 +-- dialects/mssql/mssql.go | 8 ++--- field_test.go | 2 +- main.go | 2 +- model_struct.go | 73 +++++++++++++++++++++++++++------------ scope.go | 12 +++---- 12 files changed, 90 insertions(+), 61 deletions(-) diff --git a/callback_query_preload.go b/callback_query_preload.go index 46405c38..d7c8a133 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -100,7 +100,7 @@ func autoPreload(scope *Scope) { continue } - if val, ok := field.TagSettings["PRELOAD"]; ok { + if val, ok := field.TagSettingsGet("PRELOAD"); ok { if preload, err := strconv.ParseBool(val); err != nil { scope.Err(errors.New("invalid preload option")) return diff --git a/callback_save.go b/callback_save.go index ebfd0b34..3b4e0589 100644 --- a/callback_save.go +++ b/callback_save.go @@ -35,7 +35,7 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea autoUpdate = checkTruth(value) autoCreate = autoUpdate saveReference = autoUpdate - } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok { + } else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok { autoUpdate = checkTruth(value) autoCreate = autoUpdate saveReference = autoUpdate @@ -43,19 +43,19 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea if value, ok := scope.Get("gorm:association_autoupdate"); ok { autoUpdate = checkTruth(value) - } else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok { + } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok { autoUpdate = checkTruth(value) } if value, ok := scope.Get("gorm:association_autocreate"); ok { autoCreate = checkTruth(value) - } else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok { + } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok { autoCreate = checkTruth(value) } if value, ok := scope.Get("gorm:association_save_reference"); ok { saveReference = checkTruth(value) - } else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok { + } else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok { saveReference = checkTruth(value) } } diff --git a/dialect.go b/dialect.go index 506a6e86..27b308af 100644 --- a/dialect.go +++ b/dialect.go @@ -83,7 +83,7 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel // Get redirected field type var ( reflectType = field.Struct.Type - dataType = field.TagSettings["TYPE"] + dataType, _ = field.TagSettingsGet("TYPE") ) for reflectType.Kind() == reflect.Ptr { @@ -112,15 +112,17 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel } // Default Size - if num, ok := field.TagSettings["SIZE"]; ok { + if num, ok := field.TagSettingsGet("SIZE"); ok { size, _ = strconv.Atoi(num) } else { size = 255 } // Default type from tag setting - additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] - if value, ok := field.TagSettings["DEFAULT"]; ok { + notNull, _ := field.TagSettingsGet("NOT NULL") + unique, _ := field.TagSettingsGet("UNIQUE") + additionalType = notNull + " " + unique + if value, ok := field.TagSettingsGet("DEFAULT"); ok { additionalType = additionalType + " DEFAULT " + value } diff --git a/dialect_common.go b/dialect_common.go index b9f0c7da..a479be79 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -39,7 +39,7 @@ func (commonDialect) Quote(key string) string { } func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { - if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { return strings.ToLower(value) != "false" } return field.IsPrimaryKey diff --git a/dialect_mysql.go b/dialect_mysql.go index b162bade..5d63e5cd 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -33,9 +33,9 @@ func (s *mysql) DataTypeOf(field *StructField) string { // MySQL allows only one auto increment column per table, and it must // be a KEY column. - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey { - delete(field.TagSettings, "AUTO_INCREMENT") + if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { + if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey { + field.TagSettingsDelete("AUTO_INCREMENT") } } @@ -45,42 +45,42 @@ func (s *mysql) DataTypeOf(field *StructField) string { sqlType = "boolean" case reflect.Int8: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "tinyint AUTO_INCREMENT" } else { sqlType = "tinyint" } case reflect.Int, reflect.Int16, reflect.Int32: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "int AUTO_INCREMENT" } else { sqlType = "int" } case reflect.Uint8: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "tinyint unsigned AUTO_INCREMENT" } else { sqlType = "tinyint unsigned" } case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "int unsigned AUTO_INCREMENT" } else { sqlType = "int unsigned" } case reflect.Int64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigint AUTO_INCREMENT" } else { sqlType = "bigint" } case reflect.Uint64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigint unsigned AUTO_INCREMENT" } else { sqlType = "bigint unsigned" @@ -96,11 +96,11 @@ func (s *mysql) DataTypeOf(field *StructField) string { case reflect.Struct: if _, ok := dataValue.Interface().(time.Time); ok { precision := "" - if p, ok := field.TagSettings["PRECISION"]; ok { + if p, ok := field.TagSettingsGet("PRECISION"); ok { precision = fmt.Sprintf("(%s)", p) } - if _, ok := field.TagSettings["NOT NULL"]; ok { + if _, ok := field.TagSettingsGet("NOT NULL"); ok { sqlType = fmt.Sprintf("timestamp%v", precision) } else { sqlType = fmt.Sprintf("timestamp%v NULL", precision) diff --git a/dialect_postgres.go b/dialect_postgres.go index c44c6a5b..53d31388 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -34,14 +34,14 @@ func (s *postgres) DataTypeOf(field *StructField) string { sqlType = "boolean" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "serial" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint32, reflect.Uint64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigserial" } else { sqlType = "bigint" @@ -49,7 +49,7 @@ func (s *postgres) DataTypeOf(field *StructField) string { case reflect.Float32, reflect.Float64: sqlType = "numeric" case reflect.String: - if _, ok := field.TagSettings["SIZE"]; !ok { + if _, ok := field.TagSettingsGet("SIZE"); !ok { size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different } diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index f26f6be3..5f96c363 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -29,14 +29,14 @@ func (s *sqlite3) DataTypeOf(field *StructField) string { sqlType = "bool" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "integer primary key autoincrement" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "integer primary key autoincrement" } else { sqlType = "bigint" diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 731721cb..6c424bc1 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -18,7 +18,7 @@ import ( func setIdentityInsert(scope *gorm.Scope) { if scope.Dialect().GetName() == "mssql" { for _, field := range scope.PrimaryFields() { - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank { + if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank { scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) scope.InstanceSet("mssql:identity_insert_on", true) } @@ -70,14 +70,14 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { sqlType = "bit" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "int IDENTITY(1,1)" } else { sqlType = "int" } case reflect.Int64, reflect.Uint64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigint IDENTITY(1,1)" } else { sqlType = "bigint" @@ -116,7 +116,7 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { } func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool { - if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { return value != "FALSE" } return field.IsPrimaryKey diff --git a/field_test.go b/field_test.go index 30e9a778..c3afdff5 100644 --- a/field_test.go +++ b/field_test.go @@ -43,7 +43,7 @@ func TestCalculateField(t *testing.T) { if field, ok := scope.FieldByName("embedded_name"); !ok { t.Errorf("should find embedded field") - } else if _, ok := field.TagSettings["NOT NULL"]; !ok { + } else if _, ok := field.TagSettingsGet("NOT NULL"); !ok { t.Errorf("should find embedded field's tag settings") } } diff --git a/main.go b/main.go index 4dbda61e..17c75ed3 100644 --- a/main.go +++ b/main.go @@ -699,7 +699,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join scope := s.NewScope(source) for _, field := range scope.GetModelStruct().StructFields { if field.Name == column || field.DBName == column { - if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { + if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { source := (&Scope{Value: source}).GetModelStruct().ModelType destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType handler.Setup(field.Relationship, many2many, source, destination) diff --git a/model_struct.go b/model_struct.go index 5b5be618..12860e67 100644 --- a/model_struct.go +++ b/model_struct.go @@ -60,6 +60,30 @@ type StructField struct { Struct reflect.StructField IsForeignKey bool Relationship *Relationship + + tagSettingsLock sync.RWMutex +} + +// TagSettingsSet Sets a tag in the tag settings map +func (s *StructField) TagSettingsSet(key, val string) { + s.tagSettingsLock.Lock() + defer s.tagSettingsLock.Unlock() + s.TagSettings[key] = val +} + +// TagSettingsGet returns a tag from the tag settings +func (s *StructField) TagSettingsGet(key string) (string, bool) { + s.tagSettingsLock.RLock() + defer s.tagSettingsLock.RUnlock() + val, ok := s.TagSettings[key] + return val, ok +} + +// TagSettingsDelete deletes a tag +func (s *StructField) TagSettingsDelete(key string) { + s.tagSettingsLock.Lock() + defer s.tagSettingsLock.Unlock() + delete(s.TagSettings, key) } func (structField *StructField) clone() *StructField { @@ -83,6 +107,9 @@ func (structField *StructField) clone() *StructField { clone.Relationship = &relationship } + // copy the struct field tagSettings, they should be read-locked while they are copied + structField.tagSettingsLock.Lock() + defer structField.tagSettingsLock.Unlock() for key, value := range structField.TagSettings { clone.TagSettings[key] = value } @@ -149,19 +176,19 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // is ignored field - if _, ok := field.TagSettings["-"]; ok { + if _, ok := field.TagSettingsGet("-"); ok { field.IsIgnored = true } else { - if _, ok := field.TagSettings["PRIMARY_KEY"]; ok { + if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok { field.IsPrimaryKey = true modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) } - if _, ok := field.TagSettings["DEFAULT"]; ok { + if _, ok := field.TagSettingsGet("DEFAULT"); ok { field.HasDefaultValue = true } - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsPrimaryKey { + if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey { field.HasDefaultValue = true } @@ -177,8 +204,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if indirectType.Kind() == reflect.Struct { for i := 0; i < indirectType.NumField(); i++ { for key, value := range parseTagSetting(indirectType.Field(i).Tag) { - if _, ok := field.TagSettings[key]; !ok { - field.TagSettings[key] = value + if _, ok := field.TagSettingsGet(key); !ok { + field.TagSettingsSet(key, value) } } } @@ -186,17 +213,17 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } else if _, isTime := fieldValue.(*time.Time); isTime { // is time field.IsNormal = true - } else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { + } else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous { // is embedded struct for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields { subField = subField.clone() subField.Names = append([]string{fieldStruct.Name}, subField.Names...) - if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok { + if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok { subField.DBName = prefix + subField.DBName } if subField.IsPrimaryKey { - if _, ok := subField.TagSettings["PRIMARY_KEY"]; ok { + if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok { modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) } else { subField.IsPrimaryKey = false @@ -227,13 +254,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { elemType = field.Struct.Type ) - if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { foreignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { associationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { associationForeignKeys = strings.Split(foreignKey, ",") } @@ -242,13 +269,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if elemType.Kind() == reflect.Struct { - if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { + if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { relationship.Kind = "many_to_many" { // Foreign Keys for Source joinTableDBNames := []string{} - if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" { joinTableDBNames = strings.Split(foreignKey, ",") } @@ -279,7 +306,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { { // Foreign Keys for Association (Destination) associationJoinTableDBNames := []string{} - if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" { associationJoinTableDBNames = strings.Split(foreignKey, ",") } @@ -317,7 +344,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { var toFields = toScope.GetStructFields() relationship.Kind = "has_many" - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { // Dog has many toys, tag polymorphic is Owner, then associationType is Owner // Toy use OwnerID, OwnerType ('dogs') as foreign key if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { @@ -325,7 +352,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName // if Dog has multiple set of toys set name of the set (instead of default 'dogs') - if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { + if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { relationship.PolymorphicValue = value } else { relationship.PolymorphicValue = scope.TableName() @@ -407,17 +434,17 @@ func (scope *Scope) GetModelStruct() *ModelStruct { tagAssociationForeignKeys []string ) - if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { tagForeignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { tagAssociationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { tagAssociationForeignKeys = strings.Split(foreignKey, ",") } - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { // Cat has one toy, tag polymorphic is Owner, then associationType is Owner // Toy use OwnerID, OwnerType ('cats') as foreign key if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { @@ -425,7 +452,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName // if Cat has several different types of toys set name for each (instead of default 'cats') - if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { + if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { relationship.PolymorphicValue = value } else { relationship.PolymorphicValue = scope.TableName() @@ -563,7 +590,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Even it is ignored, also possible to decode db value into the field - if value, ok := field.TagSettings["COLUMN"]; ok { + if value, ok := field.TagSettingsGet("COLUMN"); ok { field.DBName = value } else { field.DBName = ToColumnName(fieldStruct.Name) diff --git a/scope.go b/scope.go index ce80ab86..fa521ca2 100644 --- a/scope.go +++ b/scope.go @@ -1115,8 +1115,8 @@ func (scope *Scope) createJoinTable(field *StructField) { if field, ok := scope.FieldByName(fieldName); ok { foreignKeyStruct := field.clone() foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" - delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT") + foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") + foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) } @@ -1126,8 +1126,8 @@ func (scope *Scope) createJoinTable(field *StructField) { if field, ok := toScope.FieldByName(fieldName); ok { foreignKeyStruct := field.clone() foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" - delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT") + foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") + foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) } @@ -1262,7 +1262,7 @@ func (scope *Scope) autoIndex() *Scope { var uniqueIndexes = map[string][]string{} for _, field := range scope.GetStructFields() { - if name, ok := field.TagSettings["INDEX"]; ok { + if name, ok := field.TagSettingsGet("INDEX"); ok { names := strings.Split(name, ",") for _, name := range names { @@ -1273,7 +1273,7 @@ func (scope *Scope) autoIndex() *Scope { } } - if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok { + if name, ok := field.TagSettingsGet("UNIQUE_INDEX"); ok { names := strings.Split(name, ",") for _, name := range names { From 5be9bd34135805e0332b993378864b159784d8a8 Mon Sep 17 00:00:00 2001 From: ch3rub1m Date: Fri, 14 Sep 2018 15:53:49 +0800 Subject: [PATCH 32/34] Rollback transaction when a panic happens in callback (#2093) --- scope.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/scope.go b/scope.go index fa521ca2..378025bd 100644 --- a/scope.go +++ b/scope.go @@ -855,6 +855,14 @@ func (scope *Scope) inlineCondition(values ...interface{}) *Scope { } func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { + defer func() { + if err := recover(); err != nil { + if db, ok := scope.db.db.(sqlTx); ok { + db.Rollback() + } + panic(err) + } + }() for _, f := range funcs { (*f)(scope) if scope.skipLeft { From f6260a00852946a10a57e8bb9f505f19bc9389b7 Mon Sep 17 00:00:00 2001 From: Artemij Shepelev Date: Sat, 22 Sep 2018 14:59:11 +0300 Subject: [PATCH 33/34] Second part of the defaultTableName field race fix (#2060) * fix (https://github.com/jinzhu/gorm/issues/1407) * changed map with mutex to sync.Map (https://github.com/jinzhu/gorm/issues/1407) * removed newModelStructsMap func * commit to rerun pipeline, comment changed * fix race with defaultTableName field (again) --- model_struct.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/model_struct.go b/model_struct.go index 12860e67..8c27e209 100644 --- a/model_struct.go +++ b/model_struct.go @@ -24,11 +24,16 @@ type ModelStruct struct { PrimaryFields []*StructField StructFields []*StructField ModelType reflect.Type + defaultTableName string + l sync.Mutex } // TableName returns model's table name func (s *ModelStruct) TableName(db *DB) string { + s.l.Lock() + defer s.l.Unlock() + if s.defaultTableName == "" && db != nil && s.ModelType != nil { // Set default table name if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { From 742154be9a26e849f02d296073c077e0a7c23828 Mon Sep 17 00:00:00 2001 From: "Iskander (Alex) Sharipov" Date: Sun, 7 Oct 2018 03:49:37 +0300 Subject: [PATCH 34/34] rewrite if-else chain as switch statement (#2121) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit From effective Go: https://golang.org/doc/effective_go.html#switch > It's therefore possible—and idiomatic—to write an if-else-if-else chain as a switch. --- association.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/association.go b/association.go index 8c6d9864..1b7744b5 100644 --- a/association.go +++ b/association.go @@ -267,15 +267,16 @@ func (association *Association) Count() int { query = scope.DB() ) - if relationship.Kind == "many_to_many" { + switch relationship.Kind { + case "many_to_many": query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { + case "has_many", "has_one": primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) query = query.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)..., ) - } else if relationship.Kind == "belongs_to" { + case "belongs_to": primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) query = query.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),