From 56fffcb25b6e63540dcc2071ae653daed016105e Mon Sep 17 00:00:00 2001 From: Code Date: Tue, 29 Aug 2017 18:50:40 +0800 Subject: [PATCH 001/881] =?UTF-8?q?fix=20count()=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit COUNT()函数逻辑有错误,本应该是在执行任何SQL的时候,都可以返回正确的行数。而现在复杂的SQL集合无法正确获取行数。 --- scope.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/scope.go b/scope.go index fda7f653..6b8ce53f 100644 --- a/scope.go +++ b/scope.go @@ -950,7 +950,12 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope { func (scope *Scope) count(value interface{}) *Scope { if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) { - scope.Search.Select("count(*)") + if len(scope.Search.group) != 0 { + scope.Search.Select("count(*) FROM ( SELECT count(*) ") + scope.Search.group += " ) AS count" + } else { + scope.Search.Select("count(*)") + } } scope.Search.ignoreOrderQuery = true scope.Err(scope.row().Scan(value)) From 26262ef9bb897b06d4e7ad6f1316e1037e030283 Mon Sep 17 00:00:00 2001 From: Wing Gao Date: Tue, 28 Nov 2017 13:05:10 +0800 Subject: [PATCH 002/881] autoIndex should throw an error on failed --- scope.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/scope.go b/scope.go index 51ebd5a0..f1b9da4b 100644 --- a/scope.go +++ b/scope.go @@ -1228,11 +1228,19 @@ func (scope *Scope) autoIndex() *Scope { } for name, columns := range indexes { - scope.NewDB().Model(scope.Value).AddIndex(name, columns...) + db := scope.NewDB().Model(scope.Value).AddIndex(name, columns...) + if db.Error != nil { + scope.db.Error = db.Error + return scope + } } for name, columns := range uniqueIndexes { - scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) + db := scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) + if db.Error != nil { + scope.db.Error = db.Error + return scope + } } return scope From 2ff44ee8d72785386e42e11f637ac8fa816cc4ca Mon Sep 17 00:00:00 2001 From: s-takehana Date: Wed, 31 Jan 2018 17:32:36 +0900 Subject: [PATCH 003/881] Fix regex in BuildForeignKeyName #1681 (#1728) --- dialect_common.go | 2 +- dialect_mysql.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dialect_common.go b/dialect_common.go index a99627f2..7d0c3ce7 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -146,7 +146,7 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string { keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest) - keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") + keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_") return keyName } diff --git a/dialect_mysql.go b/dialect_mysql.go index 6fcd0079..686ad1ee 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -166,8 +166,8 @@ func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { h.Write([]byte(keyName)) bs := h.Sum(nil) - // sha1 is 40 digits, keep first 24 characters of destination - destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_")) + // sha1 is 40 characters, keep first 24 characters of destination + destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(dest, "_")) if len(destRunes) > 24 { destRunes = destRunes[:24] } From a2c7c4b63f2ba7da2ae6428269bfc43efd29a4e8 Mon Sep 17 00:00:00 2001 From: rightjoin Date: Wed, 31 Jan 2018 14:38:03 +0530 Subject: [PATCH 004/881] UID should come before UI in common abbreviations (#1678) This will fix the following issue https://github.com/jinzhu/gorm/issues/1460 --- utils.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils.go b/utils.go index 97a3d175..dfaae939 100644 --- a/utils.go +++ b/utils.go @@ -23,7 +23,7 @@ var NowFunc = func() time.Time { } // Copied from golint -var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} +var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialismsReplacer *strings.Replacer var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) From b9035a7602b7734076ac4a3146fc88d285e326a5 Mon Sep 17 00:00:00 2001 From: s-takehana Date: Wed, 31 Jan 2018 17:32:36 +0900 Subject: [PATCH 005/881] Fix regex in BuildForeignKeyName #1681 (#1728) --- dialect_common.go | 2 +- dialect_mysql.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dialect_common.go b/dialect_common.go index a99627f2..7d0c3ce7 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -146,7 +146,7 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string { keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest) - keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") + keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_") return keyName } diff --git a/dialect_mysql.go b/dialect_mysql.go index 6fcd0079..686ad1ee 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -166,8 +166,8 @@ func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { h.Write([]byte(keyName)) bs := h.Sum(nil) - // sha1 is 40 digits, keep first 24 characters of destination - destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_")) + // sha1 is 40 characters, keep first 24 characters of destination + destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(dest, "_")) if len(destRunes) > 24 { destRunes = destRunes[:24] } From 630c12b54936a0b20a6ddf8a35dab18279165dd8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 31 Jan 2018 17:14:21 +0800 Subject: [PATCH 006/881] Refactor #1693 --- scope.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/scope.go b/scope.go index f1b9da4b..0a7e8861 100644 --- a/scope.go +++ b/scope.go @@ -1228,18 +1228,14 @@ func (scope *Scope) autoIndex() *Scope { } for name, columns := range indexes { - db := scope.NewDB().Model(scope.Value).AddIndex(name, columns...) - if db.Error != nil { - scope.db.Error = db.Error - return scope + if db := scope.NewDB().Model(scope.Value).AddIndex(name, columns...); db.Error != nil { + scope.db.AddError(db.Error) } } for name, columns := range uniqueIndexes { - db := scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) - if db.Error != nil { - scope.db.Error = db.Error - return scope + if db := scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil { + scope.db.AddError(db.Error) } } From cbc3d3cd509bee9f1c0d6f03bf02ff91e9dd47dd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 31 Jan 2018 18:16:20 +0800 Subject: [PATCH 007/881] Add go report card --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 44eb4a69..e904ef80 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm) [![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921) [![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) From ca46ec0770003aab3c0ed7d7b336643362221c21 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 31 Jan 2018 18:22:30 +0800 Subject: [PATCH 008/881] Smaller image --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e904ef80..e5c21dc5 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Supporting the project -[![http://patreon.com/jinzhu](http://patreon_public_assets.s3.amazonaws.com/sized/becomeAPatronBanner.png)](http://patreon.com/jinzhu) +[![http://patreon.com/jinzhu](https://c5.patreon.com/external/logo/become_a_patron_button.png)](http://patreon.com/jinzhu) ## Author From 802104cc7cfe58153cccc9bc76e5b9078296c16b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 2 Feb 2018 22:01:31 +0800 Subject: [PATCH 009/881] Use BuildKeyName to build db's index name --- dialect.go | 4 ++-- dialect_common.go | 4 ++-- dialect_mysql.go | 6 +++--- scope.go | 7 ++++--- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/dialect.go b/dialect.go index e879588b..9d3be249 100644 --- a/dialect.go +++ b/dialect.go @@ -41,8 +41,8 @@ type Dialect interface { // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` LastInsertIDReturningSuffix(tableName, columnName string) string - // BuildForeignKeyName returns a foreign key name for the given table, field and reference - BuildForeignKeyName(tableName, field, dest string) string + // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference + BuildKeyName(kind, tableName string, fields ...string) string // CurrentDatabase return current database name CurrentDatabase() string diff --git a/dialect_common.go b/dialect_common.go index 7d0c3ce7..ef351f9e 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -144,8 +144,8 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s return "" } -func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string { - keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest) +func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { + keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_") return keyName } diff --git a/dialect_mysql.go b/dialect_mysql.go index 686ad1ee..d2fd53ca 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -157,8 +157,8 @@ func (mysql) SelectFromDummyTable() string { return "FROM DUAL" } -func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { - keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest) +func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string { + keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...) if utf8.RuneCountInString(keyName) <= 64 { return keyName } @@ -167,7 +167,7 @@ func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { bs := h.Sum(nil) // sha1 is 40 characters, keep first 24 characters of destination - destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(dest, "_")) + destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_")) if len(destRunes) > 24 { destRunes = destRunes[:24] } diff --git a/scope.go b/scope.go index 0a7e8861..c447d8a0 100644 --- a/scope.go +++ b/scope.go @@ -1165,7 +1165,8 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { } func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { - keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest) + // Compatible with old generated key + keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return @@ -1209,7 +1210,7 @@ func (scope *Scope) autoIndex() *Scope { for _, name := range names { if name == "INDEX" || name == "" { - name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) + name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName) } indexes[name] = append(indexes[name], field.DBName) } @@ -1220,7 +1221,7 @@ func (scope *Scope) autoIndex() *Scope { for _, name := range names { if name == "UNIQUE_INDEX" || name == "" { - name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) + name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName) } uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) } From 57f031e08380b8b76252ccbb6a0bc21c85b28a7d Mon Sep 17 00:00:00 2001 From: Piyush Mishra Date: Fri, 2 Feb 2018 22:29:40 +0530 Subject: [PATCH 010/881] Use table name to guess current database if none is given --- dialect_common.go | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/dialect_common.go b/dialect_common.go index ef351f9e..9ccff6e9 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -92,7 +92,8 @@ func (s *commonDialect) DataTypeOf(field *StructField) string { func (s commonDialect) HasIndex(tableName string, indexName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count) + currentDatabase, tableName := s.currentDatabaseAndTable(tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) return count > 0 } @@ -107,13 +108,25 @@ func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bo func (s commonDialect) HasTable(tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count) + currentDatabase, tableName := s.currentDatabaseAndTable(tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) return count > 0 } +func (s commonDialect) currentDatabaseAndTable(tableName string) (string, string) { + currentDatabase := s.CurrentDatabase() + if currentDatabase == "" && strings.Contains(tableName, ".") { + splitStrings := strings.SplitN(tableName, ".", 2) + currentDatabase = splitStrings[0] + tableName = splitStrings[1] + } + return currentDatabase, tableName +} + func (s commonDialect) HasColumn(tableName string, columnName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) + currentDatabase, tableName := s.currentDatabaseAndTable(tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) return count > 0 } From 87fc1b24737a885147240041293603eceb844356 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 3 Feb 2018 20:27:19 +0800 Subject: [PATCH 011/881] Refactor PR #1751 --- dialect.go | 8 ++++++++ dialect_common.go | 17 ++++------------- dialect_mysql.go | 3 ++- dialects/mssql/mssql.go | 14 ++++++++++++-- 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/dialect.go b/dialect.go index 9d3be249..90b1723f 100644 --- a/dialect.go +++ b/dialect.go @@ -114,3 +114,11 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel return fieldValue, dataType, size, strings.TrimSpace(additionalType) } + +func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) { + if strings.Contains(tableName, ".") { + splitStrings := strings.SplitN(tableName, ".", 2) + return splitStrings[0], splitStrings[1] + } + return dialect.CurrentDatabase(), tableName +} diff --git a/dialect_common.go b/dialect_common.go index 9ccff6e9..30f035a5 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -92,7 +92,7 @@ func (s *commonDialect) DataTypeOf(field *StructField) string { func (s commonDialect) HasIndex(tableName string, indexName string) bool { var count int - currentDatabase, tableName := s.currentDatabaseAndTable(tableName) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) return count > 0 } @@ -108,24 +108,14 @@ func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bo func (s commonDialect) HasTable(tableName string) bool { var count int - currentDatabase, tableName := s.currentDatabaseAndTable(tableName) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) return count > 0 } -func (s commonDialect) currentDatabaseAndTable(tableName string) (string, string) { - currentDatabase := s.CurrentDatabase() - if currentDatabase == "" && strings.Contains(tableName, ".") { - splitStrings := strings.SplitN(tableName, ".", 2) - currentDatabase = splitStrings[0] - tableName = splitStrings[1] - } - return currentDatabase, tableName -} - func (s commonDialect) HasColumn(tableName string, columnName string) bool { var count int - currentDatabase, tableName := s.currentDatabaseAndTable(tableName) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) return count > 0 } @@ -157,6 +147,7 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s return "" } +// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_") diff --git a/dialect_mysql.go b/dialect_mysql.go index d2fd53ca..f4858e10 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -144,7 +144,8 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) return count > 0 } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index de2ae7ca..a4f8e87c 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -128,13 +128,15 @@ func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { func (s mssql) HasTable(tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count) return count > 0 } func (s mssql) HasColumn(tableName string, columnName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) return count > 0 } @@ -168,3 +170,11 @@ func (mssql) SelectFromDummyTable() string { func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } + +func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { + if strings.Contains(tableName, ".") { + splitStrings := strings.SplitN(tableName, ".", 2) + return splitStrings[0], splitStrings[1] + } + return dialect.CurrentDatabase(), tableName +} From 3f98904fe72ef13a4add9c051dbff5509e233679 Mon Sep 17 00:00:00 2001 From: Louis Tran Date: Thu, 8 Feb 2018 16:21:39 -0800 Subject: [PATCH 012/881] Update PULL_REQUEST_TEMPLATE.md, A vs. An (#1757) Only a small change. `a` agreement => `an` agreement --- .github/PULL_REQUEST_TEMPLATE.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 0ee0d73b..4923abdc 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -6,7 +6,7 @@ Make sure these boxes checked before submitting your pull request. - [] Write good commit message, try to squash your commits into a single one - [] Run `./build.sh` in `gh-pages` branch for document changes -For significant changes like big bug fixes, new features, please open an issue to make a agreement on an implementation design/plan first before starting it. +For significant changes like big bug fixes, new features, please open an issue to make an agreement on an implementation design/plan first before starting it. Thank you. From 48e41440afa6a741a3e345f2cfbabca08f6fb1ac Mon Sep 17 00:00:00 2001 From: Adrian Heng Date: Fri, 9 Feb 2018 08:22:30 +0800 Subject: [PATCH 013/881] Allow for proper table creation with Jsonb fields (#1758) * DataTypeOf should now correctly identify dataValues that are 'json.RawMessage' types as 'jsonb' columns * move the json check to its own function * ran gofmt and did some minor tweaks to satisfy CodeClimate --- dialect_postgres.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/dialect_postgres.go b/dialect_postgres.go index 6fdf4df1..3bcea536 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -1,6 +1,7 @@ package gorm import ( + "encoding/json" "fmt" "reflect" "strings" @@ -68,9 +69,14 @@ func (s *postgres) DataTypeOf(field *StructField) string { default: if IsByteArrayOrSlice(dataValue) { sqlType = "bytea" + if isUUID(dataValue) { sqlType = "uuid" } + + if isJSON(dataValue) { + sqlType = "jsonb" + } } } } @@ -130,3 +136,8 @@ func isUUID(value reflect.Value) bool { lower := strings.ToLower(typename) return "uuid" == lower || "guid" == lower } + +func isJSON(value reflect.Value) bool { + _, ok := value.Interface().(json.RawMessage) + return ok +} From 38f96c65140f00f0b15efc495a487cfd5db510b8 Mon Sep 17 00:00:00 2001 From: daisy1754 Date: Fri, 9 Feb 2018 05:59:33 -0800 Subject: [PATCH 014/881] Add handling for empty Jsonb to fix #1649 (#1650) --- dialects/postgres/postgres.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index b8e76891..1d0dcb60 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -63,6 +63,9 @@ type Jsonb struct { // Value get value of Jsonb func (j Jsonb) Value() (driver.Value, error) { + if len(j.RawMessage) == 0 { + return nil, nil + } return j.MarshalJSON() } From 0e1cb6ece9d27b56ee6c1e514987175bba94711b Mon Sep 17 00:00:00 2001 From: Amit Yadav <154998+ayadav@users.noreply.github.com> Date: Fri, 9 Feb 2018 19:50:26 +0530 Subject: [PATCH 015/881] Add support to remove foreign key constraints (#1686) --- main.go | 8 ++++++++ scope.go | 10 ++++++++++ 2 files changed, 18 insertions(+) diff --git a/main.go b/main.go index 16fa0b79..b23ae2f2 100644 --- a/main.go +++ b/main.go @@ -611,6 +611,14 @@ func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate return scope.db } +// RemoveForeignKey Remove foreign key from the given scope, e.g: +// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") +func (s *DB) RemoveForeignKey(field string, dest string) *DB { + scope := s.clone().NewScope(s.Value) + scope.removeForeignKey(field, dest) + return scope.db +} + // Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode func (s *DB) Association(column string) *Association { var err error diff --git a/scope.go b/scope.go index c447d8a0..4c404b38 100644 --- a/scope.go +++ b/scope.go @@ -1175,6 +1175,16 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() } +func (scope *Scope) removeForeignKey(field string, dest string) { + keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest) + + if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { + return + } + var query = `ALTER TABLE %s DROP CONSTRAINT %s;` + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() +} + func (scope *Scope) removeIndex(indexName string) { scope.Dialect().RemoveIndex(scope.TableName(), indexName) } From e9309d361f8777f861997089ce142744109e1aa2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 9 Feb 2018 22:34:59 +0800 Subject: [PATCH 016/881] Fix build exception --- scope.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope.go b/scope.go index 4c404b38..0ef087bc 100644 --- a/scope.go +++ b/scope.go @@ -1176,7 +1176,7 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on } func (scope *Scope) removeForeignKey(field string, dest string) { - keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest) + keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest) if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return From 89a726ce5da26da893dd3c2d8475e1d66677fd9c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 9 Feb 2018 22:58:34 +0800 Subject: [PATCH 017/881] Move ModifyColumn implemention to Dialect --- dialect.go | 2 ++ dialect_common.go | 5 +++++ dialect_mysql.go | 5 +++++ dialects/mssql/mssql.go | 5 +++++ migration_test.go | 5 +---- scope.go | 2 +- 6 files changed, 19 insertions(+), 5 deletions(-) diff --git a/dialect.go b/dialect.go index 90b1723f..fe8e2f62 100644 --- a/dialect.go +++ b/dialect.go @@ -33,6 +33,8 @@ type Dialect interface { HasTable(tableName string) bool // HasColumn check has column or not HasColumn(tableName string, columnName string) bool + // ModifyColumn modify column's type + ModifyColumn(tableName string, columnName string, typ string) error // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case LimitAndOffsetSQL(limit, offset interface{}) string diff --git a/dialect_common.go b/dialect_common.go index 30f035a5..06d0bd07 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -120,6 +120,11 @@ func (s commonDialect) HasColumn(tableName string, columnName string) bool { return count > 0 } +func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ)) + return err +} + func (s commonDialect) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DATABASE()").Scan(&name) return diff --git a/dialect_mysql.go b/dialect_mysql.go index f4858e10..b9887a5c 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -127,6 +127,11 @@ func (s mysql) RemoveIndex(tableName string, indexName string) error { return err } +func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ)) + return err +} + func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { if limit != nil { if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index a4f8e87c..10a779de 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -140,6 +140,11 @@ func (s mssql) HasColumn(tableName string, columnName string) bool { return count > 0 } +func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ)) + return err +} + func (s mssql) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) return diff --git a/migration_test.go b/migration_test.go index 3f3a5c8f..6b4470a6 100644 --- a/migration_test.go +++ b/migration_test.go @@ -435,10 +435,7 @@ func TestMultipleIndexes(t *testing.T) { } func TestModifyColumnType(t *testing.T) { - dialect := os.Getenv("GORM_DIALECT") - if dialect != "postgres" && - dialect != "mysql" && - dialect != "mssql" { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" { t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type") } diff --git a/scope.go b/scope.go index 0ef087bc..a10cb3a2 100644 --- a/scope.go +++ b/scope.go @@ -1139,7 +1139,7 @@ func (scope *Scope) dropTable() *Scope { } func (scope *Scope) modifyColumn(column string, typ string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() + scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ)) } func (scope *Scope) dropColumn(column string) { From ae696d051fdd183c27ca75f7aa13bde5649b7264 Mon Sep 17 00:00:00 2001 From: miyauchi Date: Fri, 20 Oct 2017 10:24:09 +0900 Subject: [PATCH 018/881] corresponds timestamp precision for mysql --- dialect_mysql.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dialect_mysql.go b/dialect_mysql.go index b9887a5c..573bfc0f 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -96,9 +96,9 @@ func (s *mysql) DataTypeOf(field *StructField) string { case reflect.Struct: if _, ok := dataValue.Interface().(time.Time); ok { if _, ok := field.TagSettings["NOT NULL"]; ok { - sqlType = "timestamp" + sqlType = fmt.Sprintf("timestamp(%d)", size) } else { - sqlType = "timestamp NULL" + sqlType = fmt.Sprintf("timestamp(%d) NULL", size) } } default: From 8d4e3e5a832d78a11ea13bb1166569095238cfd0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 9 Feb 2018 23:18:47 +0800 Subject: [PATCH 019/881] Use tag PRECISION to set time's precision for mysql --- dialect_mysql.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/dialect_mysql.go b/dialect_mysql.go index 573bfc0f..fee61819 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -95,10 +95,15 @@ func (s *mysql) DataTypeOf(field *StructField) string { } case reflect.Struct: if _, ok := dataValue.Interface().(time.Time); ok { + precision := "" + if p, ok := field.TagSettings["PRECISION"]; ok { + precision = fmt.Sprintf("(%s)", p) + } + if _, ok := field.TagSettings["NOT NULL"]; ok { - sqlType = fmt.Sprintf("timestamp(%d)", size) + sqlType = fmt.Sprintf("timestamp%v", precision) } else { - sqlType = fmt.Sprintf("timestamp(%d) NULL", size) + sqlType = fmt.Sprintf("timestamp%v NULL", precision) } } default: From ec72a4cb6b0fc60c2dda9ab842416b17ae4b3ad7 Mon Sep 17 00:00:00 2001 From: Geoff Baskwill Date: Fri, 9 Feb 2018 10:22:53 -0500 Subject: [PATCH 020/881] Call Query callback chain when preloading many2many (#1622) When using `Preload` on a `many2many` association, the `Query` callback chain was not being called. This made it difficult to write a plugin that could reliably get called regardless of how objects were being queried. Now `handleManyToManyPreload` will call the `Query` callback chain for each object that is retrieved by following the association. Since the data has already been read by the `handleManyToManyPreload` method, a new scope setting called `gorm:skip_queryCallback` is set to `true` before calling the callbacks. Callbacks can check for the presence of this setting if they should not be run; the default `queryCallback` is an example of this case. Fixes jinzhu/gorm#1621. --- callback_query.go | 4 ++++ callback_query_preload.go | 4 ++++ preload_test.go | 40 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+) diff --git a/callback_query.go b/callback_query.go index 20e88161..f9940880 100644 --- a/callback_query.go +++ b/callback_query.go @@ -15,6 +15,10 @@ func init() { // queryCallback used to query data from database func queryCallback(scope *Scope) { + if _, skip := scope.Get("gorm:skip_query_callback"); skip { + return + } + defer scope.trace(NowFunc()) var ( diff --git a/callback_query_preload.go b/callback_query_preload.go index 21ab22ce..f2a218c7 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -324,6 +324,10 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface scope.scan(rows, columns, append(fields, joinTableFields...)) + scope.New(elem.Addr().Interface()). + Set("gorm:skip_query_callback", true). + callCallbacks(scope.db.parent.callbacks.queries) + var foreignKeys = make([]interface{}, len(sourceKeys)) // generate hashed forkey keys in join table for idx, joinTableField := range joinTableFields { diff --git a/preload_test.go b/preload_test.go index 1b89e77b..66f2629b 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1627,6 +1627,46 @@ func TestPrefixedPreloadDuplication(t *testing.T) { } } +func TestPreloadManyToManyCallbacks(t *testing.T) { + type ( + Level2 struct { + ID uint + } + Level1 struct { + ID uint + Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"` + } + ) + + DB.DropTableIfExists("level1_level2s") + DB.DropTableIfExists(new(Level1)) + DB.DropTableIfExists(new(Level2)) + + if err := DB.AutoMigrate(new(Level1), new(Level2)).Error; err != nil { + t.Error(err) + } + + lvl := Level1{ + Level2s: []Level2{ + Level2{}, + }, + } + DB.Save(&lvl) + + called := 0 + + DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) { + called = called + 1 + }) + + found := Level1{ID: lvl.ID} + DB.Preload("Level2s").First(&found, &found) + + if called != 2 { + t.Errorf("Wanted callback to be called 2 times but got %d", called) + } +} + func toJSONString(v interface{}) []byte { r, _ := json.MarshalIndent(v, "", " ") return r From 77eb925ea09471b7082d9d5749b2c96be726eac2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 00:07:16 +0800 Subject: [PATCH 021/881] Refactor preloading many2many for auto preload --- callback_query.go | 2 +- callback_query_preload.go | 5 ++++- preload_test.go | 14 ++++++++------ 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/callback_query.go b/callback_query.go index f9940880..ba10cc7d 100644 --- a/callback_query.go +++ b/callback_query.go @@ -15,7 +15,7 @@ func init() { // queryCallback used to query data from database func queryCallback(scope *Scope) { - if _, skip := scope.Get("gorm:skip_query_callback"); skip { + if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { return } diff --git a/callback_query_preload.go b/callback_query_preload.go index f2a218c7..30f6b585 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -10,6 +10,9 @@ import ( // preloadCallback used to preload associations func preloadCallback(scope *Scope) { + if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { + return + } if _, ok := scope.Get("gorm:auto_preload"); ok { autoPreload(scope) @@ -325,7 +328,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface scope.scan(rows, columns, append(fields, joinTableFields...)) scope.New(elem.Addr().Interface()). - Set("gorm:skip_query_callback", true). + InstanceSet("gorm:skip_query_callback", true). callCallbacks(scope.db.parent.callbacks.queries) var foreignKeys = make([]interface{}, len(sourceKeys)) diff --git a/preload_test.go b/preload_test.go index 66f2629b..311ad0be 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1630,10 +1630,12 @@ func TestPrefixedPreloadDuplication(t *testing.T) { func TestPreloadManyToManyCallbacks(t *testing.T) { type ( Level2 struct { - ID uint + ID uint + Name string } Level1 struct { ID uint + Name string Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"` } ) @@ -1647,8 +1649,9 @@ func TestPreloadManyToManyCallbacks(t *testing.T) { } lvl := Level1{ + Name: "l1", Level2s: []Level2{ - Level2{}, + Level2{Name: "l2-1"}, Level2{Name: "l2-2"}, }, } DB.Save(&lvl) @@ -1659,11 +1662,10 @@ func TestPreloadManyToManyCallbacks(t *testing.T) { called = called + 1 }) - found := Level1{ID: lvl.ID} - DB.Preload("Level2s").First(&found, &found) + DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID) - if called != 2 { - t.Errorf("Wanted callback to be called 2 times but got %d", called) + if called != 3 { + t.Errorf("Wanted callback to be called 3 times but got %d", called) } } From 97495a5e4067bc254ac33ce0e54c0af97a8c35d5 Mon Sep 17 00:00:00 2001 From: Wing Gao Date: Fri, 13 Oct 2017 15:08:55 +0800 Subject: [PATCH 022/881] Add new tag "not_auto_increment" to set a column can auto increase or not --- dialect_common.go | 14 ++++++++++++-- dialect_mysql.go | 12 ++++++------ dialect_postgres.go | 4 ++-- dialect_sqlite3.go | 4 ++-- 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/dialect_common.go b/dialect_common.go index 06d0bd07..64d720db 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -38,6 +38,16 @@ func (commonDialect) Quote(key string) string { return fmt.Sprintf(`"%s"`, key) } +func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { + // add a new tag "NOT_AUTO_INCREMENT" + _, not := field.TagSettings["NOT_AUTO_INCREMENT"] + if not { + return false + } + _, ok := field.TagSettings["AUTO_INCREMENT"] + return ok || field.IsPrimaryKey +} + func (s *commonDialect) DataTypeOf(field *StructField) string { var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) @@ -46,13 +56,13 @@ func (s *commonDialect) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "BOOLEAN" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if s.fieldCanAutoIncrement(field) { sqlType = "INTEGER AUTO_INCREMENT" } else { sqlType = "INTEGER" } case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if s.fieldCanAutoIncrement(field) { sqlType = "BIGINT AUTO_INCREMENT" } else { sqlType = "BIGINT" diff --git a/dialect_mysql.go b/dialect_mysql.go index fee61819..1feed1f6 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -44,42 +44,42 @@ func (s *mysql) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "boolean" case reflect.Int8: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "tinyint AUTO_INCREMENT" } else { sqlType = "tinyint" } case reflect.Int, reflect.Int16, reflect.Int32: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "int AUTO_INCREMENT" } else { sqlType = "int" } case reflect.Uint8: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "tinyint unsigned AUTO_INCREMENT" } else { sqlType = "tinyint unsigned" } case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "int unsigned AUTO_INCREMENT" } else { sqlType = "int unsigned" } case reflect.Int64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigint AUTO_INCREMENT" } else { sqlType = "bigint" } case reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigint unsigned AUTO_INCREMENT" } else { diff --git a/dialect_postgres.go b/dialect_postgres.go index 3bcea536..c44c6a5b 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -33,14 +33,14 @@ func (s *postgres) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "boolean" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "serial" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint32, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigserial" } else { diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index de9c05cb..f26f6be3 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -28,14 +28,14 @@ func (s *sqlite3) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "bool" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "integer primary key autoincrement" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint64: - if field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "integer primary key autoincrement" } else { From 2c68f695c3de3b05f31e0f4c0132a19e236a0f23 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 08:24:39 +0800 Subject: [PATCH 023/881] Set AutoIncrement to false with tag --- dialect_common.go | 9 +++------ main_test.go | 4 +++- test_all.sh | 2 +- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/dialect_common.go b/dialect_common.go index 64d720db..1e5e3b61 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -39,13 +39,10 @@ func (commonDialect) Quote(key string) string { } func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { - // add a new tag "NOT_AUTO_INCREMENT" - _, not := field.TagSettings["NOT_AUTO_INCREMENT"] - if not { - return false + if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + return value != "FALSE" } - _, ok := field.TagSettings["AUTO_INCREMENT"] - return ok || field.IsPrimaryKey + return field.IsPrimaryKey } func (s *commonDialect) DataTypeOf(field *StructField) string { diff --git a/main_test.go b/main_test.go index 34f96a86..499324bc 100644 --- a/main_test.go +++ b/main_test.go @@ -72,8 +72,10 @@ func OpenTestConnection() (db *gorm.DB, err error) { // db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) // db.SetLogger(log.New(os.Stdout, "\r\n", 0)) - if os.Getenv("DEBUG") == "true" { + if debug := os.Getenv("DEBUG"); debug == "true" { db.LogMode(true) + } else if debug == "false" { + db.LogMode(false) } db.DB().SetMaxIdleConns(10) diff --git a/test_all.sh b/test_all.sh index 80b319bf..5cfb3321 100755 --- a/test_all.sh +++ b/test_all.sh @@ -1,5 +1,5 @@ dialects=("postgres" "mysql" "mssql" "sqlite") for dialect in "${dialects[@]}" ; do - GORM_DIALECT=${dialect} go test + DEBUG=false GORM_DIALECT=${dialect} go test done From ae509ab23743e034b8c4e1d0d72d60a31ac7f6fd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 08:30:05 +0800 Subject: [PATCH 024/881] Port AUTO_INCREMENT false support to mssql --- dialects/mssql/mssql.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 10a779de..1c735a84 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -65,14 +65,14 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { case reflect.Bool: sqlType = "bit" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "int IDENTITY(1,1)" } else { sqlType = "int" } case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigint IDENTITY(1,1)" } else { @@ -111,6 +111,13 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } +func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool { + if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + return value != "FALSE" + } + return field.IsPrimaryKey +} + func (s mssql) HasIndex(tableName string, indexName string) bool { var count int s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) From e0f9087c8d67b035172c15aabe1953aae4293d9f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 11:06:43 +0800 Subject: [PATCH 025/881] Setup test env --- docker-compose.yml | 30 ++++++++++++++++++++++++++++++ main_test.go | 26 +++++++++++--------------- wercker.yml | 17 +++++++++++++++-- 3 files changed, 56 insertions(+), 17 deletions(-) create mode 100644 docker-compose.yml diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..79bf5fc3 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,30 @@ +version: '3' + +services: + mysql: + image: 'mysql:latest' + ports: + - 9910:3306 + environment: + - MYSQL_DATABASE=gorm + - MYSQL_USER=gorm + - MYSQL_PASSWORD=gorm + - MYSQL_RANDOM_ROOT_PASSWORD="yes" + postgres: + image: 'postgres:latest' + ports: + - 9920:5432 + environment: + - POSTGRES_USER=gorm + - POSTGRES_DB=gorm + - POSTGRES_PASSWORD=gorm + mssql: + image: 'mcmoe/mssqldocker:latest' + ports: + - 9930:1433 + environment: + - ACCEPT_EULA=Y + - SA_PASSWORD=LoremIpsum86 + - MSSQL_DB=gorm + - MSSQL_USER=gorm + - MSSQL_PASSWORD=LoremIpsum86 diff --git a/main_test.go b/main_test.go index 499324bc..83e6f7aa 100644 --- a/main_test.go +++ b/main_test.go @@ -36,27 +36,20 @@ func init() { } func OpenTestConnection() (db *gorm.DB, err error) { + dbDSN := os.Getenv("GORM_DSN") switch os.Getenv("GORM_DIALECT") { case "mysql": - // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; - // CREATE DATABASE gorm; - // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; fmt.Println("testing mysql...") - dbhost := os.Getenv("GORM_DBADDRESS") - if dbhost != "" { - dbhost = fmt.Sprintf("tcp(%v)", dbhost) + if dbDSN == "" { + dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" } - db, err = gorm.Open("mysql", fmt.Sprintf("gorm:gorm@%v/gorm?charset=utf8&parseTime=True", dbhost)) + db, err = gorm.Open("mysql", dbDSN) case "postgres": fmt.Println("testing postgres...") - dbhost := os.Getenv("GORM_DBHOST") - if dbhost != "" { - dbhost = fmt.Sprintf("host=%v ", dbhost) + if dbDSN == "" { + dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" } - db, err = gorm.Open("postgres", fmt.Sprintf("%vuser=gorm password=gorm DB.name=gorm sslmode=disable", dbhost)) - case "foundation": - fmt.Println("testing foundation...") - db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable") + db, err = gorm.Open("postgres", dbDSN) case "mssql": // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; // CREATE DATABASE gorm; @@ -64,7 +57,10 @@ func OpenTestConnection() (db *gorm.DB, err error) { // CREATE USER gorm FROM LOGIN gorm; // sp_changedbowner 'gorm'; fmt.Println("testing mssql...") - db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm") + if dbDSN == "" { + dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + } + db, err = gorm.Open("mssql", dbDSN) default: fmt.Println("testing sqlite3...") db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) diff --git a/wercker.yml b/wercker.yml index ff6fb17c..c3045c54 100644 --- a/wercker.yml +++ b/wercker.yml @@ -13,6 +13,14 @@ services: POSTGRES_USER: gorm POSTGRES_PASSWORD: gorm POSTGRES_DB: gorm + - name: mssql + id: mcmoe/mssqldocker: + env: + ACCEPT_EULA: Y + SA_PASSWORD: LoremIpsum86 + MSSQL_DB: gorm + MSSQL_USER: gorm + MSSQL_PASSWORD: LoremIpsum86 # The steps that will be executed in the build pipeline build: @@ -45,9 +53,14 @@ build: - script: name: test mysql code: | - GORM_DIALECT=mysql GORM_DBADDRESS=mariadb:3306 go test ./... + GORM_DIALECT=mysql GORM_DSN=gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True go test ./... - script: name: test postgres code: | - GORM_DIALECT=postgres GORM_DBHOST=postgres go test ./... + GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test mssql + code: | + GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test ./... From 2e5d98a42020e99e9270e5caa9125b9de2dc56e8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 11:45:38 +0800 Subject: [PATCH 026/881] Update wercker.yml --- wercker.yml | 102 +++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 98 insertions(+), 4 deletions(-) diff --git a/wercker.yml b/wercker.yml index c3045c54..2f2370b3 100644 --- a/wercker.yml +++ b/wercker.yml @@ -2,19 +2,73 @@ box: golang services: - - id: mariadb:10.0 + - name: mariadb + id: mariadb:latest env: MYSQL_DATABASE: gorm MYSQL_USER: gorm MYSQL_PASSWORD: gorm MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - id: postgres + - name: mysql + id: mysql:8 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql57 + id: mysql:5.7 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql56 + id: mysql:5.6 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql55 + id: mysql:5.5 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: postgres + id: postgres:latest + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres96 + id: postgres:9.6 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres95 + id: postgres:9.5 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres94 + id: postgres:9.4 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres93 + id: postgres:9.3 env: POSTGRES_USER: gorm POSTGRES_PASSWORD: gorm POSTGRES_DB: gorm - name: mssql - id: mcmoe/mssqldocker: + id: mcmoe/mssqldocker:latest env: ACCEPT_EULA: Y SA_PASSWORD: LoremIpsum86 @@ -50,16 +104,56 @@ build: code: | go test ./... + - script: + name: test mariadb + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test ./... + - script: name: test mysql code: | - GORM_DIALECT=mysql GORM_DSN=gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True go test ./... + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test ./... + + - script: + name: test mysql5.7 + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test ./... + + - script: + name: test mysql5.6 + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test ./... + + - script: + name: test mysql5.5 + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test ./... - script: name: test postgres code: | GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + - script: + name: test postgres96 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test postgres95 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test postgres94 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test postgres93 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + - script: name: test mssql code: | From 706b8f55da67c097aede7662a45c9ae577ea3ed9 Mon Sep 17 00:00:00 2001 From: Giuseppe Date: Sat, 10 Feb 2018 05:28:01 +0100 Subject: [PATCH 027/881] Use brackets for quoting (#1736) --- dialects/mssql/mssql.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 1c735a84..1dd5fb69 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -54,7 +54,7 @@ func (mssql) BindVar(i int) string { } func (mssql) Quote(key string) string { - return fmt.Sprintf(`"%s"`, key) + return fmt.Sprintf(`[%s]`, key) } func (s *mssql) DataTypeOf(field *gorm.StructField) string { From 21fb3ae1febe4581f80a4d5633f3fffd6d10a606 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 13:15:04 +0800 Subject: [PATCH 028/881] Simplify GitHub templates --- .github/ISSUE_TEMPLATE.md | 21 ++++++--------------- .github/PULL_REQUEST_TEMPLATE.md | 5 ----- README.md | 2 +- 3 files changed, 7 insertions(+), 21 deletions(-) diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index 8b4f03b7..a0b64bfa 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -1,10 +1,4 @@ -Before posting a bug report about a problem, please try to verify that it is a bug and that it has not been reported already, please apply corresponding GitHub labels to the issue, for feature requests, please apply `type:feature`. - -DON'T post usage related questions, ask in https://gitter.im/jinzhu/gorm or http://stackoverflow.com/questions/tagged/go-gorm, - -Please answer these questions before submitting your issue. Thanks! - - +Your issue may already be reported! Please search on the [issue track](https://github.com/jinzhu/gorm/issues) before creating one. ### What version of Go are you using (`go version`)? @@ -12,9 +6,9 @@ Please answer these questions before submitting your issue. Thanks! ### Which database and its version are you using? -### What did you do? +### Please provide a complete runnable program to reproduce your issue. **IMPORTANT** -Please provide a complete runnable program to reproduce your issue. +Need to runnable with [GORM's docker compose config](https://github.com/jinzhu/gorm/blob/master/docker-compose.yml) or please provides your config. ```go package main @@ -32,10 +26,9 @@ var db *gorm.DB func init() { var err error db, err = gorm.Open("sqlite3", "test.db") - // Please use below username, password as your database's account for the script. - // db, err = gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") - // db, err = gorm.Open("mysql", "gorm:gorm@/dbname?charset=utf8&parseTime=True") - // db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm") + // db, err = gorm.Open("postgres", "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable") + // db, err = gorm.Open("mysql", "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True") + // db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm") if err != nil { panic(err) } @@ -43,8 +36,6 @@ func init() { } func main() { - // your code here - if /* failure condition */ { fmt.Println("failed") } else { diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 4923abdc..b467b6ce 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -3,12 +3,7 @@ Make sure these boxes checked before submitting your pull request. - [] Do only one thing - [] No API-breaking changes - [] New code/logic commented & tested -- [] Write good commit message, try to squash your commits into a single one -- [] Run `./build.sh` in `gh-pages` branch for document changes For significant changes like big bug fixes, new features, please open an issue to make an agreement on an implementation design/plan first before starting it. -Thank you. - - ### What did this pull request do? diff --git a/README.md b/README.md index e5c21dc5..8c6e2302 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm) -[![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921) +[![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b) [![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) ## Overview From aa3fd6de13fee7e0ae715eeaad3bc2f329db2366 Mon Sep 17 00:00:00 2001 From: Jess Smith Date: Sat, 10 Feb 2018 01:26:01 -0500 Subject: [PATCH 029/881] Sort column names before generating SQL in `DB.UpdateColumns` (#1734) --- callback_update.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/callback_update.go b/callback_update.go index 6948439f..373bd726 100644 --- a/callback_update.go +++ b/callback_update.go @@ -3,6 +3,7 @@ package gorm import ( "errors" "fmt" + "sort" "strings" ) @@ -59,7 +60,16 @@ func updateCallback(scope *Scope) { var sqls []string if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - for column, value := range updateAttrs.(map[string]interface{}) { + // Sort the column names so that the generated SQL is the same every time. + updateMap := updateAttrs.(map[string]interface{}) + var columns []string + for c := range updateMap { + columns = append(columns, c) + } + sort.Strings(columns) + + for _, column := range columns { + value := updateMap[column] sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) } } else { From 9235b47ea28d816ef25d6bf4e037ccb5c7c7096b Mon Sep 17 00:00:00 2001 From: joe-at-startupmedia Date: Wed, 4 Oct 2017 08:19:16 +0000 Subject: [PATCH 030/881] Allows foreign keys to be saved without saving the assoication when specified #1628 --- callback_save.go | 53 ++++++++++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/callback_save.go b/callback_save.go index f4bc918e..ad4eda2f 100644 --- a/callback_save.go +++ b/callback_save.go @@ -11,35 +11,34 @@ func commitOrRollbackTransactionCallback(scope *Scope) { } func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) { - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { - if relationship := field.Relationship; relationship != nil { - return true, relationship - } - } - } - return false, nil + if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { + if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { + return true, field.Relationship + } + return false, field.Relationship + } + return false, nil } func saveBeforeAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } - for _, field := range scope.Fields() { - if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() - scope.Err(scope.NewDB().Save(fieldValue).Error) - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) - } - } - } - } - } + for _, field := range scope.Fields() { + ok, relationship := saveFieldAsAssociation(scope, field); + if relationship != nil && relationship.Kind == "belongs_to" { + fieldValue := field.Field.Addr().Interface() + if ok && scope.shouldSaveAssociations() { + scope.Err(scope.NewDB().Save(fieldValue).Error) + } + if len(relationship.ForeignFieldNames) != 0 { + // set value's foreign key + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + } + } + } + } + } } func saveAfterAssociationsCallback(scope *Scope) { @@ -47,7 +46,7 @@ func saveAfterAssociationsCallback(scope *Scope) { return } for _, field := range scope.Fields() { - if ok, relationship := saveFieldAsAssociation(scope, field); ok && + if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { value := field.Field From 9f409820dfdfc2ab7cb20a56d4cefdf1a111c315 Mon Sep 17 00:00:00 2001 From: joe-at-startupmedia Date: Tue, 10 Oct 2017 18:20:56 +0000 Subject: [PATCH 031/881] Formatting code with gomt --- callback_save.go | 50 ++++++++++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/callback_save.go b/callback_save.go index ad4eda2f..fa32c907 100644 --- a/callback_save.go +++ b/callback_save.go @@ -11,34 +11,34 @@ func commitOrRollbackTransactionCallback(scope *Scope) { } func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) { - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { - return true, field.Relationship - } - return false, field.Relationship - } - return false, nil + if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { + if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { + return true, field.Relationship + } + return false, field.Relationship + } + return false, nil } func saveBeforeAssociationsCallback(scope *Scope) { - for _, field := range scope.Fields() { - ok, relationship := saveFieldAsAssociation(scope, field); - if relationship != nil && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() - if ok && scope.shouldSaveAssociations() { - scope.Err(scope.NewDB().Save(fieldValue).Error) - } - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) - } - } - } - } - } + for _, field := range scope.Fields() { + ok, relationship := saveFieldAsAssociation(scope, field) + if relationship != nil && relationship.Kind == "belongs_to" { + fieldValue := field.Field.Addr().Interface() + if ok && scope.shouldSaveAssociations() { + scope.Err(scope.NewDB().Save(fieldValue).Error) + } + if len(relationship.ForeignFieldNames) != 0 { + // set value's foreign key + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + } + } + } + } + } } func saveAfterAssociationsCallback(scope *Scope) { From 63cb513b4978a49870ff20d27fb18c721f64d977 Mon Sep 17 00:00:00 2001 From: Ezequiel Muns Date: Wed, 1 Nov 2017 18:45:08 +0100 Subject: [PATCH 032/881] Tests for saving foreign key when save_associations:false --- association_test.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/association_test.go b/association_test.go index c84f84ed..f37047d1 100644 --- a/association_test.go +++ b/association_test.go @@ -902,6 +902,20 @@ func TestSkipSaveAssociation(t *testing.T) { DB.Save(&User{Name: "jinzhu", Company: Company{Name: "skip_save_association"}}) if !DB.Where("name = ?", "skip_save_association").First(&Company{}).RecordNotFound() { - t.Errorf("Company skip_save_association should not been saved") + t.Errorf("Company skip_save_association should not have been saved") + } + + // if foreign key is set, this should be saved even if association isn't + company := Company{Name: "skip_save_association"} + DB.Save(&company) + company.Name = "skip_save_association_modified" + user := User{Name: "jinzhu", CompanyID: company.ID, Company: company} + DB.Save(&user) + + if !DB.Where("name = ?", "skip_save_association_modified").First(&Company{}).RecordNotFound() { + t.Errorf("Company skip_save_association should not have been updated") + } + if DB.Where("id = ? AND company_id = ?", user.ID, company.ID).First(&User{}).RecordNotFound() { + t.Errorf("User's foreign key should have been saved") } } From 43dc867644b879f8f87fd0598ac0b459232d9293 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 15:16:20 +0800 Subject: [PATCH 033/881] Allow save association relations w/o saving association --- association_test.go | 2 +- callback_save.go | 31 ++++++++++++++++++------------- scope.go | 2 +- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/association_test.go b/association_test.go index f37047d1..34822dbc 100644 --- a/association_test.go +++ b/association_test.go @@ -909,7 +909,7 @@ func TestSkipSaveAssociation(t *testing.T) { company := Company{Name: "skip_save_association"} DB.Save(&company) company.Name = "skip_save_association_modified" - user := User{Name: "jinzhu", CompanyID: company.ID, Company: company} + user := User{Name: "jinzhu", Company: company} DB.Save(&user) if !DB.Where("name = ?", "skip_save_association_modified").First(&Company{}).RecordNotFound() { diff --git a/callback_save.go b/callback_save.go index fa32c907..544354d0 100644 --- a/callback_save.go +++ b/callback_save.go @@ -12,22 +12,25 @@ func commitOrRollbackTransactionCallback(scope *Scope) { func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) { if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { - return true, field.Relationship + if field.Relationship != nil { + if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; (!ok || (value != "false" && value != "skip")) && scope.allowSaveAssociations() { + return true, field.Relationship + } + return false, field.Relationship } - return false, field.Relationship } return false, nil } func saveBeforeAssociationsCallback(scope *Scope) { for _, field := range scope.Fields() { - ok, relationship := saveFieldAsAssociation(scope, field) - if relationship != nil && relationship.Kind == "belongs_to" { + if allowSaveAssociation, relationship := saveFieldAsAssociation(scope, field); relationship != nil && relationship.Kind == "belongs_to" { fieldValue := field.Field.Addr().Interface() - if ok && scope.shouldSaveAssociations() { + + if allowSaveAssociation { scope.Err(scope.NewDB().Save(fieldValue).Error) } + if len(relationship.ForeignFieldNames) != 0 { // set value's foreign key for idx, fieldName := range relationship.ForeignFieldNames { @@ -42,11 +45,8 @@ func saveBeforeAssociationsCallback(scope *Scope) { } func saveAfterAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } for _, field := range scope.Fields() { - if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship != nil && + if allowSaveAssociation, relationship := saveFieldAsAssociation(scope, field); relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { value := field.Field @@ -70,9 +70,11 @@ func saveAfterAssociationsCallback(scope *Scope) { scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) } - scope.Err(newDB.Save(elem).Error) + if allowSaveAssociation { + scope.Err(newDB.Save(elem).Error) + } - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { + if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil && !newScope.PrimaryKeyZero() { scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) } } @@ -91,7 +93,10 @@ func saveAfterAssociationsCallback(scope *Scope) { if relationship.PolymorphicType != "" { scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) } - scope.Err(scope.NewDB().Save(elem).Error) + + if allowSaveAssociation { + scope.Err(scope.NewDB().Save(elem).Error) + } } } } diff --git a/scope.go b/scope.go index a10cb3a2..9ae33913 100644 --- a/scope.go +++ b/scope.go @@ -993,7 +993,7 @@ func (scope *Scope) changeableField(field *Field) bool { return true } -func (scope *Scope) shouldSaveAssociations() bool { +func (scope *Scope) allowSaveAssociations() bool { if saveAssociations, ok := scope.Get("gorm:save_associations"); ok { if v, ok := saveAssociations.(bool); ok && !v { return false From b2b568daa8e27966c39c942e5aefc74bcc8af88d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 16:47:48 +0800 Subject: [PATCH 034/881] Add tag association_autoupdate, association_autocreate, association_save_reference support --- association_test.go | 147 +++++++++++++++++++++++++++++++++++++++--- callback_save.go | 153 +++++++++++++++++++++++++++++++------------- query_test.go | 2 +- scope.go | 12 ---- 4 files changed, 248 insertions(+), 66 deletions(-) diff --git a/association_test.go b/association_test.go index 34822dbc..60d0cf48 100644 --- a/association_test.go +++ b/association_test.go @@ -885,7 +885,7 @@ func TestHasManyChildrenWithOneStruct(t *testing.T) { DB.Save(&category) } -func TestSkipSaveAssociation(t *testing.T) { +func TestAutoSaveBelongsToAssociation(t *testing.T) { type Company struct { gorm.Model Name string @@ -895,27 +895,156 @@ func TestSkipSaveAssociation(t *testing.T) { gorm.Model Name string CompanyID uint - Company Company `gorm:"save_associations:false"` + Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"` } + + DB.Where("name = ?", "auto_save_association").Delete(&Company{}) DB.AutoMigrate(&Company{}, &User{}) - DB.Save(&User{Name: "jinzhu", Company: Company{Name: "skip_save_association"}}) + DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_association"}}) - if !DB.Where("name = ?", "skip_save_association").First(&Company{}).RecordNotFound() { - t.Errorf("Company skip_save_association should not have been saved") + if !DB.Where("name = ?", "auto_save_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_association should not have been saved when autosave is false") } // if foreign key is set, this should be saved even if association isn't - company := Company{Name: "skip_save_association"} + company := Company{Name: "auto_save_association"} DB.Save(&company) - company.Name = "skip_save_association_modified" + + company.Name = "auto_save_association_new_name" user := User{Name: "jinzhu", Company: company} + DB.Save(&user) - if !DB.Where("name = ?", "skip_save_association_modified").First(&Company{}).RecordNotFound() { - t.Errorf("Company skip_save_association should not have been updated") + if !DB.Where("name = ?", "auto_save_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") } + if DB.Where("id = ? AND company_id = ?", user.ID, company.ID).First(&User{}).RecordNotFound() { t.Errorf("User's foreign key should have been saved") } + + user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_association_2"}} + DB.Set("gorm:association_autocreate", true).Save(&user2) + if DB.Where("name = ?", "auto_save_association_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_association_2 should been created when autocreate is true") + } + + user2.Company.Name = "auto_save_association_2_newname" + DB.Set("gorm:association_autoupdate", true).Save(&user2) + + if DB.Where("name = ?", "auto_save_association_2_newname").First(&Company{}).RecordNotFound() { + t.Errorf("Company should been updated") + } +} + +func TestAutoSaveHasOneAssociation(t *testing.T) { + type Company struct { + gorm.Model + UserID uint + Name string + } + + type User struct { + gorm.Model + Name string + Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"` + } + + DB.Where("name = ?", "auto_save_has_one_association").Delete(&Company{}) + DB.AutoMigrate(&Company{}, &User{}) + + DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_has_one_association"}}) + + if !DB.Where("name = ?", "auto_save_has_one_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_has_one_association should not have been saved when autosave is false") + } + + company := Company{Name: "auto_save_has_one_association"} + DB.Save(&company) + + company.Name = "auto_save_has_one_association_new_name" + user := User{Name: "jinzhu", Company: company} + + DB.Save(&user) + + if !DB.Where("name = ?", "auto_save_has_one_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if !DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association", user.ID).First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if user.Company.UserID == 0 { + t.Errorf("UserID should be assigned") + } + + company.Name = "auto_save_has_one_association_2_new_name" + DB.Set("gorm:association_autoupdate", true).Save(&user) + + if DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association_new_name", user.ID).First(&Company{}).RecordNotFound() { + t.Errorf("Company should been updated") + } + + user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_has_one_association_2"}} + DB.Set("gorm:association_autocreate", true).Save(&user2) + if DB.Where("name = ?", "auto_save_has_one_association_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_has_one_association_2 should been created when autocreate is true") + } +} + +func TestAutoSaveMany2ManyAssociation(t *testing.T) { + type Company struct { + gorm.Model + Name string + } + + type User struct { + gorm.Model + Name string + Companies []Company `gorm:"many2many:user_companies;association_autoupdate:false;association_autocreate:false;"` + } + + DB.AutoMigrate(&Company{}, &User{}) + + DB.Save(&User{Name: "jinzhu", Companies: []Company{{Name: "auto_save_m2m_association"}}}) + + if !DB.Where("name = ?", "auto_save_m2m_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_m2m_association should not have been saved when autosave is false") + } + + company := Company{Name: "auto_save_m2m_association"} + DB.Save(&company) + + company.Name = "auto_save_m2m_association_new_name" + user := User{Name: "jinzhu", Companies: []Company{company, {Name: "auto_save_m2m_association_new_name_2"}}} + + DB.Save(&user) + + if !DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if !DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not been created") + } + + if DB.Model(&user).Association("Companies").Count() != 1 { + t.Errorf("Relationship should been saved") + } + + DB.Set("gorm:association_autoupdate", true).Set("gorm:association_autocreate", true).Save(&user) + + if DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should been updated") + } + + if DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company should been created") + } + + if DB.Model(&user).Association("Companies").Count() != 2 { + t.Errorf("Relationship should been updated") + } } diff --git a/callback_save.go b/callback_save.go index 544354d0..243c986e 100644 --- a/callback_save.go +++ b/callback_save.go @@ -1,6 +1,9 @@ package gorm -import "reflect" +import ( + "reflect" + "strings" +) func beginTransactionCallback(scope *Scope) { scope.Begin() @@ -10,33 +13,79 @@ func commitOrRollbackTransactionCallback(scope *Scope) { scope.CommitOrRollback() } -func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) { - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if field.Relationship != nil { - if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; (!ok || (value != "false" && value != "skip")) && scope.allowSaveAssociations() { - return true, field.Relationship +func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) { + checkTruth := func(value interface{}) bool { + if v, ok := value.(bool); ok && !v { + return false + } + + if v, ok := value.(string); ok { + v = strings.ToLower(v) + if v == "false" || v != "skip" { + return false + } + } + + return true + } + + if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { + if r = field.Relationship; r != nil { + autoUpdate, autoCreate, saveReference = true, true, true + + if value, ok := scope.Get("gorm:save_associations"); ok { + autoUpdate = checkTruth(value) + autoCreate = autoUpdate + } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok { + autoUpdate = checkTruth(value) + autoCreate = autoUpdate + } + + if value, ok := scope.Get("gorm:association_autoupdate"); ok { + autoUpdate = checkTruth(value) + } else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok { + autoUpdate = checkTruth(value) + } + + if value, ok := scope.Get("gorm:association_autocreate"); ok { + autoCreate = checkTruth(value) + } else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok { + autoCreate = checkTruth(value) + } + + if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok { + saveReference = checkTruth(value) } - return false, field.Relationship } } - return false, nil + + return } func saveBeforeAssociationsCallback(scope *Scope) { for _, field := range scope.Fields() { - if allowSaveAssociation, relationship := saveFieldAsAssociation(scope, field); relationship != nil && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() + autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) - if allowSaveAssociation { + if relationship != nil && relationship.Kind == "belongs_to" { + fieldValue := field.Field.Addr().Interface() + newScope := scope.New(fieldValue) + + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(scope.NewDB().Save(fieldValue).Error) + } + } else if autoUpdate { scope.Err(scope.NewDB().Save(fieldValue).Error) } - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + if saveReference { + if len(relationship.ForeignFieldNames) != 0 { + // set value's foreign key + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + } } } } @@ -46,8 +95,9 @@ func saveBeforeAssociationsCallback(scope *Scope) { func saveAfterAssociationsCallback(scope *Scope) { for _, field := range scope.Fields() { - if allowSaveAssociation, relationship := saveFieldAsAssociation(scope, field); relationship != nil && - (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { + autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) + + if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { value := field.Field switch value.Kind() { @@ -57,7 +107,41 @@ func saveAfterAssociationsCallback(scope *Scope) { elem := value.Index(i).Addr().Interface() newScope := newDB.NewScope(elem) - if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { + if saveReference { + if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.FieldByName(associationForeignName); ok { + scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) + } + } + } + + if relationship.PolymorphicType != "" { + scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) + } + } + + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(newDB.Save(elem).Error) + } + } else if autoUpdate { + scope.Err(newDB.Save(elem).Error) + } + + if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference { + if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { + scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) + } + } + } + default: + elem := value.Addr().Interface() + newScope := scope.New(elem) + + if saveReference { + if len(relationship.ForeignFieldNames) != 0 { for idx, fieldName := range relationship.ForeignFieldNames { associationForeignName := relationship.AssociationForeignDBNames[idx] if f, ok := scope.FieldByName(associationForeignName); ok { @@ -69,32 +153,13 @@ func saveAfterAssociationsCallback(scope *Scope) { if relationship.PolymorphicType != "" { scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) } - - if allowSaveAssociation { - scope.Err(newDB.Save(elem).Error) - } - - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil && !newScope.PrimaryKeyZero() { - scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) - } - } - default: - elem := value.Addr().Interface() - newScope := scope.New(elem) - if len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } } - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - - if allowSaveAssociation { + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(scope.NewDB().Save(elem).Error) + } + } else if autoUpdate { scope.Err(scope.NewDB().Save(elem).Error) } } diff --git a/query_test.go b/query_test.go index def84e04..98721800 100644 --- a/query_test.go +++ b/query_test.go @@ -389,7 +389,7 @@ func TestOffset(t *testing.T) { DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) } var users1, users2, users3, users4 []User - DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) + DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { t.Errorf("Offset should work") diff --git a/scope.go b/scope.go index 9ae33913..125e02b0 100644 --- a/scope.go +++ b/scope.go @@ -993,18 +993,6 @@ func (scope *Scope) changeableField(field *Field) bool { return true } -func (scope *Scope) allowSaveAssociations() bool { - if saveAssociations, ok := scope.Get("gorm:save_associations"); ok { - if v, ok := saveAssociations.(bool); ok && !v { - return false - } - if v, ok := saveAssociations.(string); ok && (v != "skip") { - return false - } - } - return true && !scope.HasError() -} - func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { toScope := scope.db.NewScope(value) tx := scope.db.Set("gorm:association:source", scope.Value) From 2940c553eb9763e966effbdca702e2d5b2b255da Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 18:01:41 +0800 Subject: [PATCH 035/881] Add DB setting gorm:association_save_reference --- callback_save.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/callback_save.go b/callback_save.go index 243c986e..ef267141 100644 --- a/callback_save.go +++ b/callback_save.go @@ -53,7 +53,9 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea autoCreate = checkTruth(value) } - if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok { + if value, ok := scope.Get("gorm:association_save_reference"); ok { + saveReference = checkTruth(value) + } else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok { saveReference = checkTruth(value) } } From c6ce739b2a4d3b26af9326a31723883b4f136a74 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 19:25:58 +0800 Subject: [PATCH 036/881] Convert auto_increment's value to lower case when checking its value --- dialect_common.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dialect_common.go b/dialect_common.go index 1e5e3b61..fbbaef33 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -40,7 +40,7 @@ func (commonDialect) Quote(key string) string { func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - return value != "FALSE" + return strings.ToLower(value) != "false" } return field.IsPrimaryKey } From c0359226dc500354fd8c18366ad2fb6616f8c322 Mon Sep 17 00:00:00 2001 From: Emil Davtyan Date: Sat, 10 Feb 2018 12:31:55 +0100 Subject: [PATCH 037/881] Removed unnecessary cloning. (#1462) `NewScope` clones `DB` no need to chain a call to clone with `NewScope`. --- main.go | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/main.go b/main.go index b23ae2f2..fc4859ac 100644 --- a/main.go +++ b/main.go @@ -274,7 +274,7 @@ func (s *DB) Assign(attrs ...interface{}) *DB { // First find first record that match given conditions, order by primary key func (s *DB) First(out interface{}, where ...interface{}) *DB { - newScope := s.clone().NewScope(out) + newScope := s.NewScope(out) newScope.Search.Limit(1) return newScope.Set("gorm:order_by_primary_key", "ASC"). inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db @@ -282,7 +282,7 @@ func (s *DB) First(out interface{}, where ...interface{}) *DB { // Last find last record that match given conditions, order by primary key func (s *DB) Last(out interface{}, where ...interface{}) *DB { - newScope := s.clone().NewScope(out) + newScope := s.NewScope(out) newScope.Search.Limit(1) return newScope.Set("gorm:order_by_primary_key", "DESC"). inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db @@ -290,12 +290,12 @@ func (s *DB) Last(out interface{}, where ...interface{}) *DB { // Find find records that match given conditions func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db + return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } // Scan scan value to a struct func (s *DB) Scan(dest interface{}) *DB { - return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db + return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db } // Row return `*sql.Row` with given conditions @@ -311,8 +311,8 @@ func (s *DB) Rows() (*sql.Rows, error) { // ScanRows scan `*sql.Rows` to give struct func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { var ( - clone = s.clone() - scope = clone.NewScope(result) + scope = s.NewScope(result) + clone = scope.db columns, err = rows.Columns() ) @@ -337,7 +337,7 @@ func (s *DB) Count(value interface{}) *DB { // Related get related associations func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.clone().NewScope(s.Value).related(value, foreignKeys...).db + return s.NewScope(s.Value).related(value, foreignKeys...).db } // FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) @@ -377,7 +377,7 @@ func (s *DB) Update(attrs ...interface{}) *DB { // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { - return s.clone().NewScope(s.Value). + return s.NewScope(s.Value). Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). InstanceSet("gorm:update_interface", values). callCallbacks(s.parent.callbacks.updates).db @@ -390,7 +390,7 @@ func (s *DB) UpdateColumn(attrs ...interface{}) *DB { // UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (s *DB) UpdateColumns(values interface{}) *DB { - return s.clone().NewScope(s.Value). + return s.NewScope(s.Value). Set("gorm:update_column", true). Set("gorm:save_associations", false). InstanceSet("gorm:update_interface", values). @@ -399,7 +399,7 @@ func (s *DB) UpdateColumns(values interface{}) *DB { // Save update value in database, if the value doesn't have primary key, will insert it func (s *DB) Save(value interface{}) *DB { - scope := s.clone().NewScope(value) + scope := s.NewScope(value) if !scope.PrimaryKeyZero() { newDB := scope.callCallbacks(s.parent.callbacks.updates).db if newDB.Error == nil && newDB.RowsAffected == 0 { @@ -412,13 +412,13 @@ func (s *DB) Save(value interface{}) *DB { // Create insert the value into database func (s *DB) Create(value interface{}) *DB { - scope := s.clone().NewScope(value) + scope := s.NewScope(value) return scope.callCallbacks(s.parent.callbacks.creates).db } // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition func (s *DB) Delete(value interface{}, where ...interface{}) *DB { - return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db + return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db } // Raw use raw sql as conditions, won't run it unless invoked by other methods @@ -429,7 +429,7 @@ func (s *DB) Raw(sql string, values ...interface{}) *DB { // Exec execute raw sql func (s *DB) Exec(sql string, values ...interface{}) *DB { - scope := s.clone().NewScope(nil) + scope := s.NewScope(nil) generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") scope.Raw(generatedSQL) @@ -495,7 +495,7 @@ func (s *DB) Rollback() *DB { // NewRecord check if value's primary key is blank func (s *DB) NewRecord(value interface{}) bool { - return s.clone().NewScope(value).PrimaryKeyZero() + return s.NewScope(value).PrimaryKeyZero() } // RecordNotFound check if returning ErrRecordNotFound error @@ -544,7 +544,7 @@ func (s *DB) DropTableIfExists(values ...interface{}) *DB { // HasTable check has table or not func (s *DB) HasTable(value interface{}) bool { var ( - scope = s.clone().NewScope(value) + scope = s.NewScope(value) tableName string ) @@ -570,14 +570,14 @@ func (s *DB) AutoMigrate(values ...interface{}) *DB { // ModifyColumn modify column to type func (s *DB) ModifyColumn(column string, typ string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.modifyColumn(column, typ) return scope.db } // DropColumn drop a column func (s *DB) DropColumn(column string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.dropColumn(column) return scope.db } @@ -598,7 +598,7 @@ func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { // RemoveIndex remove index with name func (s *DB) RemoveIndex(indexName string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.removeIndex(indexName) return scope.db } @@ -606,7 +606,7 @@ func (s *DB) RemoveIndex(indexName string) *DB { // AddForeignKey Add foreign key to the given scope, e.g: // db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.addForeignKey(field, dest, onDelete, onUpdate) return scope.db } From 8e7d807ebf902bf239ac9ccd509b42659ee378ba Mon Sep 17 00:00:00 2001 From: Nathan Osman Date: Fri, 22 Dec 2017 17:59:15 -0800 Subject: [PATCH 038/881] Allow name of column to be customized to support self-referencing many2many fields. --- customize_column_test.go | 22 ++++++++++++++++++++++ join_table_handler.go | 19 ++++++++++++++++++- model_struct.go | 15 ++++++++++++++- 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/customize_column_test.go b/customize_column_test.go index ddb536b8..c96b2d40 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -279,3 +279,25 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { t.Errorf("should preload discount from coupon") } } + +type SelfReferencingUser struct { + gorm.Model + Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;AssociationForeignKey:ID=friend_id"` +} + +func TestSelfReferencingMany2ManyColumn(t *testing.T) { + DB.DropTable(&SelfReferencingUser{}, "UserFriends") + DB.AutoMigrate(&SelfReferencingUser{}) + + friend := SelfReferencingUser{} + if err := DB.Create(&friend).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + user := SelfReferencingUser{ + Friends: []*SelfReferencingUser{&friend}, + } + if err := DB.Create(&user).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } +} diff --git a/join_table_handler.go b/join_table_handler.go index 2d1a5055..b4be6cf9 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -109,7 +109,24 @@ func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[strin // Add create relationship in join table for source and destination func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { scope := db.NewScope("") - searchMap := s.getSearchMap(db, source, destination) + searchMap := map[string]interface{}{} + + // getSearchMap() cannot be used here since the source and destination + // model types may be identical + + sourceScope := db.NewScope(source) + for _, foreignKey := range s.Source.ForeignKeys { + if field, ok := sourceScope.FieldByName(foreignKey.AssociationDBName); ok { + searchMap[foreignKey.DBName] = field.Field.Interface() + } + } + + destinationScope := db.NewScope(destination) + for _, foreignKey := range s.Destination.ForeignKeys { + if field, ok := destinationScope.FieldByName(foreignKey.AssociationDBName); ok { + searchMap[foreignKey.DBName] = field.Field.Interface() + } + } var assignColumns, binVars, conditions []string var values []interface{} diff --git a/model_struct.go b/model_struct.go index 315028c4..463ec517 100644 --- a/model_struct.go +++ b/model_struct.go @@ -289,11 +289,24 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } for _, name := range associationForeignKeys { + + // In order to allow self-referencing many2many tables, the name + // may be followed by "=" to allow renaming the column + parts := strings.Split(name, "=") + name = parts[0] + if field, ok := toScope.FieldByName(name); ok { // association foreign keys (db names) relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) + + // If a new name was provided for the field, use it + name = field.DBName + if len(parts) > 1 { + name = parts[1] + } + // join table foreign keys for association - joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + joinTableDBName := ToDBName(elemType.Name()) + "_" + name relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) } } From 44b9911f5157e6b7d03c08fcf730ded96b2eda66 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 20:57:39 +0800 Subject: [PATCH 039/881] Refactor self referencing m2m support --- association.go | 4 +- customize_column_test.go | 30 +++++++++-- join_table_handler.go | 60 +++++++++------------ model_struct.go | 111 +++++++++++++++++++++++---------------- 4 files changed, 117 insertions(+), 88 deletions(-) diff --git a/association.go b/association.go index 3d522ccc..8c6d9864 100644 --- a/association.go +++ b/association.go @@ -107,7 +107,7 @@ func (association *Association) Replace(values ...interface{}) *Association { if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) } } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) @@ -173,7 +173,7 @@ func (association *Association) Delete(values ...interface{}) *Association { sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) } else { var foreignKeyMap = map[string]interface{}{} for _, foreignKey := range relationship.ForeignDBNames { diff --git a/customize_column_test.go b/customize_column_test.go index c96b2d40..629d85f9 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -282,22 +282,44 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { type SelfReferencingUser struct { gorm.Model - Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;AssociationForeignKey:ID=friend_id"` + Name string + Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"` } func TestSelfReferencingMany2ManyColumn(t *testing.T) { DB.DropTable(&SelfReferencingUser{}, "UserFriends") DB.AutoMigrate(&SelfReferencingUser{}) - friend := SelfReferencingUser{} - if err := DB.Create(&friend).Error; err != nil { + friend1 := SelfReferencingUser{Name: "friend1_m2m"} + if err := DB.Create(&friend1).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + friend2 := SelfReferencingUser{Name: "friend2_m2m"} + if err := DB.Create(&friend2).Error; err != nil { t.Errorf("no error should happen, but got %v", err) } user := SelfReferencingUser{ - Friends: []*SelfReferencingUser{&friend}, + Name: "self_m2m", + Friends: []*SelfReferencingUser{&friend1, &friend2}, } + if err := DB.Create(&user).Error; err != nil { t.Errorf("no error should happen, but got %v", err) } + + if DB.Model(&user).Association("Friends").Count() != 2 { + t.Errorf("Should find created friends correctly") + } + + var newUser = SelfReferencingUser{} + + if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if len(newUser.Friends) != 2 { + t.Errorf("Should preload created frineds for self reference m2m") + } } diff --git a/join_table_handler.go b/join_table_handler.go index b4be6cf9..f07541ba 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -82,55 +82,40 @@ func (s JoinTableHandler) Table(db *DB) string { return s.TableName } -func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} { - values := map[string]interface{}{} - +func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) { for _, source := range sources { scope := db.NewScope(source) modelType := scope.GetModelStruct().ModelType - if s.Source.ModelType == modelType { - for _, foreignKey := range s.Source.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - values[foreignKey.DBName] = field.Field.Interface() - } - } - } else if s.Destination.ModelType == modelType { - for _, foreignKey := range s.Destination.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - values[foreignKey.DBName] = field.Field.Interface() + for _, joinTableSource := range joinTableSources { + if joinTableSource.ModelType == modelType { + for _, foreignKey := range joinTableSource.ForeignKeys { + if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { + conditionMap[foreignKey.DBName] = field.Field.Interface() + } } + break } } } - return values } // Add create relationship in join table for source and destination func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { - scope := db.NewScope("") - searchMap := map[string]interface{}{} + var ( + scope = db.NewScope("") + conditionMap = map[string]interface{}{} + ) - // getSearchMap() cannot be used here since the source and destination - // model types may be identical + // Update condition map for source + s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source) - sourceScope := db.NewScope(source) - for _, foreignKey := range s.Source.ForeignKeys { - if field, ok := sourceScope.FieldByName(foreignKey.AssociationDBName); ok { - searchMap[foreignKey.DBName] = field.Field.Interface() - } - } - - destinationScope := db.NewScope(destination) - for _, foreignKey := range s.Destination.ForeignKeys { - if field, ok := destinationScope.FieldByName(foreignKey.AssociationDBName); ok { - searchMap[foreignKey.DBName] = field.Field.Interface() - } - } + // Update condition map for destination + s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination) var assignColumns, binVars, conditions []string var values []interface{} - for key, value := range searchMap { + for key, value := range conditionMap { assignColumns = append(assignColumns, scope.Quote(key)) binVars = append(binVars, `?`) conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) @@ -158,12 +143,15 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source // Delete delete relationship in join table for sources func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { var ( - scope = db.NewScope(nil) - conditions []string - values []interface{} + scope = db.NewScope(nil) + conditions []string + values []interface{} + conditionMap = map[string]interface{}{} ) - for key, value := range s.getSearchMap(db, sources...) { + s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...) + + for key, value := range conditionMap { conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) values = append(values, value) } diff --git a/model_struct.go b/model_struct.go index 463ec517..f571e2e8 100644 --- a/model_struct.go +++ b/model_struct.go @@ -249,11 +249,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { ) if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") + foreignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") + if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { + associationForeignKeys = strings.Split(foreignKey, ",") + } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + associationForeignKeys = strings.Split(foreignKey, ",") } for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { @@ -264,50 +266,65 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { relationship.Kind = "many_to_many" - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, field.DBName) + { // Foreign Keys for Source + joinTableDBNames := []string{} + + if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" { + joinTableDBNames = strings.Split(foreignKey, ",") } - } - for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - // source foreign keys (db names) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) - // join table foreign keys for source - joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName - relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) - } - } - - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range toScope.PrimaryFields() { - associationForeignKeys = append(associationForeignKeys, field.DBName) - } - } - - for _, name := range associationForeignKeys { - - // In order to allow self-referencing many2many tables, the name - // may be followed by "=" to allow renaming the column - parts := strings.Split(name, "=") - name = parts[0] - - if field, ok := toScope.FieldByName(name); ok { - // association foreign keys (db names) - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) - - // If a new name was provided for the field, use it - name = field.DBName - if len(parts) > 1 { - name = parts[1] + // if no foreign keys defined with tag + if len(foreignKeys) == 0 { + for _, field := range modelStruct.PrimaryFields { + foreignKeys = append(foreignKeys, field.DBName) } + } - // join table foreign keys for association - joinTableDBName := ToDBName(elemType.Name()) + "_" + name - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + for idx, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { + // source foreign keys (db names) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) + + // setup join table foreign keys for source + if len(joinTableDBNames) > idx { + // if defined join table's foreign key + relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) + } else { + defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName + relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) + } + } + } + } + + { // Foreign Keys for Association (Destination) + associationJoinTableDBNames := []string{} + + if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" { + associationJoinTableDBNames = strings.Split(foreignKey, ",") + } + + // if no association foreign keys defined with tag + if len(associationForeignKeys) == 0 { + for _, field := range toScope.PrimaryFields() { + associationForeignKeys = append(associationForeignKeys, field.DBName) + } + } + + for idx, name := range associationForeignKeys { + if field, ok := toScope.FieldByName(name); ok { + // association foreign keys (db names) + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) + + // setup join table foreign keys for association + if len(associationJoinTableDBNames) > idx { + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) + } else { + // join table foreign keys for association + joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + } + } } } @@ -412,11 +429,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { ) if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") + tagForeignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") + if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { + tagAssociationForeignKeys = strings.Split(foreignKey, ",") + } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + tagAssociationForeignKeys = strings.Split(foreignKey, ",") } if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { From cb7c41e0b6e3863e7934a50c0aed76b8cfb61bfd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 22:14:18 +0800 Subject: [PATCH 040/881] Add more tests for self-referencing many2many relationship --- customize_column_test.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/customize_column_test.go b/customize_column_test.go index 629d85f9..5e19d6f4 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -322,4 +322,25 @@ func TestSelfReferencingMany2ManyColumn(t *testing.T) { if len(newUser.Friends) != 2 { t.Errorf("Should preload created frineds for self reference m2m") } + + DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"}) + if DB.Model(&user).Association("Friends").Count() != 3 { + t.Errorf("Should find created friends correctly") + } + + DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"}) + if DB.Model(&user).Association("Friends").Count() != 1 { + t.Errorf("Should find created friends correctly") + } + + friend := SelfReferencingUser{} + DB.Model(&newUser).Association("Friends").Find(&friend) + if friend.Name != "friend4_m2m" { + t.Errorf("Should find created friends correctly") + } + + DB.Model(&newUser).Association("Friends").Delete(friend) + if DB.Model(&user).Association("Friends").Count() != 0 { + t.Errorf("All friends should be deleted") + } } From fd15156d399274bcf281ac25ca0536075abd637a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 11 Feb 2018 09:16:10 +0800 Subject: [PATCH 041/881] Fix Count in mssql for SQL with group --- query_test.go | 9 +++++++++ scope.go | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/query_test.go b/query_test.go index 98721800..882fd611 100644 --- a/query_test.go +++ b/query_test.go @@ -430,6 +430,15 @@ func TestCount(t *testing.T) { if count1 != 1 || count2 != 3 { t.Errorf("Multiple count in chain") } + + var count3 int + if err := DB.Model(&User{}).Where("name in (?)", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { + t.Errorf("Not error should happen, but got %v", err) + } + + if count3 != 2 { + t.Errorf("Should get correct count, but got %v", count3) + } } func TestNot(t *testing.T) { diff --git a/scope.go b/scope.go index 63bf618f..ae98d251 100644 --- a/scope.go +++ b/scope.go @@ -951,8 +951,8 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope { func (scope *Scope) count(value interface{}) *Scope { if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) { if len(scope.Search.group) != 0 { - scope.Search.Select("count(*) FROM ( SELECT count(*) ") - scope.Search.group += " ) AS count" + scope.Search.Select("count(*) FROM ( SELECT count(*) as name ") + scope.Search.group += " ) AS count_table" } else { scope.Search.Select("count(*)") } From 3b6d790e93e9715cafbc66179f9435994e7413a2 Mon Sep 17 00:00:00 2001 From: Viktor Nikolaiev Date: Wed, 30 Aug 2017 22:52:45 +0300 Subject: [PATCH 042/881] Made it possible to implement driver.Valuer for byte slices --- migration_test.go | 26 ++++++++++++++ scope.go | 92 ++++++++++++++++++++++++++++++----------------- scope_test.go | 42 ++++++++++++++++++++++ 3 files changed, 128 insertions(+), 32 deletions(-) diff --git a/migration_test.go b/migration_test.go index 6b4470a6..7c3436ca 100644 --- a/migration_test.go +++ b/migration_test.go @@ -33,6 +33,7 @@ type User struct { CompanyID *int Company Company Role Role + Password EncryptedData PasswordHash []byte IgnoreMe int64 `sql:"-"` IgnoreStringSlice []string `sql:"-"` @@ -116,6 +117,31 @@ type Company struct { Owner *User `sql:"-"` } +type EncryptedData []byte + +func (data *EncryptedData) Scan(value interface{}) error { + if b, ok := value.([]byte); ok { + if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*'{ + return errors.New("Too short") + } + + *data = b[3:] + return nil + } else { + return errors.New("Bytes expected") + } +} + +func (data EncryptedData) Value() (driver.Value, error) { + if len(data) > 0 && data[0] == 'x' { + //needed to test failures + return nil, errors.New("Should not start with 'x'") + } + + //prepend asterisks + return append([]byte("***"), data...), nil +} + type Role struct { Name string `gorm:"size:256"` } diff --git a/scope.go b/scope.go index ae98d251..65d35461 100644 --- a/scope.go +++ b/scope.go @@ -557,22 +557,29 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri args := clause["args"].([]interface{}) for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) + rArg := reflect.ValueOf(arg) + rArgType := reflect.TypeOf(arg) + vArg, isValuer := arg.(driver.Valuer) + var err error + + //non byte slice and non driver.Valuer + if rArgType.Kind() == reflect.Slice && rArgType.Elem().Kind() != reflect.Uint8 && !isValuer { + if rArg.Len() > 0 { + tempMarks := make([]string, 0, rArg.Len()) + for i := 0; i < rArg.Len(); i++ { + tempMarks = append(tempMarks, scope.AddToVars(rArg.Index(i).Interface())) } + str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) } else { str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) } - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() + } else { + if isValuer { + arg, err = vArg.Value() + if err != nil { + scope.Err(err) + } } str = strings.Replace(str, "?", scope.AddToVars(arg), 1) @@ -629,23 +636,31 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string args := clause["args"].([]interface{}) for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) + rArg := reflect.ValueOf(arg) + rArgType := reflect.TypeOf(arg) + vArg, isValuer := arg.(driver.Valuer) + var err error + + //non byte slice and non driver.Valuer + if rArgType.Kind() == reflect.Slice && rArgType.Elem().Kind() != reflect.Uint8 && !isValuer { + if rArg.Len() > 0 { + tempMarks := make([]string, 0, rArg.Len()) + for i := 0; i < rArg.Len(); i++ { + tempMarks = append(tempMarks, scope.AddToVars(rArg.Index(i).Interface())) } + str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) } else { str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) } - default: - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = scanner.Value() + } else { + if isValuer { + arg, err = vArg.Value() + if err != nil { + scope.Err(err) + } } + str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) } } @@ -662,18 +677,31 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) args := clause["args"].([]interface{}) for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: - values := reflect.ValueOf(arg) - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) + rArg := reflect.ValueOf(arg) + rArgType := reflect.TypeOf(arg) + vArg, isValuer := arg.(driver.Valuer) + var err error + + //non byte slice and non driver.Valuer + if rArgType.Kind() == reflect.Slice && rArgType.Elem().Kind() != reflect.Uint8 && !isValuer { + if rArg.Len() > 0 { + tempMarks := make([]string, 0, rArg.Len()) + for i := 0; i < rArg.Len(); i++ { + tempMarks = append(tempMarks, scope.AddToVars(rArg.Index(i).Interface())) + } + + str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + } else { + str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() + } else { + if isValuer { + arg, err = vArg.Value() + if err != nil { + scope.Err(err) + } } + str = strings.Replace(str, "?", scope.AddToVars(arg), 1) } } diff --git a/scope_test.go b/scope_test.go index 42458995..71e80225 100644 --- a/scope_test.go +++ b/scope_test.go @@ -1,7 +1,10 @@ package gorm_test import ( + "encoding/hex" "github.com/jinzhu/gorm" + "math/rand" + "strings" "testing" ) @@ -41,3 +44,42 @@ func TestScopes(t *testing.T) { t.Errorf("Should found two users's name in 1, 3") } } + +func randName() string { + data := make([]byte, 8) + rand.Read(data) + + return "n-" + hex.EncodeToString(data) +} + +func TestValuer(t *testing.T) { + name := randName() + + origUser := User{Name: name, Age: 1, Password: EncryptedData("pass1"), PasswordHash: []byte("abc")} + err := DB.Save(&origUser).Error + if err != nil { + t.Log(err) + t.FailNow() + } + + var user2 User + err = DB.Where("name=? AND password=? AND password_hash=?", name, EncryptedData("pass1"), []byte("abc")).First(&user2).Error + if err != nil { + t.Log(err) + t.FailNow() + } +} + +func TestFailedValuer(t *testing.T) { + name := randName() + + err := DB.Exec("INSERT INTO users(name, password) VALUES(?, ?)", name, EncryptedData("xpass1")).Error + if err == nil { + t.FailNow() + } + + if !strings.HasPrefix(err.Error(), "Should not start with") { + t.FailNow() + } + +} From fce49136e8cd59940611b8510316e9cef59a5f86 Mon Sep 17 00:00:00 2001 From: Viktor Nikolaiev Date: Thu, 31 Aug 2017 10:30:48 +0300 Subject: [PATCH 043/881] fixed golint issues --- migration_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/migration_test.go b/migration_test.go index 7c3436ca..d58e1fb5 100644 --- a/migration_test.go +++ b/migration_test.go @@ -121,15 +121,15 @@ type EncryptedData []byte func (data *EncryptedData) Scan(value interface{}) error { if b, ok := value.([]byte); ok { - if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*'{ + if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' { return errors.New("Too short") } *data = b[3:] return nil - } else { - return errors.New("Bytes expected") } + + return errors.New("Bytes expected") } func (data EncryptedData) Value() (driver.Value, error) { From ba3e6201c72c22584cbe39f87a564c5ecdf440a6 Mon Sep 17 00:00:00 2001 From: Viktor Nikolaiev Date: Tue, 3 Oct 2017 17:17:39 +0300 Subject: [PATCH 044/881] fixed issue with null values in where conditions --- scope.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scope.go b/scope.go index 65d35461..252a1240 100644 --- a/scope.go +++ b/scope.go @@ -563,7 +563,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri var err error //non byte slice and non driver.Valuer - if rArgType.Kind() == reflect.Slice && rArgType.Elem().Kind() != reflect.Uint8 && !isValuer { + if arg != nil && rArgType.Kind() == reflect.Slice && rArgType.Elem().Kind() != reflect.Uint8 && !isValuer { if rArg.Len() > 0 { tempMarks := make([]string, 0, rArg.Len()) for i := 0; i < rArg.Len(); i++ { @@ -642,7 +642,7 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string var err error //non byte slice and non driver.Valuer - if rArgType.Kind() == reflect.Slice && rArgType.Elem().Kind() != reflect.Uint8 && !isValuer { + if arg != nil && rArgType.Kind() == reflect.Slice && rArgType.Elem().Kind() != reflect.Uint8 && !isValuer { if rArg.Len() > 0 { tempMarks := make([]string, 0, rArg.Len()) for i := 0; i < rArg.Len(); i++ { @@ -683,7 +683,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) var err error //non byte slice and non driver.Valuer - if rArgType.Kind() == reflect.Slice && rArgType.Elem().Kind() != reflect.Uint8 && !isValuer { + if arg != nil && rArgType.Kind() == reflect.Slice && rArgType.Elem().Kind() != reflect.Uint8 && !isValuer { if rArg.Len() > 0 { tempMarks := make([]string, 0, rArg.Len()) for i := 0; i < rArg.Len(); i++ { From c503108f8345b65e02549846cdb9313487022932 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 11 Feb 2018 12:48:08 +0800 Subject: [PATCH 045/881] Refactor fix valuer --- scope.go | 102 ++++++++++++++++++++++---------------------------- scope_test.go | 25 +++++-------- 2 files changed, 54 insertions(+), 73 deletions(-) diff --git a/scope.go b/scope.go index 252a1240..0dcea855 100644 --- a/scope.go +++ b/scope.go @@ -557,33 +557,33 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri args := clause["args"].([]interface{}) for _, arg := range args { - rArg := reflect.ValueOf(arg) - rArgType := reflect.TypeOf(arg) - vArg, isValuer := arg.(driver.Valuer) var err error - - //non byte slice and non driver.Valuer - if arg != nil && rArgType.Kind() == reflect.Slice && rArgType.Elem().Kind() != reflect.Uint8 && !isValuer { - if rArg.Len() > 0 { - tempMarks := make([]string, 0, rArg.Len()) - for i := 0; i < rArg.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(rArg.Index(i).Interface())) + switch reflect.ValueOf(arg).Kind() { + case reflect.Slice: // For where("id in (?)", []int64{1,2}) + if scanner, ok := interface{}(arg).(driver.Valuer); ok { + arg, err = scanner.Value() + str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + } else if bytes, ok := arg.([]byte); ok { + str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + } else if values := reflect.ValueOf(arg); values.Len() > 0 { + var tempMarks []string + for i := 0; i < values.Len(); i++ { + tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) } else { str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) } - } else { - if isValuer { - arg, err = vArg.Value() - if err != nil { - scope.Err(err) - } + default: + if valuer, ok := interface{}(arg).(driver.Valuer); ok { + arg, err = valuer.Value() } str = strings.Replace(str, "?", scope.AddToVars(arg), 1) } + if err != nil { + scope.Err(err) + } } return } @@ -636,33 +636,32 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string args := clause["args"].([]interface{}) for _, arg := range args { - rArg := reflect.ValueOf(arg) - rArgType := reflect.TypeOf(arg) - vArg, isValuer := arg.(driver.Valuer) var err error - - //non byte slice and non driver.Valuer - if arg != nil && rArgType.Kind() == reflect.Slice && rArgType.Elem().Kind() != reflect.Uint8 && !isValuer { - if rArg.Len() > 0 { - tempMarks := make([]string, 0, rArg.Len()) - for i := 0; i < rArg.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(rArg.Index(i).Interface())) + switch reflect.ValueOf(arg).Kind() { + case reflect.Slice: // For where("id in (?)", []int64{1,2}) + if scanner, ok := interface{}(arg).(driver.Valuer); ok { + arg, err = scanner.Value() + str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + } else if bytes, ok := arg.([]byte); ok { + str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + } else if values := reflect.ValueOf(arg); values.Len() > 0 { + var tempMarks []string + for i := 0; i < values.Len(); i++ { + tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) } else { str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) } - } else { - if isValuer { - arg, err = vArg.Value() - if err != nil { - scope.Err(err) - } + default: + if scanner, ok := interface{}(arg).(driver.Valuer); ok { + arg, err = scanner.Value() } - str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) } + if err != nil { + scope.Err(err) + } } return } @@ -677,31 +676,18 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) args := clause["args"].([]interface{}) for _, arg := range args { - rArg := reflect.ValueOf(arg) - rArgType := reflect.TypeOf(arg) - vArg, isValuer := arg.(driver.Valuer) - var err error - - //non byte slice and non driver.Valuer - if arg != nil && rArgType.Kind() == reflect.Slice && rArgType.Elem().Kind() != reflect.Uint8 && !isValuer { - if rArg.Len() > 0 { - tempMarks := make([]string, 0, rArg.Len()) - for i := 0; i < rArg.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(rArg.Index(i).Interface())) - } - - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) - } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + switch reflect.ValueOf(arg).Kind() { + case reflect.Slice: + values := reflect.ValueOf(arg) + var tempMarks []string + for i := 0; i < values.Len(); i++ { + tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - } else { - if isValuer { - arg, err = vArg.Value() - if err != nil { - scope.Err(err) - } + str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + default: + if valuer, ok := interface{}(arg).(driver.Valuer); ok { + arg, _ = valuer.Value() } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) } } diff --git a/scope_test.go b/scope_test.go index 71e80225..3018f350 100644 --- a/scope_test.go +++ b/scope_test.go @@ -2,10 +2,11 @@ package gorm_test import ( "encoding/hex" - "github.com/jinzhu/gorm" "math/rand" "strings" "testing" + + "github.com/jinzhu/gorm" ) func NameIn1And2(d *gorm.DB) *gorm.DB { @@ -56,17 +57,13 @@ func TestValuer(t *testing.T) { name := randName() origUser := User{Name: name, Age: 1, Password: EncryptedData("pass1"), PasswordHash: []byte("abc")} - err := DB.Save(&origUser).Error - if err != nil { - t.Log(err) - t.FailNow() + if err := DB.Save(&origUser).Error; err != nil { + t.Errorf("No error should happen when saving user, but got %v", err) } var user2 User - err = DB.Where("name=? AND password=? AND password_hash=?", name, EncryptedData("pass1"), []byte("abc")).First(&user2).Error - if err != nil { - t.Log(err) - t.FailNow() + if err := DB.Where("name = ? AND password = ? AND password_hash = ?", name, EncryptedData("pass1"), []byte("abc")).First(&user2).Error; err != nil { + t.Errorf("No error should happen when querying user with valuer, but got %v", err) } } @@ -74,12 +71,10 @@ func TestFailedValuer(t *testing.T) { name := randName() err := DB.Exec("INSERT INTO users(name, password) VALUES(?, ?)", name, EncryptedData("xpass1")).Error + if err == nil { - t.FailNow() + t.Errorf("There should be an error should happen when insert data") + } else if !strings.HasPrefix(err.Error(), "Should not start with") { + t.Errorf("The error should be returned from Valuer, but get %v", err) } - - if !strings.HasPrefix(err.Error(), "Should not start with") { - t.FailNow() - } - } From 841ea1bde530b7d046262861cc39a041f42bdce3 Mon Sep 17 00:00:00 2001 From: matematik7 Date: Mon, 14 Aug 2017 20:46:39 +0200 Subject: [PATCH 046/881] Do not always override select on pluck --- scope.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/scope.go b/scope.go index 0dcea855..db797dcc 100644 --- a/scope.go +++ b/scope.go @@ -938,14 +938,30 @@ func (scope *Scope) initialize() *Scope { return scope } +func (scope *Scope) isQueryForColumn(query interface{}, column string) bool { + queryStr := fmt.Sprint(query) + if queryStr == column { + return true + } + + if strings.HasSuffix(strings.ToLower(queryStr), "as "+column) { + return true + } + + return false +} + func (scope *Scope) pluck(column string, value interface{}) *Scope { dest := reflect.Indirect(reflect.ValueOf(value)) - scope.Search.Select(column) if dest.Kind() != reflect.Slice { scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) return scope } + if query, ok := scope.Search.selects["query"]; !ok || !scope.isQueryForColumn(query, column) { + scope.Search.Select(column) + } + rows, err := scope.rows() if scope.Err(err) == nil { defer rows.Close() From 36043ad905ae3c19feaebd68327b1bf6b291ec29 Mon Sep 17 00:00:00 2001 From: matematik7 Date: Mon, 4 Sep 2017 18:12:20 +0200 Subject: [PATCH 047/881] Fix for quoted column names and add test --- query_test.go | 24 ++++++++++++++++++++++++ scope.go | 8 ++++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/query_test.go b/query_test.go index 882fd611..df8893fd 100644 --- a/query_test.go +++ b/query_test.go @@ -674,3 +674,27 @@ func TestSelectWithArrayInput(t *testing.T) { t.Errorf("Should have selected both age and name") } } + +func TestPluckWithSelect(t *testing.T) { + DB.Save(&User{Name: "matematik7", Age: 25}) + + var userAges []string + err := DB.Model(&User{}).Where("age = ?", 25).Select("name || ' - ' || age as user_age").Pluck("user_age", &userAges).Error + if err != nil { + t.Error(err) + } + + if len(userAges) != 1 || userAges[0] != "matematik7 - 25" { + t.Errorf("Should correctly pluck with select, got: %s", userAges) + } + + userAges = userAges[:0] + err = DB.Model(&User{}).Where("age = ?", 25).Select("name || ' - ' || age as \"user_age\"").Pluck("user_age", &userAges).Error + if err != nil { + t.Error(err) + } + + if len(userAges) != 1 || userAges[0] != "matematik7 - 25" { + t.Errorf("Should correctly pluck with select, got: %s", userAges) + } +} diff --git a/scope.go b/scope.go index db797dcc..65ac62d9 100644 --- a/scope.go +++ b/scope.go @@ -939,12 +939,16 @@ func (scope *Scope) initialize() *Scope { } func (scope *Scope) isQueryForColumn(query interface{}, column string) bool { - queryStr := fmt.Sprint(query) + queryStr := strings.ToLower(fmt.Sprint(query)) if queryStr == column { return true } - if strings.HasSuffix(strings.ToLower(queryStr), "as "+column) { + if strings.HasSuffix(queryStr, "as "+column) { + return true + } + + if strings.HasSuffix(queryStr, "as \""+column+"\"") { return true } From 46269198a4e50bbffb5682321fe5865836dd17b9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 11 Feb 2018 13:41:46 +0800 Subject: [PATCH 048/881] Refactor PR #1569 --- query_test.go | 23 ++++++++++++++++++----- scope.go | 2 +- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/query_test.go b/query_test.go index df8893fd..135805a7 100644 --- a/query_test.go +++ b/query_test.go @@ -2,6 +2,7 @@ package gorm_test import ( "fmt" + "os" "reflect" "github.com/jinzhu/gorm" @@ -676,25 +677,37 @@ func TestSelectWithArrayInput(t *testing.T) { } func TestPluckWithSelect(t *testing.T) { - DB.Save(&User{Name: "matematik7", Age: 25}) + var ( + user = User{Name: "matematik7_pluck_with_select", Age: 25} + combinedName = fmt.Sprintf("%v%v", user.Name, user.Age) + combineUserAgeSQL = fmt.Sprintf("concat(%v, %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age")) + ) + if dialect := os.Getenv("GORM_DIALECT"); dialect == "sqlite" { + combineUserAgeSQL = fmt.Sprintf("(%v || %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age")) + } + + DB.Save(&user) + + selectStr := combineUserAgeSQL + " as user_age" var userAges []string - err := DB.Model(&User{}).Where("age = ?", 25).Select("name || ' - ' || age as user_age").Pluck("user_age", &userAges).Error + err := DB.Model(&User{}).Where("age = ?", 25).Select(selectStr).Pluck("user_age", &userAges).Error if err != nil { t.Error(err) } - if len(userAges) != 1 || userAges[0] != "matematik7 - 25" { + if len(userAges) != 1 || userAges[0] != combinedName { t.Errorf("Should correctly pluck with select, got: %s", userAges) } + selectStr = combineUserAgeSQL + fmt.Sprintf(" as %v", DB.Dialect().Quote("user_age")) userAges = userAges[:0] - err = DB.Model(&User{}).Where("age = ?", 25).Select("name || ' - ' || age as \"user_age\"").Pluck("user_age", &userAges).Error + err = DB.Model(&User{}).Where("age = ?", 25).Select(selectStr).Pluck("user_age", &userAges).Error if err != nil { t.Error(err) } - if len(userAges) != 1 || userAges[0] != "matematik7 - 25" { + if len(userAges) != 1 || userAges[0] != combinedName { t.Errorf("Should correctly pluck with select, got: %s", userAges) } } diff --git a/scope.go b/scope.go index 65ac62d9..29508d8d 100644 --- a/scope.go +++ b/scope.go @@ -948,7 +948,7 @@ func (scope *Scope) isQueryForColumn(query interface{}, column string) bool { return true } - if strings.HasSuffix(queryStr, "as \""+column+"\"") { + if strings.HasSuffix(queryStr, "as "+scope.Quote(column)) { return true } From 3c70f83833b62a5f106d16039b46658b21e90ea6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 11 Feb 2018 13:57:59 +0800 Subject: [PATCH 049/881] Fix query test --- query_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/query_test.go b/query_test.go index 135805a7..77449f4f 100644 --- a/query_test.go +++ b/query_test.go @@ -2,7 +2,6 @@ package gorm_test import ( "fmt" - "os" "reflect" "github.com/jinzhu/gorm" @@ -683,7 +682,7 @@ func TestPluckWithSelect(t *testing.T) { combineUserAgeSQL = fmt.Sprintf("concat(%v, %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age")) ) - if dialect := os.Getenv("GORM_DIALECT"); dialect == "sqlite" { + if dialect := DB.Dialect().GetName(); dialect == "sqlite3" { combineUserAgeSQL = fmt.Sprintf("(%v || %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age")) } From 8d66eb4926845fd1210dfa88c6fbd052fc4867bf Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 23 Oct 2017 10:50:44 +0800 Subject: [PATCH 050/881] fixed wrong param substitution order --- main_test.go | 24 +++++++++++++++++++++++ scope.go | 54 ++++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 66 insertions(+), 12 deletions(-) diff --git a/main_test.go b/main_test.go index 83e6f7aa..48a8bd63 100644 --- a/main_test.go +++ b/main_test.go @@ -631,6 +631,30 @@ func TestQueryBuilderSubselectInWhere(t *testing.T) { } } +func TestQueryBuilderRawQueryWithSubquery(t *testing.T) { + user := User{Name: "subquery_test_user1", Age: 10} + DB.Save(&user) + user = User{Name: "subquery_test_user2", Age: 11} + DB.Save(&user) + user = User{Name: "subquery_test_user2", Age: 12} + DB.Save(&user) + + var count int + err := DB.Raw("select count(*) from (?) tmp", + DB.Table("users"). + Select("name"). + Where("age >= ? and name in (?)", 10, []string{"subquery_test_user1", "subquery_test_user2"}). + Group("name"). + QueryExpr(), + ).Count(&count).Error + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + if count != 2 { + t.Errorf("Row count must be 2, instead got %d", count) + } +} + func TestQueryBuilderSubselectInHaving(t *testing.T) { user := User{Name: "query_expr_having_ruser1", Email: "root@user1.com", Age: 64} DB.Save(&user) diff --git a/scope.go b/scope.go index 29508d8d..ba9bd37c 100644 --- a/scope.go +++ b/scope.go @@ -1,16 +1,16 @@ package gorm import ( + "bytes" "database/sql" "database/sql/driver" "errors" "fmt" + "reflect" "regexp" "strconv" "strings" "time" - - "reflect" ) // Scope contain current operation's information when you perform any operation on the database @@ -555,6 +555,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri return strings.Join(sqls, " AND ") } + replacements := []string{} args := clause["args"].([]interface{}) for _, arg := range args { var err error @@ -562,29 +563,43 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri case reflect.Slice: // For where("id in (?)", []int64{1,2}) if scanner, ok := interface{}(arg).(driver.Valuer); ok { arg, err = scanner.Value() - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) - } else if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + replacements = append(replacements, scope.AddToVars(arg)) + } else if b, ok := arg.([]byte); ok { + replacements = append(replacements, scope.AddToVars(b)) } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + replacements = append(replacements, strings.Join(tempMarks, ",")) } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + replacements = append(replacements, scope.AddToVars(Expr("NULL"))) } default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { arg, err = valuer.Value() } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + replacements = append(replacements, scope.AddToVars(arg)) + } + } + + buff := bytes.NewBuffer([]byte{}) + i := 0 + for pos := range str { + if str[pos] == '?' { + buff.WriteString(replacements[i]) + i++ + } else { + buff.WriteByte(str[pos]) } if err != nil { scope.Err(err) } } + + str = buff.String() + return } @@ -642,8 +657,8 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string if scanner, ok := interface{}(arg).(driver.Valuer); ok { arg, err = scanner.Value() str = strings.Replace(str, "?", scope.AddToVars(arg), 1) - } else if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + } else if b, ok := arg.([]byte); ok { + str = strings.Replace(str, "?", scope.AddToVars(b), 1) } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ { @@ -675,6 +690,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) } args := clause["args"].([]interface{}) + replacements := []string{} for _, arg := range args { switch reflect.ValueOf(arg).Kind() { case reflect.Slice: @@ -683,14 +699,28 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + replacements = append(replacements, strings.Join(tempMarks, ",")) default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = valuer.Value() } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + replacements = append(replacements, scope.AddToVars(arg)) } } + + buff := bytes.NewBuffer([]byte{}) + i := 0 + for pos := range str { + if str[pos] == '?' { + buff.WriteString(replacements[i]) + i++ + } else { + buff.WriteByte(str[pos]) + } + } + + str = buff.String() + return } From 86c04795b754c96ec5bbeee05284a35e8caa4de1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 11 Feb 2018 15:52:52 +0800 Subject: [PATCH 051/881] Port PR #1655 to Not query builder --- main_test.go | 19 ++++++++++++++- scope.go | 68 +++++++++++++++++++++++++++++++--------------------- 2 files changed, 59 insertions(+), 28 deletions(-) diff --git a/main_test.go b/main_test.go index 48a8bd63..66c46af0 100644 --- a/main_test.go +++ b/main_test.go @@ -636,7 +636,7 @@ func TestQueryBuilderRawQueryWithSubquery(t *testing.T) { DB.Save(&user) user = User{Name: "subquery_test_user2", Age: 11} DB.Save(&user) - user = User{Name: "subquery_test_user2", Age: 12} + user = User{Name: "subquery_test_user3", Age: 12} DB.Save(&user) var count int @@ -647,12 +647,29 @@ func TestQueryBuilderRawQueryWithSubquery(t *testing.T) { Group("name"). QueryExpr(), ).Count(&count).Error + if err != nil { t.Errorf("Expected to get no errors, but got %v", err) } if count != 2 { t.Errorf("Row count must be 2, instead got %d", count) } + + err = DB.Raw("select count(*) from (?) tmp", + DB.Table("users"). + Select("name"). + Where("name LIKE ?", "subquery_test%"). + Not("age <= ?", 10).Not("name in (?)", []string{"subquery_test_user1", "subquery_test_user2"}). + Group("name"). + QueryExpr(), + ).Count(&count).Error + + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + if count != 1 { + t.Errorf("Row count must be 1, instead got %d", count) + } } func TestQueryBuilderSubselectInHaving(t *testing.T) { diff --git a/scope.go b/scope.go index ba9bd37c..762904d7 100644 --- a/scope.go +++ b/scope.go @@ -460,7 +460,7 @@ func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { var ( columnRegexp = regexp.MustCompile("^[a-zA-Z\\d]+(\\.[a-zA-Z\\d]+)*$") // only match string like `name`, `users.name` isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number - comparisonRegexp = regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ") + comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ") countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$") ) @@ -523,17 +523,17 @@ func (scope *Scope) primaryCondition(value interface{}) string { func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) { switch value := clause["query"].(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: + return scope.primaryCondition(scope.AddToVars(value)) + case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: + str = fmt.Sprintf("(%v.%v IN (?))", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey())) + clause["args"] = []interface{}{value} case string: if isNumberRegexp.MatchString(value) { return scope.primaryCondition(scope.AddToVars(value)) } else if value != "" { str = fmt.Sprintf("(%v)", value) } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return scope.primaryCondition(scope.AddToVars(value)) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: - str = fmt.Sprintf("(%v.%v IN (?))", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey())) - clause["args"] = []interface{}{value} case map[string]interface{}: var sqls []string for key, value := range value { @@ -582,6 +582,9 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri replacements = append(replacements, scope.AddToVars(arg)) } + if err != nil { + scope.Err(err) + } } buff := bytes.NewBuffer([]byte{}) @@ -593,9 +596,6 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri } else { buff.WriteByte(str[pos]) } - if err != nil { - scope.Err(err) - } } str = buff.String() @@ -604,21 +604,9 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri } func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { - var notEqualSQL string var primaryKey = scope.PrimaryKey() switch value := clause["query"].(type) { - case string: - if isNumberRegexp.MatchString(value) { - id, _ := strconv.Atoi(value) - return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) - } else if comparisonRegexp.MatchString(value) { - str = fmt.Sprintf(" NOT (%v) ", value) - notEqualSQL = fmt.Sprintf("NOT (%v)", value) - } else { - str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(value)) - notEqualSQL = fmt.Sprintf("(%v.%v <> ?)", scope.QuotedTableName(), scope.Quote(value)) - } case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: return fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(primaryKey), value) case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: @@ -628,6 +616,15 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string } else { return "" } + case string: + if isNumberRegexp.MatchString(value) { + id, _ := strconv.Atoi(value) + return fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(primaryKey), id) + } else if comparisonRegexp.MatchString(value) { + str = fmt.Sprintf("NOT (%v)", value) + } else { + str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(value)) + } case map[string]interface{}: var sqls []string for key, value := range value { @@ -642,13 +639,14 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string var sqls []string var newScope = scope.New(value) for _, field := range newScope.Fields() { - if !field.IsBlank { + if !field.IsIgnored && !field.IsBlank { sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } } return strings.Join(sqls, " AND ") } + replacements := []string{} args := clause["args"].([]interface{}) for _, arg := range args { var err error @@ -656,28 +654,44 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string case reflect.Slice: // For where("id in (?)", []int64{1,2}) if scanner, ok := interface{}(arg).(driver.Valuer); ok { arg, err = scanner.Value() - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + replacements = append(replacements, scope.AddToVars(arg)) } else if b, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(b), 1) + replacements = append(replacements, scope.AddToVars(b)) } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + replacements = append(replacements, strings.Join(tempMarks, ",")) } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + replacements = append(replacements, scope.AddToVars(Expr("NULL"))) } default: if scanner, ok := interface{}(arg).(driver.Valuer); ok { arg, err = scanner.Value() } - str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) + + replacements = append(replacements, scope.AddToVars(arg)) } + if err != nil { scope.Err(err) } } + + buff := bytes.NewBuffer([]byte{}) + i := 0 + + for pos := range str { + if str[pos] == '?' { + buff.WriteString(replacements[i]) + i++ + } else { + buff.WriteByte(str[pos]) + } + } + + str = buff.String() return } From 7a8c2bbff8d0327b20017b24299394263b94f69f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 11 Feb 2018 23:52:38 +0800 Subject: [PATCH 052/881] Refactor build SQL condition --- create_test.go | 6 +- main.go | 2 +- migration_test.go | 3 + query_test.go | 2 +- scope.go | 156 ++++++++++++++-------------------------------- 5 files changed, 57 insertions(+), 112 deletions(-) diff --git a/create_test.go b/create_test.go index 36472914..83b3a4ef 100644 --- a/create_test.go +++ b/create_test.go @@ -27,7 +27,9 @@ func TestCreate(t *testing.T) { } var newUser User - DB.First(&newUser, user.Id) + if err := DB.First(&newUser, user.Id).Error; err != nil { + t.Errorf("No error should happen, but got %v", err) + } if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) { t.Errorf("User's PasswordHash should be saved ([]byte)") @@ -38,7 +40,7 @@ func TestCreate(t *testing.T) { } if newUser.UserNum != Num(111) { - t.Errorf("User's UserNum should be saved (custom type)") + t.Errorf("User's UserNum should be saved (custom type), but got %v", newUser.UserNum) } if newUser.Latitude != float { diff --git a/main.go b/main.go index fc4859ac..d342571d 100644 --- a/main.go +++ b/main.go @@ -430,7 +430,7 @@ func (s *DB) Raw(sql string, values ...interface{}) *DB { // Exec execute raw sql func (s *DB) Exec(sql string, values ...interface{}) *DB { scope := s.NewScope(nil) - generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) + generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true) generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") scope.Raw(generatedSQL) return scope.Exec().db diff --git a/migration_test.go b/migration_test.go index d58e1fb5..7c694485 100644 --- a/migration_test.go +++ b/migration_test.go @@ -7,6 +7,7 @@ import ( "fmt" "os" "reflect" + "strconv" "testing" "time" @@ -168,6 +169,8 @@ type Num int64 func (i *Num) Scan(src interface{}) error { switch s := src.(type) { case []byte: + n, _ := strconv.Atoi(string(s)) + *i = Num(n) case int64: *i = Num(s) default: diff --git a/query_test.go b/query_test.go index 77449f4f..3c3c74b5 100644 --- a/query_test.go +++ b/query_test.go @@ -99,7 +99,7 @@ func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { var address AddressByZipCode DB.First(&address, "00501") if address.ZipCode != "00501" { - t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed") + t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode) } } diff --git a/scope.go b/scope.go index 762904d7..5ac147e4 100644 --- a/scope.go +++ b/scope.go @@ -8,7 +8,6 @@ import ( "fmt" "reflect" "regexp" - "strconv" "strings" "time" ) @@ -521,26 +520,58 @@ func (scope *Scope) primaryCondition(value interface{}) string { return fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()), value) } -func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) { +func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) (str string) { + var ( + quotedTableName = scope.QuotedTableName() + quotedPrimaryKey = scope.Quote(scope.PrimaryKey()) + equalSQL = "=" + inSQL = "IN" + ) + + // If building not conditions + if !include { + equalSQL = "<>" + inSQL = "NOT IN" + } + switch value := clause["query"].(type) { - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return scope.primaryCondition(scope.AddToVars(value)) + case sql.NullInt64: + return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value.Int64) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value) case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: - str = fmt.Sprintf("(%v.%v IN (?))", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey())) + if !include && reflect.ValueOf(value).Len() == 0 { + return + } + str = fmt.Sprintf("(%v.%v %s (?))", quotedTableName, quotedPrimaryKey, inSQL) clause["args"] = []interface{}{value} case string: if isNumberRegexp.MatchString(value) { - return scope.primaryCondition(scope.AddToVars(value)) - } else if value != "" { - str = fmt.Sprintf("(%v)", value) + return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, scope.AddToVars(value)) + } + + if value != "" { + if !include { + if comparisonRegexp.MatchString(value) { + str = fmt.Sprintf("NOT (%v)", value) + } else { + str = fmt.Sprintf("(%v.%v NOT IN (?))", quotedTableName, scope.Quote(value)) + } + } else { + str = fmt.Sprintf("(%v)", value) + } } case map[string]interface{}: var sqls []string for key, value := range value { if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(key), scope.AddToVars(value))) + sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(key), equalSQL, scope.AddToVars(value))) } else { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", scope.QuotedTableName(), scope.Quote(key))) + if !include { + sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", quotedTableName, scope.Quote(key))) + } else { + sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", quotedTableName, scope.Quote(key))) + } } } return strings.Join(sqls, " AND ") @@ -549,7 +580,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri newScope := scope.New(value) for _, field := range newScope.Fields() { if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) } } return strings.Join(sqls, " AND ") @@ -582,6 +613,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri replacements = append(replacements, scope.AddToVars(arg)) } + if err != nil { scope.Err(err) } @@ -603,98 +635,6 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri return } -func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { - var primaryKey = scope.PrimaryKey() - - switch value := clause["query"].(type) { - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(primaryKey), value) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: - if reflect.ValueOf(value).Len() > 0 { - str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(primaryKey)) - clause["args"] = []interface{}{value} - } else { - return "" - } - case string: - if isNumberRegexp.MatchString(value) { - id, _ := strconv.Atoi(value) - return fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(primaryKey), id) - } else if comparisonRegexp.MatchString(value) { - str = fmt.Sprintf("NOT (%v)", value) - } else { - str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(value)) - } - case map[string]interface{}: - var sqls []string - for key, value := range value { - if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(key), scope.AddToVars(value))) - } else { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", scope.QuotedTableName(), scope.Quote(key))) - } - } - return strings.Join(sqls, " AND ") - case interface{}: - var sqls []string - var newScope = scope.New(value) - for _, field := range newScope.Fields() { - if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } - } - return strings.Join(sqls, " AND ") - } - - replacements := []string{} - args := clause["args"].([]interface{}) - for _, arg := range args { - var err error - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, err = scanner.Value() - replacements = append(replacements, scope.AddToVars(arg)) - } else if b, ok := arg.([]byte); ok { - replacements = append(replacements, scope.AddToVars(b)) - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - replacements = append(replacements, strings.Join(tempMarks, ",")) - } else { - replacements = append(replacements, scope.AddToVars(Expr("NULL"))) - } - default: - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, err = scanner.Value() - } - - replacements = append(replacements, scope.AddToVars(arg)) - } - - if err != nil { - scope.Err(err) - } - } - - buff := bytes.NewBuffer([]byte{}) - i := 0 - - for pos := range str { - if str[pos] == '?' { - buff.WriteString(replacements[i]) - i++ - } else { - buff.WriteByte(str[pos]) - } - } - - str = buff.String() - return -} - func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) { switch value := clause["query"].(type) { case string: @@ -758,19 +698,19 @@ func (scope *Scope) whereSQL() (sql string) { } for _, clause := range scope.Search.whereConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { andConditions = append(andConditions, sql) } } for _, clause := range scope.Search.orConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { orConditions = append(orConditions, sql) } } for _, clause := range scope.Search.notConditions { - if sql := scope.buildNotCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, false); sql != "" { andConditions = append(andConditions, sql) } } @@ -844,7 +784,7 @@ func (scope *Scope) havingSQL() string { var andConditions []string for _, clause := range scope.Search.havingConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { andConditions = append(andConditions, sql) } } @@ -860,7 +800,7 @@ func (scope *Scope) havingSQL() string { func (scope *Scope) joinsSQL() string { var joinConditions []string for _, clause := range scope.Search.joinConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) } } From c54d23473c3f5ded7f0d1fbdd993a8c4a957ef9b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 12 Feb 2018 09:38:16 +0800 Subject: [PATCH 053/881] Add IsRecordNotFoundError method --- README.md | 4 ++-- errors.go | 12 ++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8c6e2302..6ff49b87 100644 --- a/README.md +++ b/README.md @@ -24,11 +24,11 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Getting Started -* GORM Guides [jinzhu.github.com/gorm](http://jinzhu.github.io/gorm) +* GORM Guides [jinzhu.github.com/gorm](https://jinzhu.github.io/gorm) ## Upgrading To V1.0 -* [CHANGELOG](http://jinzhu.github.io/gorm/changelog.html) +* [CHANGELOG](https://jinzhu.github.io/gorm/changelog.html) ## Supporting the project diff --git a/errors.go b/errors.go index 6845188e..da2cf13c 100644 --- a/errors.go +++ b/errors.go @@ -21,6 +21,18 @@ var ( // Errors contains all happened errors type Errors []error +// IsRecordNotFoundError returns current error has record not found error or not +func IsRecordNotFoundError(err error) bool { + if errs, ok := err.(Errors); ok { + for _, err := range errs { + if err == ErrRecordNotFound { + return true + } + } + } + return err == ErrRecordNotFound +} + // GetErrors gets all happened errors func (errs Errors) GetErrors() []error { return errs From 49934ff3bf729e5465ae8c3129e820419a0edd2a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 12 Feb 2018 09:43:28 +0800 Subject: [PATCH 054/881] Call DefaultTableNameHandler for JoinTableHandler's table --- join_table_handler.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/join_table_handler.go b/join_table_handler.go index f07541ba..a036d46d 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -79,7 +79,7 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s // Table return join table's table name func (s JoinTableHandler) Table(db *DB) string { - return s.TableName + return DefaultTableNameHandler(db, s.TableName) } func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) { From 7e2bb3d7fa0916f4cdf50236af59e735f7e67739 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 12 Feb 2018 11:56:45 +0800 Subject: [PATCH 055/881] Allow customize table name when creating index, close #1656 --- scope.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scope.go b/scope.go index 5ac147e4..1bd32a28 100644 --- a/scope.go +++ b/scope.go @@ -1250,13 +1250,13 @@ func (scope *Scope) autoIndex() *Scope { } for name, columns := range indexes { - if db := scope.NewDB().Model(scope.Value).AddIndex(name, columns...); db.Error != nil { + if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddIndex(name, columns...); db.Error != nil { scope.db.AddError(db.Error) } } for name, columns := range uniqueIndexes { - if db := scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil { + if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil { scope.db.AddError(db.Error) } } From 30adc80edc91ab4934e12f33fa1a0b07bfa4da03 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 12 Feb 2018 13:09:37 +0800 Subject: [PATCH 056/881] Test customize data type for primary key --- query_test.go | 31 +++++++++++++++++++++++++++++++ scope.go | 9 +++++++++ 2 files changed, 40 insertions(+) diff --git a/query_test.go b/query_test.go index 3c3c74b5..80ebd473 100644 --- a/query_test.go +++ b/query_test.go @@ -87,6 +87,37 @@ func TestUIntPrimaryKey(t *testing.T) { } } +func TestCustomizedTypePrimaryKey(t *testing.T) { + type ID uint + type CustomizedTypePrimaryKey struct { + ID ID + Name string + } + + DB.AutoMigrate(&CustomizedTypePrimaryKey{}) + + p1 := CustomizedTypePrimaryKey{Name: "p1"} + p2 := CustomizedTypePrimaryKey{Name: "p2"} + p3 := CustomizedTypePrimaryKey{Name: "p3"} + DB.Create(&p1) + DB.Create(&p2) + DB.Create(&p3) + + var p CustomizedTypePrimaryKey + + if err := DB.First(&p, p2.ID).Error; err == nil { + t.Errorf("Should return error for invalid query condition") + } + + if err := DB.First(&p, "id = ?", p2.ID).Error; err != nil { + t.Errorf("No error should happen when querying with customized type for primary key, got err %v", err) + } + + if p.Name != "p2" { + t.Errorf("Should find correct value when querying with customized type for primary key") + } +} + func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { type AddressByZipCode struct { ZipCode string `gorm:"primary_key"` diff --git a/scope.go b/scope.go index 1bd32a28..04d549bf 100644 --- a/scope.go +++ b/scope.go @@ -578,12 +578,21 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) case interface{}: var sqls []string newScope := scope.New(value) + + if len(newScope.Fields()) == 0 { + scope.Err(fmt.Errorf("invalid query condition: %v", value)) + return + } + for _, field := range newScope.Fields() { if !field.IsIgnored && !field.IsBlank { sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) } } return strings.Join(sqls, " AND ") + default: + scope.Err(fmt.Errorf("invalid query condition: %v", value)) + return } replacements := []string{} From 8005321a1c1da0f2b8ceb868f72aa97ebef0e9dc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 12 Feb 2018 14:48:11 +0800 Subject: [PATCH 057/881] Allow table option when DropTable, close #1514 --- scope.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scope.go b/scope.go index 04d549bf..3fe4675d 100644 --- a/scope.go +++ b/scope.go @@ -1079,7 +1079,7 @@ func (scope *Scope) getTableOptions() string { if !ok { return "" } - return tableOptions.(string) + return " " + tableOptions.(string) } func (scope *Scope) createJoinTable(field *StructField) { @@ -1112,7 +1112,7 @@ func (scope *Scope) createJoinTable(field *StructField) { } } - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) + scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v))%s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) } scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) } @@ -1147,14 +1147,14 @@ func (scope *Scope) createTable() *Scope { primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) } - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() + scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() scope.autoIndex() return scope } func (scope *Scope) dropTable() *Scope { - scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() + scope.Raw(fmt.Sprintf("DROP TABLE %v%s", scope.QuotedTableName(), scope.getTableOptions())).Exec() return scope } From 3b2c4b3608621404821708218da56ac6ea75f0d9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 12 Feb 2018 17:39:34 +0800 Subject: [PATCH 058/881] Fix insert with default value for mysql --- callback_create.go | 3 ++- create_test.go | 11 +++++++++++ dialect.go | 2 ++ dialect_common.go | 4 ++++ dialect_mysql.go | 4 ++++ dialects/mssql/mssql.go | 4 ++++ 6 files changed, 27 insertions(+), 1 deletion(-) diff --git a/callback_create.go b/callback_create.go index a4da39e8..e7fe6f86 100644 --- a/callback_create.go +++ b/callback_create.go @@ -97,8 +97,9 @@ func createCallback(scope *Scope) { if len(columns) == 0 { scope.Raw(fmt.Sprintf( - "INSERT INTO %v DEFAULT VALUES%v%v", + "INSERT INTO %v %v%v%v", quotedTableName, + scope.Dialect().DefaultValueStr(), addExtraSpaceIfExist(extraOption), addExtraSpaceIfExist(lastInsertIDReturningSuffix), )) diff --git a/create_test.go b/create_test.go index 83b3a4ef..92560643 100644 --- a/create_test.go +++ b/create_test.go @@ -62,6 +62,17 @@ func TestCreate(t *testing.T) { } } +func TestCreateEmptyStrut(t *testing.T) { + type EmptyStruct struct { + ID uint + } + DB.AutoMigrate(&EmptyStruct{}) + + if err := DB.Create(&EmptyStruct{}).Error; err != nil { + t.Errorf("No error should happen when creating user, but got %v", err) + } +} + func TestCreateWithExistingTimestamp(t *testing.T) { user := User{Name: "CreateUserExistingTimestamp"} diff --git a/dialect.go b/dialect.go index fe8e2f62..b20bfd5b 100644 --- a/dialect.go +++ b/dialect.go @@ -42,6 +42,8 @@ type Dialect interface { SelectFromDummyTable() string // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` LastInsertIDReturningSuffix(tableName, columnName string) string + // DefaultValueStr + DefaultValueStr() string // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference BuildKeyName(kind, tableName string, fields ...string) string diff --git a/dialect_common.go b/dialect_common.go index fbbaef33..b9f0c7da 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -159,6 +159,10 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s return "" } +func (commonDialect) DefaultValueStr() string { + return "DEFAULT VALUES" +} + // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) diff --git a/dialect_mysql.go b/dialect_mysql.go index 1feed1f6..b162bade 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -185,3 +185,7 @@ func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string { return fmt.Sprintf("%s%x", string(destRunes), bs) } + +func (mysql) DefaultValueStr() string { + return "VALUES()" +} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 1dd5fb69..e0606465 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -183,6 +183,10 @@ func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } +func (mssql) DefaultValueStr() string { + return "DEFAULT VALUES" +} + func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { if strings.Contains(tableName, ".") { splitStrings := strings.SplitN(tableName, ".", 2) From cfd1cc586aff992a165730b734e496bca1e79d8e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 13 Feb 2018 08:32:22 +0800 Subject: [PATCH 059/881] Add 2D array support, close #1201 --- query_test.go | 30 ++++++++++++++++++++++++++++++ scope.go | 16 ++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/query_test.go b/query_test.go index 80ebd473..fac7d4d8 100644 --- a/query_test.go +++ b/query_test.go @@ -222,6 +222,36 @@ func TestSearchWithPlainSQL(t *testing.T) { } } +func TestSearchWithTwoDimensionalArray(t *testing.T) { + var users []User + user1 := User{Name: "2DSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} + user2 := User{Name: "2DSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} + user3 := User{Name: "2DSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} + DB.Create(&user1) + DB.Create(&user2) + DB.Create(&user3) + + if dialect := DB.Dialect().GetName(); dialect == "mysql" || dialect == "postgres" { + if err := DB.Where("(name, age) IN (?)", [][]interface{}{{"2DSearchUser1", 1}, {"2DSearchUser2", 10}}).Find(&users).Error; err != nil { + t.Errorf("No error should happen when query with 2D array, but got %v", err) + + if len(users) != 2 { + t.Errorf("Should find 2 users with 2D array, but got %v", len(users)) + } + } + } + + if dialect := DB.Dialect().GetName(); dialect == "mssql" { + if err := DB.Joins("JOIN (VALUES ?) AS x (col1, col2) ON x.col1 = name AND x.col2 = age", [][]interface{}{{"2DSearchUser1", 1}, {"2DSearchUser2", 10}}).Find(&users).Error; err != nil { + t.Errorf("No error should happen when query with 2D array, but got %v", err) + + if len(users) != 2 { + t.Errorf("Should find 2 users with 2D array, but got %v", len(users)) + } + } + } +} + func TestSearchWithStruct(t *testing.T) { user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} diff --git a/scope.go b/scope.go index 3fe4675d..cdb772ca 100644 --- a/scope.go +++ b/scope.go @@ -606,6 +606,22 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) replacements = append(replacements, scope.AddToVars(arg)) } else if b, ok := arg.([]byte); ok { replacements = append(replacements, scope.AddToVars(b)) + } else if as, ok := arg.([][]interface{}); ok { + var tempMarks []string + for _, a := range as { + var arrayMarks []string + for _, v := range a { + arrayMarks = append(arrayMarks, scope.AddToVars(v)) + } + + if len(arrayMarks) > 0 { + tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ","))) + } + } + + if len(tempMarks) > 0 { + replacements = append(replacements, strings.Join(tempMarks, ",")) + } } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ { From fe3c94cd2d1eb99270a029e652fc5494e7106ebe Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 13 Feb 2018 09:18:42 +0800 Subject: [PATCH 060/881] Add Take method, close #1228 --- main.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/main.go b/main.go index d342571d..4bbaadab 100644 --- a/main.go +++ b/main.go @@ -280,6 +280,13 @@ func (s *DB) First(out interface{}, where ...interface{}) *DB { inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } +// Take return a record that match given conditions, the order will depend on the database implementation +func (s *DB) Take(out interface{}, where ...interface{}) *DB { + newScope := s.NewScope(out) + newScope.Search.Limit(1) + return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db +} + // Last find last record that match given conditions, order by primary key func (s *DB) Last(out interface{}, where ...interface{}) *DB { newScope := s.NewScope(out) From 67c4280c5721f23bdc13c74733bad922637c5ec1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 13 Feb 2018 10:00:07 +0800 Subject: [PATCH 061/881] Fix support embedded pointer type struct, close #1450 --- embedded_struct_test.go | 18 ++++++++++++++++++ scope.go | 3 +++ 2 files changed, 21 insertions(+) diff --git a/embedded_struct_test.go b/embedded_struct_test.go index 91dd0563..5f8ece57 100644 --- a/embedded_struct_test.go +++ b/embedded_struct_test.go @@ -71,3 +71,21 @@ func TestSaveAndQueryEmbeddedStruct(t *testing.T) { } } } + +func TestEmbeddedPointerTypeStruct(t *testing.T) { + type HNPost struct { + *BasePost + Upvotes int32 + } + + DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}}) + + var hnPost HNPost + if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil { + t.Errorf("No error should happen when find embedded pointer type, but got %v", err) + } + + if hnPost.Title != "embedded_pointer_type" { + t.Errorf("Should find correct value for embedded pointer type") + } +} diff --git a/scope.go b/scope.go index cdb772ca..14baf631 100644 --- a/scope.go +++ b/scope.go @@ -115,6 +115,9 @@ func (scope *Scope) Fields() []*Field { if isStruct { fieldValue := indirectScopeValue for _, name := range structField.Names { + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + } fieldValue = reflect.Indirect(fieldValue).FieldByName(name) } fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}) From becd777b1e2f4a0ce705bfac3a80517ab8ebbb2b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 13 Feb 2018 12:37:39 +0800 Subject: [PATCH 062/881] Fix unicode chars in SQL --- scope.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scope.go b/scope.go index 14baf631..25077efc 100644 --- a/scope.go +++ b/scope.go @@ -649,12 +649,12 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) buff := bytes.NewBuffer([]byte{}) i := 0 - for pos := range str { - if str[pos] == '?' { + for _, s := range str { + if s == '?' { buff.WriteString(replacements[i]) i++ } else { - buff.WriteByte(str[pos]) + buff.WriteRune(s) } } From 1fb623dfbba585fd0c22473d13b1bfdb54d382ca Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 13 Feb 2018 17:59:29 +0800 Subject: [PATCH 063/881] Update README --- README.md | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 6ff49b87..7a861f39 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,12 @@ The fantastic ORM library for Golang, aims to be developer friendly. -[![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm) [![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b) +[![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm) +[![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm) +[![MIT license](http://img.shields.io/badge/license-MIT-brightgreen.svg)](http://opensource.org/licenses/MIT) [![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) ## Overview @@ -24,28 +27,14 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Getting Started -* GORM Guides [jinzhu.github.com/gorm](https://jinzhu.github.io/gorm) +* GORM Guides [http://gorm.io](http://gorm.io) -## Upgrading To V1.0 +## Contributing -* [CHANGELOG](https://jinzhu.github.io/gorm/changelog.html) - -## Supporting the project - -[![http://patreon.com/jinzhu](https://c5.patreon.com/external/logo/become_a_patron_button.png)](http://patreon.com/jinzhu) - -## Author - -**jinzhu** - -* -* -* - -## Contributors - -https://github.com/jinzhu/gorm/graphs/contributors +[Become a backer or sponsor on Open Collective](http://opencollective.com/gorm) +[Become a backer or sponsor on Patreon](http://patreon.com/jinzhu) ## License -Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License). +© 2013~`time.Now()`, Jinzhu +Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License) From 6e1387b44c64dce50b89c2f56ed425f5f73e417c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 13 Feb 2018 18:12:09 +0800 Subject: [PATCH 064/881] Update README --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 7a861f39..caebbcfb 100644 --- a/README.md +++ b/README.md @@ -32,9 +32,11 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Contributing [Become a backer or sponsor on Open Collective](http://opencollective.com/gorm) + [Become a backer or sponsor on Patreon](http://patreon.com/jinzhu) ## License -© 2013~`time.Now()`, Jinzhu +© Jinzhu, 2013~time.Now + Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License) From 55945afb346c0ca3e62f9cb44d73ff62bc2cce2e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 17 Feb 2018 00:33:52 +0800 Subject: [PATCH 065/881] Update README --- README.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index caebbcfb..0c5c7ea6 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. * Full-Featured ORM (almost) * Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism) -* Callbacks (Before/After Create/Save/Update/Delete/Find) +* Hooks (Before/After Create/Save/Update/Delete/Find) * Preloading (eager loading) * Transactions * Composite Primary Key @@ -31,9 +31,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Contributing -[Become a backer or sponsor on Open Collective](http://opencollective.com/gorm) - -[Become a backer or sponsor on Patreon](http://patreon.com/jinzhu) +[You can help to deliver a better GORM, check out things you can do](http://gorm.io/contribute.html) ## License From 58e34726dfc069b558038efbaa25555f182d1f7a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 18 Feb 2018 09:00:03 +0800 Subject: [PATCH 066/881] Don't access scanner's fields if already defined data type --- dialect.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/dialect.go b/dialect.go index b20bfd5b..5f6439c1 100644 --- a/dialect.go +++ b/dialect.go @@ -94,14 +94,16 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel } // Get scanner's real value - var getScannerValue func(reflect.Value) - getScannerValue = func(value reflect.Value) { - fieldValue = value - if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { - getScannerValue(fieldValue.Field(0)) + if dataType == "" { + var getScannerValue func(reflect.Value) + getScannerValue = func(value reflect.Value) { + fieldValue = value + if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { + getScannerValue(fieldValue.Field(0)) + } } + getScannerValue(fieldValue) } - getScannerValue(fieldValue) // Default Size if num, ok := field.TagSettings["SIZE"]; ok { From 48a20a6e9f3f4d26095df82c3337efec6db0a6fc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 22 Feb 2018 12:04:12 +0800 Subject: [PATCH 067/881] Add SubQuery method --- main.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/main.go b/main.go index 4bbaadab..c26e05c8 100644 --- a/main.go +++ b/main.go @@ -177,6 +177,15 @@ func (s *DB) QueryExpr() *expr { return Expr(scope.SQL, scope.SQLVars...) } +// SubQuery returns the query as sub query +func (s *DB) SubQuery() *expr { + scope := s.NewScope(s.Value) + scope.InstanceSet("skip_bindvar", true) + scope.prepareQuerySQL() + + return Expr(fmt.Sprintf("(%v)", scope.SQL), scope.SQLVars...) +} + // Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query func (s *DB) Where(query interface{}, args ...interface{}) *DB { return s.clone().search.Where(query, args...).db From a12c2a2e13b0f644647dbd369a88b01fac109bd5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 27 Feb 2018 10:48:59 +0800 Subject: [PATCH 068/881] Remove mysql8 from CI --- wercker.yml | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/wercker.yml b/wercker.yml index 2f2370b3..0c3e73ef 100644 --- a/wercker.yml +++ b/wercker.yml @@ -9,13 +9,6 @@ services: MYSQL_USER: gorm MYSQL_PASSWORD: gorm MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql - id: mysql:8 - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - name: mysql57 id: mysql:5.7 env: @@ -109,11 +102,6 @@ build: code: | GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test ./... - - script: - name: test mysql - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test ./... - - script: name: test mysql5.7 code: | From 6ed508ec6a4ecb3531899a69cbc746ccf65a4166 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 28 Feb 2018 07:43:56 +0800 Subject: [PATCH 069/881] Fix panic with raw SQL --- scope.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope.go b/scope.go index 25077efc..150ac710 100644 --- a/scope.go +++ b/scope.go @@ -650,7 +650,7 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) buff := bytes.NewBuffer([]byte{}) i := 0 for _, s := range str { - if s == '?' { + if s == '?' && len(replacements) > i { buff.WriteString(replacements[i]) i++ } else { From 52c5c8127cf4aeffde3e0aa9222640832075a90f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sa=C3=BAl=20Ortega?= Date: Thu, 15 Mar 2018 09:35:31 -0500 Subject: [PATCH 070/881] Support for UTF8 names on DB (#1793) --- scope.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scope.go b/scope.go index 150ac710..2f39e073 100644 --- a/scope.go +++ b/scope.go @@ -692,12 +692,12 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) buff := bytes.NewBuffer([]byte{}) i := 0 - for pos := range str { + for pos, char := range str { if str[pos] == '?' { buff.WriteString(replacements[i]) i++ } else { - buff.WriteByte(str[pos]) + buff.WriteRune(char) } } From 919c6db4f854e4feaae94202ae29da4e3779de49 Mon Sep 17 00:00:00 2001 From: Giuseppe Date: Mon, 16 Apr 2018 16:18:51 +0200 Subject: [PATCH 071/881] Do not panic if Begin().Error was ignored (#1830) --- main.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index c26e05c8..ffee4ec6 100644 --- a/main.go +++ b/main.go @@ -491,7 +491,8 @@ func (s *DB) Begin() *DB { // Commit commit a transaction func (s *DB) Commit() *DB { - if db, ok := s.db.(sqlTx); ok && db != nil { + var emptySQLTx *sql.Tx + if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { s.AddError(db.Commit()) } else { s.AddError(ErrInvalidTransaction) From 6842b49a1ad0feb6b93be830fe63a682cf853ada Mon Sep 17 00:00:00 2001 From: Shane Date: Mon, 16 Apr 2018 07:20:02 -0700 Subject: [PATCH 072/881] fix scope.removeForeignKey method (#1841) --- scope.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope.go b/scope.go index 2f39e073..397ccf0b 100644 --- a/scope.go +++ b/scope.go @@ -1215,7 +1215,7 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on } func (scope *Scope) removeForeignKey(field string, dest string) { - keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest) + keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return From 35efe68ba71d571e64ccd1ee62830c30a53ed967 Mon Sep 17 00:00:00 2001 From: Daniel McDonald Date: Wed, 2 May 2018 07:37:51 -0700 Subject: [PATCH 073/881] add simple input validation on gorm.Open function (#1855) Simply check if the passed-in database source meets the expected types and, if not, early return with error. --- main.go | 2 ++ main_test.go | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/main.go b/main.go index ffee4ec6..c8a43e8c 100644 --- a/main.go +++ b/main.go @@ -61,6 +61,8 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { dbSQL, err = sql.Open(driver, source) case SQLCommon: dbSQL = value + default: + return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) } db = &DB{ diff --git a/main_test.go b/main_test.go index 66c46af0..265e0be7 100644 --- a/main_test.go +++ b/main_test.go @@ -8,6 +8,7 @@ import ( "path/filepath" "reflect" "strconv" + "strings" "testing" "time" @@ -79,6 +80,22 @@ func OpenTestConnection() (db *gorm.DB, err error) { return } +func TestOpen_ReturnsError_WithBadArgs(t *testing.T) { + stringRef := "foo" + testCases := []interface{}{42, time.Now(), &stringRef} + for _, tc := range testCases { + t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) { + _, err := gorm.Open("postgresql", tc) + if err == nil { + t.Error("Should got error with invalid database source") + } + if !strings.HasPrefix(err.Error(), "invalid database source:") { + t.Errorf("Should got error starting with \"invalid database source:\", but got %q", err.Error()) + } + }) + } +} + func TestStringPrimaryKey(t *testing.T) { type UUIDStruct struct { ID string `gorm:"primary_key"` From 9044197ef935c0969d94cbcfba55ccb94d269bed Mon Sep 17 00:00:00 2001 From: Illya Busigin Date: Wed, 2 May 2018 09:38:52 -0500 Subject: [PATCH 074/881] Adding GetDialect function (#1869) --- dialect.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dialect.go b/dialect.go index 5f6439c1..506a6e86 100644 --- a/dialect.go +++ b/dialect.go @@ -72,6 +72,12 @@ func RegisterDialect(name string, dialect Dialect) { dialectsMap[name] = dialect } +// GetDialect gets the dialect for the specified dialect name +func GetDialect(name string) (dialect Dialect, ok bool) { + dialect, ok = dialectsMap[name] + return +} + // ParseFieldStructForDialect get field's sql data type var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { // Get redirected field type From a58b98acee2f3bf213b2cb0f1fe1468f236c9aec Mon Sep 17 00:00:00 2001 From: lrita Date: Sat, 12 May 2018 14:28:15 +0800 Subject: [PATCH 075/881] Do not panic if Begin().Error was ignored (#1830) (#1881) --- main.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index c8a43e8c..25c3a06b 100644 --- a/main.go +++ b/main.go @@ -504,7 +504,8 @@ func (s *DB) Commit() *DB { // Rollback rollback a transaction func (s *DB) Rollback() *DB { - if db, ok := s.db.(sqlTx); ok && db != nil { + var emptySQLTx *sql.Tx + if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { s.AddError(db.Rollback()) } else { s.AddError(ErrInvalidTransaction) From 82eb9f8a5bbb5e6b929d2f0ae5b934e6a253f94e Mon Sep 17 00:00:00 2001 From: Olga Kleitsa Date: Sat, 12 May 2018 09:29:00 +0300 Subject: [PATCH 076/881] included actual sql query to discover fi foreign key with the same name exists in a specific table of the database in use (#1896) --- dialects/mssql/mssql.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index e0606465..a8d3c45a 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -130,7 +130,14 @@ func (s mssql) RemoveIndex(tableName string, indexName string) error { } func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { - return false + var count int + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow(`SELECT count(*) + FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id + inner join information_schema.tables as I on I.TABLE_NAME = T.name + WHERE F.name = ? + AND T.Name = ? AND I.TABLE_CATALOG = ?;`, foreignKeyName, tableName, currentDatabase).Scan(&count) + return count > 0 } func (s mssql) HasTable(tableName string) bool { From 1907bff3732cb4c612e4118137d8f3c8829cc8c6 Mon Sep 17 00:00:00 2001 From: ia Date: Mon, 25 Jun 2018 07:06:58 +0200 Subject: [PATCH 077/881] all: gofmt (#1956) Run standard gofmt command on project root. - go version go1.10.3 darwin/amd64 Signed-off-by: ia --- dialects/postgres/postgres.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 1d0dcb60..424e8bdc 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -4,11 +4,11 @@ import ( "database/sql" "database/sql/driver" - _ "github.com/lib/pq" - "github.com/lib/pq/hstore" "encoding/json" "errors" "fmt" + _ "github.com/lib/pq" + "github.com/lib/pq/hstore" ) type Hstore map[string]*string From 0fd395ab37aefd2d50854f0556a4311dccc6f45a Mon Sep 17 00:00:00 2001 From: Masaki Yoshida Date: Mon, 25 Jun 2018 14:07:53 +0900 Subject: [PATCH 078/881] Fix ToDBName (#1941) Don't place '_' before number. - NG: SHA256Hash -> sha_256_hash - OK: SHA256Hash -> sha256_hash --- utils.go | 12 +++++++----- utils_test.go | 3 +++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/utils.go b/utils.go index dfaae939..99b532c5 100644 --- a/utils.go +++ b/utils.go @@ -78,16 +78,18 @@ func ToDBName(name string) string { } var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase strCase + value = commonInitialismsReplacer.Replace(name) + buf = bytes.NewBufferString("") + lastCase, currCase, nextCase, nextNumber strCase ) for i, v := range value[:len(value)-1] { nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z') + nextNumber = strCase(value[i+1] >= '0' && value[i+1] <= '9') + if i > 0 { if currCase == upper { - if lastCase == upper && nextCase == upper { + if lastCase == upper && (nextCase == upper || nextNumber == upper) { buf.WriteRune(v) } else { if value[i-1] != '_' && value[i+1] != '_' { @@ -97,7 +99,7 @@ func ToDBName(name string) string { } } else { buf.WriteRune(v) - if i == len(value)-2 && nextCase == upper { + if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { buf.WriteRune('_') } } diff --git a/utils_test.go b/utils_test.go index 152296d2..086c4450 100644 --- a/utils_test.go +++ b/utils_test.go @@ -15,6 +15,9 @@ func TestToDBNameGenerateFriendlyName(t *testing.T) { "AbcAndJkl": "abc_and_jkl", "EmployeeID": "employee_id", "SKU_ID": "sku_id", + "UTF8": "utf8", + "Level1": "level1", + "SHA256Hash": "sha256_hash", "FieldX": "field_x", "HTTPAndSMTP": "http_and_smtp", "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", From dbb25e94879f463c699430a74d29c9557e15a60f Mon Sep 17 00:00:00 2001 From: Louis Brauer Date: Fri, 27 Jul 2018 01:30:57 +0200 Subject: [PATCH 079/881] Adding json type for mssql dialect, similar to postgres.Jsonb (#1934) * Adding json type for mssql dialect, similar to postgres.Jsonb * Adding proper comments --- dialects/mssql/mssql.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index a8d3c45a..731721cb 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -1,12 +1,16 @@ package mssql import ( + "database/sql/driver" + "encoding/json" + "errors" "fmt" "reflect" "strconv" "strings" "time" + // Importing mssql driver package only in dialect file, otherwide not needed _ "github.com/denisenkom/go-mssqldb" "github.com/jinzhu/gorm" ) @@ -201,3 +205,27 @@ func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, st } return dialect.CurrentDatabase(), tableName } + +// JSON type to support easy handling of JSON data in character table fields +// using golang json.RawMessage for deferred decoding/encoding +type JSON struct { + json.RawMessage +} + +// Value get value of JSON +func (j JSON) Value() (driver.Value, error) { + if len(j.RawMessage) == 0 { + return nil, nil + } + return j.MarshalJSON() +} + +// Scan scan value into JSON +func (j *JSON) Scan(value interface{}) error { + str, ok := value.(string) + if !ok { + return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value)) + } + bytes := []byte(str) + return json.Unmarshal(bytes, j) +} From ac3ec858a6c375a466f613c86b053726abbe3755 Mon Sep 17 00:00:00 2001 From: Kevin Date: Thu, 26 Jul 2018 19:35:53 -0400 Subject: [PATCH 080/881] Edit DB.clone(), DB.Dialect(), and Scope.Dialect() preserve transactions (#1939) * Edit DB.clone(), DB.Dialect(), and Scope.Dialect() preserve transactions. * Adds a test case for tables creations and autoMigrate in the same transaction. --- main.go | 5 ++++- migration_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ scope.go | 2 +- 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/main.go b/main.go index 25c3a06b..3a5d6b0c 100644 --- a/main.go +++ b/main.go @@ -119,7 +119,7 @@ func (s *DB) CommonDB() SQLCommon { // Dialect get dialect func (s *DB) Dialect() Dialect { - return s.parent.dialect + return s.dialect } // Callback return `Callbacks` container, you could add/change/delete callbacks with it @@ -484,6 +484,8 @@ func (s *DB) Begin() *DB { if db, ok := c.db.(sqlDb); ok && db != nil { tx, err := db.Begin() c.db = interface{}(tx).(SQLCommon) + + c.dialect.SetDB(c.db) c.AddError(err) } else { c.AddError(ErrCantStartTransaction) @@ -748,6 +750,7 @@ func (s *DB) clone() *DB { Value: s.Value, Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate, + dialect: newDialect(s.dialect.GetName(), s.db), } for key, value := range s.values { diff --git a/migration_test.go b/migration_test.go index 7c694485..78555dcc 100644 --- a/migration_test.go +++ b/migration_test.go @@ -398,6 +398,53 @@ func TestAutoMigration(t *testing.T) { } } +func TestCreateAndAutomigrateTransaction(t *testing.T) { + tx := DB.Begin() + + func() { + type Bar struct { + ID uint + } + DB.DropTableIfExists(&Bar{}) + + if ok := DB.HasTable("bars"); ok { + t.Errorf("Table should not exist, but does") + } + + if ok := tx.HasTable("bars"); ok { + t.Errorf("Table should not exist, but does") + } + }() + + func() { + type Bar struct { + Name string + } + err := tx.CreateTable(&Bar{}).Error + + if err != nil { + t.Errorf("Should have been able to create the table, but couldn't: %s", err) + } + + if ok := tx.HasTable(&Bar{}); !ok { + t.Errorf("The transaction should be able to see the table") + } + }() + + func() { + type Bar struct { + Stuff string + } + + err := tx.AutoMigrate(&Bar{}).Error + if err != nil { + t.Errorf("Should have been able to alter the table, but couldn't") + } + }() + + tx.Rollback() +} + type MultipleIndexes struct { ID int64 UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` diff --git a/scope.go b/scope.go index 397ccf0b..5eb98963 100644 --- a/scope.go +++ b/scope.go @@ -63,7 +63,7 @@ func (scope *Scope) SQLDB() SQLCommon { // Dialect get dialect func (scope *Scope) Dialect() Dialect { - return scope.db.parent.dialect + return scope.db.dialect } // Quote used to quote string to escape them for database From 588e2eef5d9c33b11ee52895ad5cdfab0d6648e6 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Fri, 27 Jul 2018 07:38:02 +0800 Subject: [PATCH 081/881] Fix typo in query_test (#1977) --- query_test.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/query_test.go b/query_test.go index fac7d4d8..15bf8b3c 100644 --- a/query_test.go +++ b/query_test.go @@ -181,17 +181,17 @@ func TestSearchWithPlainSQL(t *testing.T) { scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users) if len(users) != 2 { - t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users)) + t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users)) } scopedb.Where("birthday > ?", "2002-10-10").Find(&users) if len(users) != 2 { - t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users)) + t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users)) } scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) if len(users) != 1 { - t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) + t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) } DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) @@ -532,28 +532,28 @@ func TestNot(t *testing.T) { DB.Table("users").Where("name = ?", "user3").Count(&name3Count) DB.Not("name", "user3").Find(&users4) if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not("name = ?", "user3").Find(&users4) if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not("name <> ?", "user3").Find(&users4) if len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(User{Name: "user3"}).Find(&users5) if len(users1)-len(users5) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6) if len(users1)-len(users6) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) @@ -563,14 +563,14 @@ func TestNot(t *testing.T) { DB.Not("name", []string{"user3"}).Find(&users8) if len(users1)-len(users8) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } var name2Count int64 DB.Table("users").Where("name = ?", "user2").Count(&name2Count) DB.Not("name", []string{"user3", "user2"}).Find(&users9) if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } } From d68403b29dbf3086b2335f6381545462d96808bc Mon Sep 17 00:00:00 2001 From: antness Date: Fri, 27 Jul 2018 02:43:09 +0300 Subject: [PATCH 082/881] do not close wrapped *sql.DB (#1985) --- main.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index 3a5d6b0c..de6ce428 100644 --- a/main.go +++ b/main.go @@ -48,6 +48,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { } var source string var dbSQL SQLCommon + var ownDbSQL bool switch value := args[0].(type) { case string: @@ -59,8 +60,10 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { source = args[1].(string) } dbSQL, err = sql.Open(driver, source) + ownDbSQL = true case SQLCommon: dbSQL = value + ownDbSQL = false default: return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) } @@ -78,7 +81,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { } // Send a ping to make sure the database connection is alive. if d, ok := dbSQL.(*sql.DB); ok { - if err = d.Ping(); err != nil { + if err = d.Ping(); err != nil && ownDbSQL { d.Close() } } From 409121d9e394922787885b001d148a05e3a42b6c Mon Sep 17 00:00:00 2001 From: Alexey <10kdmg@gmail.com> Date: Fri, 27 Jul 2018 02:43:49 +0300 Subject: [PATCH 083/881] Fixed mysql query syntax for FK removal (#1993) --- scope.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/scope.go b/scope.go index 5eb98963..a05c1d61 100644 --- a/scope.go +++ b/scope.go @@ -1216,11 +1216,17 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on func (scope *Scope) removeForeignKey(field string, dest string) { keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return } - var query = `ALTER TABLE %s DROP CONSTRAINT %s;` + var mysql mysql + var query string + if scope.Dialect().GetName() == mysql.GetName() { + query = `ALTER TABLE %s DROP FOREIGN KEY %s;` + } else { + query = `ALTER TABLE %s DROP CONSTRAINT %s;` + } + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() } From 0e04d414d59f3154d700692bda0d7649d0e101b3 Mon Sep 17 00:00:00 2001 From: Artemij Shepelev Date: Sun, 19 Aug 2018 02:09:21 +0300 Subject: [PATCH 084/881] Race fix. Changes modelStructsMap implementation from map with mutex to sync.Map (#2022) * fix (https://github.com/jinzhu/gorm/issues/1407) * changed map with mutex to sync.Map (https://github.com/jinzhu/gorm/issues/1407) * removed newModelStructsMap func * commit to rerun pipeline, comment changed --- main.go | 3 ++- model_struct.go | 31 +++++-------------------------- 2 files changed, 7 insertions(+), 27 deletions(-) diff --git a/main.go b/main.go index de6ce428..993e19b1 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "sync" "time" ) @@ -162,7 +163,7 @@ func (s *DB) HasBlockGlobalUpdate() bool { // SingularTable use singular table by default func (s *DB) SingularTable(enable bool) { - modelStructsMap = newModelStructsMap() + modelStructsMap = sync.Map{} s.parent.singularTable = enable } diff --git a/model_struct.go b/model_struct.go index f571e2e8..8506fe87 100644 --- a/model_struct.go +++ b/model_struct.go @@ -17,28 +17,7 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { return defaultTableName } -type safeModelStructsMap struct { - m map[reflect.Type]*ModelStruct - l *sync.RWMutex -} - -func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newModelStructsMap() *safeModelStructsMap { - return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)} -} - -var modelStructsMap = newModelStructsMap() +var modelStructsMap sync.Map // ModelStruct model definition type ModelStruct struct { @@ -48,7 +27,7 @@ type ModelStruct struct { defaultTableName string } -// TableName get model's table name +// TableName returns model's table name func (s *ModelStruct) TableName(db *DB) string { if s.defaultTableName == "" && db != nil && s.ModelType != nil { // Set default table name @@ -152,8 +131,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Get Cached model struct - if value := modelStructsMap.Get(reflectType); value != nil { - return value + if value, ok := modelStructsMap.Load(reflectType); ok && value != nil { + return value.(*ModelStruct) } modelStruct.ModelType = reflectType @@ -601,7 +580,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } - modelStructsMap.Set(reflectType, &modelStruct) + modelStructsMap.Store(reflectType, &modelStruct) return &modelStruct } From 31ec9255cdc16482f5bef2ceb996ba75ba750a8a Mon Sep 17 00:00:00 2001 From: Elliott <617942+ellman121@users.noreply.github.com> Date: Sun, 19 Aug 2018 01:11:27 +0200 Subject: [PATCH 085/881] Setting gorm:auto_preload to false now prevents preloading (#2031) --- callback_query_preload.go | 10 ++++++++-- preload_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/callback_query_preload.go b/callback_query_preload.go index 30f6b585..481bfbe3 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -14,8 +14,14 @@ func preloadCallback(scope *Scope) { return } - if _, ok := scope.Get("gorm:auto_preload"); ok { - autoPreload(scope) + if ap, ok := scope.Get("gorm:auto_preload"); ok { + // If gorm:auto_preload IS NOT a bool then auto preload. + // Else if it IS a bool, use the value + if apb, ok := ap.(bool); !ok { + autoPreload(scope) + } else if apb { + autoPreload(scope) + } } if scope.Search.preload == nil || scope.HasError() { diff --git a/preload_test.go b/preload_test.go index 311ad0be..1db625c9 100644 --- a/preload_test.go +++ b/preload_test.go @@ -123,6 +123,31 @@ func TestAutoPreload(t *testing.T) { } } +func TestAutoPreloadFalseDoesntPreload(t *testing.T) { + user1 := getPreloadUser("auto_user1") + DB.Save(user1) + + preloadDB := DB.Set("gorm:auto_preload", false).Where("role = ?", "Preload") + var user User + preloadDB.Find(&user) + + if user.BillingAddress.Address1 != "" { + t.Error("AutoPreload was set to fasle, but still fetched data") + } + + user2 := getPreloadUser("auto_user2") + DB.Save(user2) + + var users []User + preloadDB.Find(&users) + + for _, user := range users { + if user.BillingAddress.Address1 != "" { + t.Error("AutoPreload was set to fasle, but still fetched data") + } + } +} + func TestNestedPreload1(t *testing.T) { type ( Level1 struct { From 53995294ef73980d6eacee993ffa8bcdf769a0e2 Mon Sep 17 00:00:00 2001 From: hector <1069315972@qq.com> Date: Sun, 19 Aug 2018 07:13:16 +0800 Subject: [PATCH 086/881] Change buildCondition TableName to struct's TableName when query is interface{} (#2011) --- scope.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scope.go b/scope.go index a05c1d61..ca861d8a 100644 --- a/scope.go +++ b/scope.go @@ -586,10 +586,10 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) scope.Err(fmt.Errorf("invalid query condition: %v", value)) return } - + scopeQuotedTableName := newScope.QuotedTableName() for _, field := range newScope.Fields() { if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) + sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) } } return strings.Join(sqls, " AND ") From 32455088f24d6b1e9a502fb8e40fdc16139dbea8 Mon Sep 17 00:00:00 2001 From: Eason Lin Date: Sun, 19 Aug 2018 07:14:33 +0800 Subject: [PATCH 087/881] doc: document ErrRecordNotFound error more clear (#2015) * doc: document ErrRecordNotFound error more clear * fix goimports * fix goimports * undo change --- errors.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/errors.go b/errors.go index da2cf13c..27c9a92d 100644 --- a/errors.go +++ b/errors.go @@ -6,7 +6,7 @@ import ( ) var ( - // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct + // ErrRecordNotFound record not found error, happens when only haven't find any matched data when looking up with a struct, finding a slice won't return this error ErrRecordNotFound = errors.New("record not found") // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL ErrInvalidSQL = errors.New("invalid SQL") From 6f58f8a52cc3ad21950402d1adaa09682e07ec2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adem=20=C3=96zay?= Date: Mon, 10 Sep 2018 00:52:20 +0300 Subject: [PATCH 088/881] added naming strategy option for db, table and column names (#2040) --- model_struct.go | 12 ++--- naming.go | 124 ++++++++++++++++++++++++++++++++++++++++++++++++ naming_test.go | 69 +++++++++++++++++++++++++++ scope.go | 4 +- utils.go | 61 ------------------------ utils_test.go | 35 -------------- 6 files changed, 201 insertions(+), 104 deletions(-) create mode 100644 naming.go create mode 100644 naming_test.go delete mode 100644 utils_test.go diff --git a/model_struct.go b/model_struct.go index 8506fe87..5b5be618 100644 --- a/model_struct.go +++ b/model_struct.go @@ -34,7 +34,7 @@ func (s *ModelStruct) TableName(db *DB) string { if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { s.defaultTableName = tabler.TableName() } else { - tableName := ToDBName(s.ModelType.Name()) + tableName := ToTableName(s.ModelType.Name()) if db == nil || !db.parent.singularTable { tableName = inflection.Plural(tableName) } @@ -105,7 +105,7 @@ type Relationship struct { func getForeignField(column string, fields []*StructField) *StructField { for _, field := range fields { - if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) { + if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) { return field } } @@ -269,7 +269,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { // if defined join table's foreign key relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) } else { - defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName + defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) } } @@ -300,7 +300,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) } else { // join table foreign keys for association - joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) } } @@ -308,7 +308,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, many2many, reflectType, elemType) + joinTableHandler.Setup(relationship, ToTableName(many2many), reflectType, elemType) relationship.JoinTableHandler = &joinTableHandler field.Relationship = relationship } else { @@ -566,7 +566,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if value, ok := field.TagSettings["COLUMN"]; ok { field.DBName = value } else { - field.DBName = ToDBName(fieldStruct.Name) + field.DBName = ToColumnName(fieldStruct.Name) } modelStruct.StructFields = append(modelStruct.StructFields, field) diff --git a/naming.go b/naming.go new file mode 100644 index 00000000..6b0a4fdd --- /dev/null +++ b/naming.go @@ -0,0 +1,124 @@ +package gorm + +import ( + "bytes" + "strings" +) + +// Namer is a function type which is given a string and return a string +type Namer func(string) string + +// NamingStrategy represents naming strategies +type NamingStrategy struct { + DB Namer + Table Namer + Column Namer +} + +// TheNamingStrategy is being initialized with defaultNamingStrategy +var TheNamingStrategy = &NamingStrategy{ + DB: defaultNamer, + Table: defaultNamer, + Column: defaultNamer, +} + +// AddNamingStrategy sets the naming strategy +func AddNamingStrategy(ns *NamingStrategy) { + if ns.DB == nil { + ns.DB = defaultNamer + } + if ns.Table == nil { + ns.Table = defaultNamer + } + if ns.Column == nil { + ns.Column = defaultNamer + } + TheNamingStrategy = ns +} + +// DBName alters the given name by DB +func (ns *NamingStrategy) DBName(name string) string { + return ns.DB(name) +} + +// TableName alters the given name by Table +func (ns *NamingStrategy) TableName(name string) string { + return ns.Table(name) +} + +// ColumnName alters the given name by Column +func (ns *NamingStrategy) ColumnName(name string) string { + return ns.Column(name) +} + +// ToDBName convert string to db name +func ToDBName(name string) string { + return TheNamingStrategy.DBName(name) +} + +// ToTableName convert string to table name +func ToTableName(name string) string { + return TheNamingStrategy.TableName(name) +} + +// ToColumnName convert string to db name +func ToColumnName(name string) string { + return TheNamingStrategy.ColumnName(name) +} + +var smap = newSafeMap() + +func defaultNamer(name string) string { + const ( + lower = false + upper = true + ) + + if v := smap.Get(name); v != "" { + return v + } + + if name == "" { + return "" + } + + var ( + value = commonInitialismsReplacer.Replace(name) + buf = bytes.NewBufferString("") + lastCase, currCase, nextCase, nextNumber bool + ) + + for i, v := range value[:len(value)-1] { + nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z') + nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9') + + if i > 0 { + if currCase == upper { + if lastCase == upper && (nextCase == upper || nextNumber == upper) { + buf.WriteRune(v) + } else { + if value[i-1] != '_' && value[i+1] != '_' { + buf.WriteRune('_') + } + buf.WriteRune(v) + } + } else { + buf.WriteRune(v) + if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { + buf.WriteRune('_') + } + } + } else { + currCase = upper + buf.WriteRune(v) + } + lastCase = currCase + currCase = nextCase + } + + buf.WriteByte(value[len(value)-1]) + + s := strings.ToLower(buf.String()) + smap.Set(name, s) + return s +} diff --git a/naming_test.go b/naming_test.go new file mode 100644 index 00000000..0c6f7713 --- /dev/null +++ b/naming_test.go @@ -0,0 +1,69 @@ +package gorm_test + +import ( + "testing" + + "github.com/jinzhu/gorm" +) + +func TestTheNamingStrategy(t *testing.T) { + + cases := []struct { + name string + namer gorm.Namer + expected string + }{ + {name: "auth", expected: "auth", namer: gorm.TheNamingStrategy.DB}, + {name: "userRestrictions", expected: "user_restrictions", namer: gorm.TheNamingStrategy.Table}, + {name: "clientID", expected: "client_id", namer: gorm.TheNamingStrategy.Column}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := c.namer(c.name) + if result != c.expected { + t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) + } + }) + } + +} + +func TestNamingStrategy(t *testing.T) { + + dbNameNS := func(name string) string { + return "db_" + name + } + tableNameNS := func(name string) string { + return "tbl_" + name + } + columnNameNS := func(name string) string { + return "col_" + name + } + + ns := &gorm.NamingStrategy{ + DB: dbNameNS, + Table: tableNameNS, + Column: columnNameNS, + } + + cases := []struct { + name string + namer gorm.Namer + expected string + }{ + {name: "auth", expected: "db_auth", namer: ns.DB}, + {name: "user", expected: "tbl_user", namer: ns.Table}, + {name: "password", expected: "col_password", namer: ns.Column}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := c.namer(c.name) + if result != c.expected { + t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) + } + }) + } + +} diff --git a/scope.go b/scope.go index ca861d8a..fbf7634e 100644 --- a/scope.go +++ b/scope.go @@ -134,7 +134,7 @@ func (scope *Scope) Fields() []*Field { // FieldByName find `gorm.Field` with field name or db name func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { var ( - dbName = ToDBName(name) + dbName = ToColumnName(name) mostMatchedField *Field ) @@ -880,7 +880,7 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string switch reflectValue.Kind() { case reflect.Map: for _, key := range reflectValue.MapKeys() { - attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() + attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() } default: for _, field := range (&Scope{Value: values}).Fields() { diff --git a/utils.go b/utils.go index 99b532c5..ad700b98 100644 --- a/utils.go +++ b/utils.go @@ -1,7 +1,6 @@ package gorm import ( - "bytes" "database/sql/driver" "fmt" "reflect" @@ -58,66 +57,6 @@ func newSafeMap() *safeMap { return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} } -var smap = newSafeMap() - -type strCase bool - -const ( - lower strCase = false - upper strCase = true -) - -// ToDBName convert string to db name -func ToDBName(name string) string { - if v := smap.Get(name); v != "" { - return v - } - - if name == "" { - return "" - } - - var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase, nextNumber strCase - ) - - for i, v := range value[:len(value)-1] { - nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z') - nextNumber = strCase(value[i+1] >= '0' && value[i+1] <= '9') - - if i > 0 { - if currCase == upper { - if lastCase == upper && (nextCase == upper || nextNumber == upper) { - buf.WriteRune(v) - } else { - if value[i-1] != '_' && value[i+1] != '_' { - buf.WriteRune('_') - } - buf.WriteRune(v) - } - } else { - buf.WriteRune(v) - if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { - buf.WriteRune('_') - } - } - } else { - currCase = upper - buf.WriteRune(v) - } - lastCase = currCase - currCase = nextCase - } - - buf.WriteByte(value[len(value)-1]) - - s := strings.ToLower(buf.String()) - smap.Set(name, s) - return s -} - // SQL expression type expr struct { expr string diff --git a/utils_test.go b/utils_test.go deleted file mode 100644 index 086c4450..00000000 --- a/utils_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package gorm_test - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -func TestToDBNameGenerateFriendlyName(t *testing.T) { - var maps = map[string]string{ - "": "", - "X": "x", - "ThisIsATest": "this_is_a_test", - "PFAndESI": "pf_and_esi", - "AbcAndJkl": "abc_and_jkl", - "EmployeeID": "employee_id", - "SKU_ID": "sku_id", - "UTF8": "utf8", - "Level1": "level1", - "SHA256Hash": "sha256_hash", - "FieldX": "field_x", - "HTTPAndSMTP": "http_and_smtp", - "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", - "UUID": "uuid", - "HTTPURL": "http_url", - "HTTP_URL": "http_url", - "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", - } - - for key, value := range maps { - if gorm.ToDBName(key) != value { - t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key)) - } - } -} From dc3b2476c4eb61c37424a1ca2f46859e4e6fcd81 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 10 Sep 2018 06:03:41 +0800 Subject: [PATCH 089/881] Don't save ignored fields into database --- callback_create.go | 2 +- scope.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/callback_create.go b/callback_create.go index e7fe6f86..2ab05d3b 100644 --- a/callback_create.go +++ b/callback_create.go @@ -59,7 +59,7 @@ func createCallback(scope *Scope) { for _, field := range scope.Fields() { if scope.changeableField(field) { - if field.IsNormal { + if field.IsNormal && !field.IsIgnored { if field.IsBlank && field.HasDefaultValue { blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) diff --git a/scope.go b/scope.go index fbf7634e..7d6ba1c0 100644 --- a/scope.go +++ b/scope.go @@ -907,7 +907,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin results[field.DBName] = value } else { err := field.Set(value) - if field.IsNormal { + if field.IsNormal && !field.IsIgnored { hasUpdate = true if err == ErrUnaddressable { results[field.DBName] = value From 71b7f19aad77eaf99a90324c7d2ac5634eaefca8 Mon Sep 17 00:00:00 2001 From: Xy Ziemba Date: Sun, 9 Sep 2018 15:12:58 -0700 Subject: [PATCH 090/881] Fix scanning identical column names occurring >2 times (#2080) Fix the indexing logic used in selectedColumnsMap to skip fields that have already been seen. The values of selectedColumns map must be indexed relative to fields, not relative to selectFields. --- main_test.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++ scope.go | 6 ++++-- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/main_test.go b/main_test.go index 265e0be7..11c4bb87 100644 --- a/main_test.go +++ b/main_test.go @@ -581,6 +581,60 @@ func TestJoins(t *testing.T) { } } +type JoinedIds struct { + UserID int64 `gorm:"column:id"` + BillingAddressID int64 `gorm:"column:id"` + EmailID int64 `gorm:"column:id"` +} + +func TestScanIdenticalColumnNames(t *testing.T) { + var user = User{ + Name: "joinsIds", + Email: "joinIds@example.com", + BillingAddress: Address{ + Address1: "One Park Place", + }, + Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, + } + DB.Save(&user) + + var users []JoinedIds + DB.Select("users.id, addresses.id, emails.id").Table("users"). + Joins("left join addresses on users.billing_address_id = addresses.id"). + Joins("left join emails on emails.user_id = users.id"). + Where("name = ?", "joinsIds").Scan(&users) + + if len(users) != 2 { + t.Fatal("should find two rows using left join") + } + + if user.Id != users[0].UserID { + t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[0].UserID) + } + if user.Id != users[1].UserID { + t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[1].UserID) + } + + if user.BillingAddressID.Int64 != users[0].BillingAddressID { + t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID) + } + if user.BillingAddressID.Int64 != users[1].BillingAddressID { + t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID) + } + + if users[0].EmailID == users[1].EmailID { + t.Errorf("Email ids should be unique. Got %d and %d", users[0].EmailID, users[1].EmailID) + } + + if int64(user.Emails[0].Id) != users[0].EmailID && int64(user.Emails[1].Id) != users[0].EmailID { + t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[0].EmailID) + } + + if int64(user.Emails[0].Id) != users[1].EmailID && int64(user.Emails[1].Id) != users[1].EmailID { + t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[1].EmailID) + } +} + func TestJoinsWithSelect(t *testing.T) { type result struct { Name string diff --git a/scope.go b/scope.go index 7d6ba1c0..ce80ab86 100644 --- a/scope.go +++ b/scope.go @@ -486,8 +486,10 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { values[index] = &ignored selectFields = fields + offset := 0 if idx, ok := selectedColumnsMap[column]; ok { - selectFields = selectFields[idx+1:] + offset = idx + 1 + selectFields = selectFields[offset:] } for fieldIndex, field := range selectFields { @@ -501,7 +503,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { resetFields[index] = field } - selectedColumnsMap[column] = fieldIndex + selectedColumnsMap[column] = offset + fieldIndex if field.IsNormal { break From 12607e8bdf4a724492d53d8c788edc77ad4439e7 Mon Sep 17 00:00:00 2001 From: kuangzhiqiang Date: Mon, 10 Sep 2018 06:14:05 +0800 Subject: [PATCH 091/881] for go1.11 go mod (#2072) when used go1.11 gomodules the code dir will be `$GOPATH/pkg/mod/github.com/jinzhu/gorm@*/` fileWithLineNum check failed --- utils.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils.go b/utils.go index ad700b98..8489538c 100644 --- a/utils.go +++ b/utils.go @@ -25,8 +25,8 @@ var NowFunc = func() time.Time { var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialismsReplacer *strings.Replacer -var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) -var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`) +var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`) +var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`) func init() { var commonInitialismsForReplacer []string From d3e666a1e086a020905e3f6cf293941806520d97 Mon Sep 17 00:00:00 2001 From: Ikhtiyor <33823221+iahmedov@users.noreply.github.com> Date: Mon, 10 Sep 2018 03:25:26 +0500 Subject: [PATCH 092/881] save_associations:true should store related item (#2067) * save_associations:true should store related item, save_associations priority on related objects * code quality --- callback_save.go | 6 ++-- main_test.go | 88 +++++++++++++++++++++++++++++++++++++++++++++++ migration_test.go | 10 +++++- 3 files changed, 100 insertions(+), 4 deletions(-) diff --git a/callback_save.go b/callback_save.go index ef267141..ebfd0b34 100644 --- a/callback_save.go +++ b/callback_save.go @@ -21,9 +21,7 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea if v, ok := value.(string); ok { v = strings.ToLower(v) - if v == "false" || v != "skip" { - return false - } + return v == "true" } return true @@ -36,9 +34,11 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea if value, ok := scope.Get("gorm:save_associations"); ok { autoUpdate = checkTruth(value) autoCreate = autoUpdate + saveReference = autoUpdate } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok { autoUpdate = checkTruth(value) autoCreate = autoUpdate + saveReference = autoUpdate } if value, ok := scope.Get("gorm:association_autoupdate"); ok { diff --git a/main_test.go b/main_test.go index 11c4bb87..94d2fa39 100644 --- a/main_test.go +++ b/main_test.go @@ -933,6 +933,94 @@ func TestOpenWithOneParameter(t *testing.T) { } } +func TestSaveAssociations(t *testing.T) { + db := DB.New() + deltaAddressCount := 0 + if err := db.Model(&Address{}).Count(&deltaAddressCount).Error; err != nil { + t.Errorf("failed to fetch address count") + t.FailNow() + } + + placeAddress := &Address{ + Address1: "somewhere on earth", + } + ownerAddress1 := &Address{ + Address1: "near place address", + } + ownerAddress2 := &Address{ + Address1: "address2", + } + db.Create(placeAddress) + + addressCountShouldBe := func(t *testing.T, expectedCount int) { + countFromDB := 0 + t.Helper() + err := db.Model(&Address{}).Count(&countFromDB).Error + if err != nil { + t.Error("failed to fetch address count") + } + if countFromDB != expectedCount { + t.Errorf("address count mismatch: %d", countFromDB) + } + } + addressCountShouldBe(t, deltaAddressCount+1) + + // owner address should be created, place address should be reused + place1 := &Place{ + PlaceAddressID: placeAddress.ID, + PlaceAddress: placeAddress, + OwnerAddress: ownerAddress1, + } + err := db.Create(place1).Error + if err != nil { + t.Errorf("failed to store place: %s", err.Error()) + } + addressCountShouldBe(t, deltaAddressCount+2) + + // owner address should be created again, place address should be reused + place2 := &Place{ + PlaceAddressID: placeAddress.ID, + PlaceAddress: &Address{ + ID: 777, + Address1: "address1", + }, + OwnerAddress: ownerAddress2, + OwnerAddressID: 778, + } + err = db.Create(place2).Error + if err != nil { + t.Errorf("failed to store place: %s", err.Error()) + } + addressCountShouldBe(t, deltaAddressCount+3) + + count := 0 + db.Model(&Place{}).Where(&Place{ + PlaceAddressID: placeAddress.ID, + OwnerAddressID: ownerAddress1.ID, + }).Count(&count) + if count != 1 { + t.Errorf("only one instance of (%d, %d) should be available, found: %d", + placeAddress.ID, ownerAddress1.ID, count) + } + + db.Model(&Place{}).Where(&Place{ + PlaceAddressID: placeAddress.ID, + OwnerAddressID: ownerAddress2.ID, + }).Count(&count) + if count != 1 { + t.Errorf("only one instance of (%d, %d) should be available, found: %d", + placeAddress.ID, ownerAddress2.ID, count) + } + + db.Model(&Place{}).Where(&Place{ + PlaceAddressID: placeAddress.ID, + }).Count(&count) + if count != 2 { + t.Errorf("two instances of (%d) should be available, found: %d", + placeAddress.ID, count) + } +} + func TestBlockGlobalUpdate(t *testing.T) { db := DB.New() db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"}) diff --git a/migration_test.go b/migration_test.go index 78555dcc..3fb06648 100644 --- a/migration_test.go +++ b/migration_test.go @@ -118,6 +118,14 @@ type Company struct { Owner *User `sql:"-"` } +type Place struct { + Id int64 + PlaceAddressID int + PlaceAddress *Address `gorm:"save_associations:false"` + OwnerAddressID int + OwnerAddress *Address `gorm:"save_associations:true"` +} + type EncryptedData []byte func (data *EncryptedData) Scan(value interface{}) error { @@ -284,7 +292,7 @@ func runMigration() { DB.Exec(fmt.Sprintf("drop table %v;", table)) } - values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}} + values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}} for _, value := range values { DB.DropTable(value) } From 73e7561e20e8e554ec54463ccbed38e426aad17f Mon Sep 17 00:00:00 2001 From: Aaron Leung Date: Sun, 9 Sep 2018 15:26:29 -0700 Subject: [PATCH 093/881] Use sync.Map for DB.values (#2064) * Replace the regular map with a sync.Map to avoid fatal concurrent map reads/writes * fix the formatting --- main.go | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/main.go b/main.go index 993e19b1..364d8e8e 100644 --- a/main.go +++ b/main.go @@ -22,7 +22,7 @@ type DB struct { logMode int logger logger search *search - values map[string]interface{} + values sync.Map // global db parent *DB @@ -72,7 +72,6 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { db = &DB{ db: dbSQL, logger: defaultLogger, - values: map[string]interface{}{}, callbacks: DefaultCallback, dialect: newDialect(dialect, dbSQL), } @@ -680,13 +679,13 @@ func (s *DB) Set(name string, value interface{}) *DB { // InstantSet instant set setting, will affect current db func (s *DB) InstantSet(name string, value interface{}) *DB { - s.values[name] = value + s.values.Store(name, value) return s } // Get get setting by name func (s *DB) Get(name string) (value interface{}, ok bool) { - value, ok = s.values[name] + value, ok = s.values.Load(name) return } @@ -750,16 +749,16 @@ func (s *DB) clone() *DB { parent: s.parent, logger: s.logger, logMode: s.logMode, - values: map[string]interface{}{}, Value: s.Value, Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate, dialect: newDialect(s.dialect.GetName(), s.db), } - for key, value := range s.values { - db.values[key] = value - } + s.values.Range(func(k, v interface{}) bool { + db.values.Store(k, v) + return true + }) if s.search == nil { db.search = &search{limit: -1, offset: -1} From 012d1479740ec593b0c07f0372e0111c01c3b34a Mon Sep 17 00:00:00 2001 From: maddie Date: Mon, 10 Sep 2018 06:45:55 +0800 Subject: [PATCH 094/881] Improve preload speed (#2058) All credits to @vanjapt who came up with this patch. Closes #1672 --- callback_query_preload.go | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/callback_query_preload.go b/callback_query_preload.go index 481bfbe3..46405c38 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -161,14 +161,17 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) ) if indirectScopeValue.Kind() == reflect.Slice { + foreignValuesToResults := make(map[string]reflect.Value) + for i := 0; i < resultsValue.Len(); i++ { + result := resultsValue.Index(i) + foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames)) + foreignValuesToResults[foreignValues] = result + } for j := 0; j < indirectScopeValue.Len(); j++ { - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { - indirectValue.FieldByName(field.Name).Set(result) - break - } + indirectValue := indirect(indirectScopeValue.Index(j)) + valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames)) + if result, found := foreignValuesToResults[valueString]; found { + indirectValue.FieldByName(field.Name).Set(result) } } } else { @@ -255,13 +258,21 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ indirectScopeValue = scope.IndirectValue() ) + foreignFieldToObjects := make(map[string][]*reflect.Value) + if indirectScopeValue.Kind() == reflect.Slice { + for j := 0; j < indirectScopeValue.Len(); j++ { + object := indirect(indirectScopeValue.Index(j)) + valueString := toString(getValueFromFields(object, relation.ForeignFieldNames)) + foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object) + } + } + for i := 0; i < resultsValue.Len(); i++ { result := resultsValue.Index(i) if indirectScopeValue.Kind() == reflect.Slice { - value := getValueFromFields(result, relation.AssociationForeignFieldNames) - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { + valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames)) + if objects, found := foreignFieldToObjects[valueString]; found { + for _, object := range objects { object.FieldByName(field.Name).Set(result) } } From 26fde9110f932df8cb5cc24396e7a54a6d3a94c2 Mon Sep 17 00:00:00 2001 From: Gustavo Brunoro Date: Sun, 9 Sep 2018 19:47:18 -0300 Subject: [PATCH 095/881] getValueFromFields doesn't panic on nil pointers (#2021) * `IsValid()` won't return `false` for nil pointers unless Value is wrapped in a `reflect.Indirect`. --- utils.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils.go b/utils.go index 8489538c..e58e57a5 100644 --- a/utils.go +++ b/utils.go @@ -206,7 +206,7 @@ func getValueFromFields(value reflect.Value, fieldNames []string) (results []int // as FieldByName could panic if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { for _, fieldName := range fieldNames { - if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { + if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() { result := fieldValue.Interface() if r, ok := result.(driver.Valuer); ok { result, _ = r.Value() From 588b598f9fbf9a0c84b6ec18f617940b045c54d4 Mon Sep 17 00:00:00 2001 From: Phillip Shipley Date: Sun, 9 Sep 2018 18:50:22 -0400 Subject: [PATCH 096/881] Fix issue updating models with foreign key constraints (#1988) * fix update callback to not try to write zero values when field has default value * fix to update callback for gorm tests --- callback_update.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/callback_update.go b/callback_update.go index 373bd726..f6ba0ffd 100644 --- a/callback_update.go +++ b/callback_update.go @@ -76,7 +76,9 @@ func updateCallback(scope *Scope) { for _, field := range scope.Fields() { if scope.changeableField(field) { if !field.IsPrimaryKey && field.IsNormal { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { for _, foreignKey := range relationship.ForeignDBNames { if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { From 282f11af1900a36646b607797273056d76350223 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 9 Sep 2018 19:52:32 -0300 Subject: [PATCH 097/881] Support only preloading (#1926) * add support for only preloading relations on an already populated model * Update callback_query.go comments --- callback_query.go | 5 +++++ main.go | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/callback_query.go b/callback_query.go index ba10cc7d..593e5d30 100644 --- a/callback_query.go +++ b/callback_query.go @@ -18,6 +18,11 @@ func queryCallback(scope *Scope) { if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { return } + + //we are only preloading relations, dont touch base model + if _, skip := scope.InstanceGet("gorm:only_preload"); skip { + return + } defer scope.trace(NowFunc()) diff --git a/main.go b/main.go index 364d8e8e..4dbda61e 100644 --- a/main.go +++ b/main.go @@ -314,6 +314,11 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB { return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } +//Preloads preloads relations, don`t touch out +func (s *DB) Preloads(out interface{}) *DB { + return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db +} + // Scan scan value to a struct func (s *DB) Scan(dest interface{}) *DB { return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db From 123d4f50ef8a8209ee8434daa41c6045a9111864 Mon Sep 17 00:00:00 2001 From: Eyal Posener Date: Mon, 10 Sep 2018 02:11:00 +0300 Subject: [PATCH 098/881] lock TagSettings structure when modified (#1796) The map is modified in different places in the code which results in race conditions on execution. This commit locks the map with read-write lock when it is modified --- callback_query_preload.go | 2 +- callback_save.go | 8 ++--- dialect.go | 10 +++--- dialect_common.go | 2 +- dialect_mysql.go | 22 ++++++------ dialect_postgres.go | 6 ++-- dialect_sqlite3.go | 4 +-- dialects/mssql/mssql.go | 8 ++--- field_test.go | 2 +- main.go | 2 +- model_struct.go | 73 +++++++++++++++++++++++++++------------ scope.go | 12 +++---- 12 files changed, 90 insertions(+), 61 deletions(-) diff --git a/callback_query_preload.go b/callback_query_preload.go index 46405c38..d7c8a133 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -100,7 +100,7 @@ func autoPreload(scope *Scope) { continue } - if val, ok := field.TagSettings["PRELOAD"]; ok { + if val, ok := field.TagSettingsGet("PRELOAD"); ok { if preload, err := strconv.ParseBool(val); err != nil { scope.Err(errors.New("invalid preload option")) return diff --git a/callback_save.go b/callback_save.go index ebfd0b34..3b4e0589 100644 --- a/callback_save.go +++ b/callback_save.go @@ -35,7 +35,7 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea autoUpdate = checkTruth(value) autoCreate = autoUpdate saveReference = autoUpdate - } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok { + } else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok { autoUpdate = checkTruth(value) autoCreate = autoUpdate saveReference = autoUpdate @@ -43,19 +43,19 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea if value, ok := scope.Get("gorm:association_autoupdate"); ok { autoUpdate = checkTruth(value) - } else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok { + } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok { autoUpdate = checkTruth(value) } if value, ok := scope.Get("gorm:association_autocreate"); ok { autoCreate = checkTruth(value) - } else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok { + } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok { autoCreate = checkTruth(value) } if value, ok := scope.Get("gorm:association_save_reference"); ok { saveReference = checkTruth(value) - } else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok { + } else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok { saveReference = checkTruth(value) } } diff --git a/dialect.go b/dialect.go index 506a6e86..27b308af 100644 --- a/dialect.go +++ b/dialect.go @@ -83,7 +83,7 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel // Get redirected field type var ( reflectType = field.Struct.Type - dataType = field.TagSettings["TYPE"] + dataType, _ = field.TagSettingsGet("TYPE") ) for reflectType.Kind() == reflect.Ptr { @@ -112,15 +112,17 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel } // Default Size - if num, ok := field.TagSettings["SIZE"]; ok { + if num, ok := field.TagSettingsGet("SIZE"); ok { size, _ = strconv.Atoi(num) } else { size = 255 } // Default type from tag setting - additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] - if value, ok := field.TagSettings["DEFAULT"]; ok { + notNull, _ := field.TagSettingsGet("NOT NULL") + unique, _ := field.TagSettingsGet("UNIQUE") + additionalType = notNull + " " + unique + if value, ok := field.TagSettingsGet("DEFAULT"); ok { additionalType = additionalType + " DEFAULT " + value } diff --git a/dialect_common.go b/dialect_common.go index b9f0c7da..a479be79 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -39,7 +39,7 @@ func (commonDialect) Quote(key string) string { } func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { - if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { return strings.ToLower(value) != "false" } return field.IsPrimaryKey diff --git a/dialect_mysql.go b/dialect_mysql.go index b162bade..5d63e5cd 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -33,9 +33,9 @@ func (s *mysql) DataTypeOf(field *StructField) string { // MySQL allows only one auto increment column per table, and it must // be a KEY column. - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey { - delete(field.TagSettings, "AUTO_INCREMENT") + if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { + if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey { + field.TagSettingsDelete("AUTO_INCREMENT") } } @@ -45,42 +45,42 @@ func (s *mysql) DataTypeOf(field *StructField) string { sqlType = "boolean" case reflect.Int8: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "tinyint AUTO_INCREMENT" } else { sqlType = "tinyint" } case reflect.Int, reflect.Int16, reflect.Int32: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "int AUTO_INCREMENT" } else { sqlType = "int" } case reflect.Uint8: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "tinyint unsigned AUTO_INCREMENT" } else { sqlType = "tinyint unsigned" } case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "int unsigned AUTO_INCREMENT" } else { sqlType = "int unsigned" } case reflect.Int64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigint AUTO_INCREMENT" } else { sqlType = "bigint" } case reflect.Uint64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigint unsigned AUTO_INCREMENT" } else { sqlType = "bigint unsigned" @@ -96,11 +96,11 @@ func (s *mysql) DataTypeOf(field *StructField) string { case reflect.Struct: if _, ok := dataValue.Interface().(time.Time); ok { precision := "" - if p, ok := field.TagSettings["PRECISION"]; ok { + if p, ok := field.TagSettingsGet("PRECISION"); ok { precision = fmt.Sprintf("(%s)", p) } - if _, ok := field.TagSettings["NOT NULL"]; ok { + if _, ok := field.TagSettingsGet("NOT NULL"); ok { sqlType = fmt.Sprintf("timestamp%v", precision) } else { sqlType = fmt.Sprintf("timestamp%v NULL", precision) diff --git a/dialect_postgres.go b/dialect_postgres.go index c44c6a5b..53d31388 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -34,14 +34,14 @@ func (s *postgres) DataTypeOf(field *StructField) string { sqlType = "boolean" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "serial" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint32, reflect.Uint64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigserial" } else { sqlType = "bigint" @@ -49,7 +49,7 @@ func (s *postgres) DataTypeOf(field *StructField) string { case reflect.Float32, reflect.Float64: sqlType = "numeric" case reflect.String: - if _, ok := field.TagSettings["SIZE"]; !ok { + if _, ok := field.TagSettingsGet("SIZE"); !ok { size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different } diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index f26f6be3..5f96c363 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -29,14 +29,14 @@ func (s *sqlite3) DataTypeOf(field *StructField) string { sqlType = "bool" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "integer primary key autoincrement" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "integer primary key autoincrement" } else { sqlType = "bigint" diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 731721cb..6c424bc1 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -18,7 +18,7 @@ import ( func setIdentityInsert(scope *gorm.Scope) { if scope.Dialect().GetName() == "mssql" { for _, field := range scope.PrimaryFields() { - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank { + if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank { scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) scope.InstanceSet("mssql:identity_insert_on", true) } @@ -70,14 +70,14 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { sqlType = "bit" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "int IDENTITY(1,1)" } else { sqlType = "int" } case reflect.Int64, reflect.Uint64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigint IDENTITY(1,1)" } else { sqlType = "bigint" @@ -116,7 +116,7 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { } func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool { - if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { return value != "FALSE" } return field.IsPrimaryKey diff --git a/field_test.go b/field_test.go index 30e9a778..c3afdff5 100644 --- a/field_test.go +++ b/field_test.go @@ -43,7 +43,7 @@ func TestCalculateField(t *testing.T) { if field, ok := scope.FieldByName("embedded_name"); !ok { t.Errorf("should find embedded field") - } else if _, ok := field.TagSettings["NOT NULL"]; !ok { + } else if _, ok := field.TagSettingsGet("NOT NULL"); !ok { t.Errorf("should find embedded field's tag settings") } } diff --git a/main.go b/main.go index 4dbda61e..17c75ed3 100644 --- a/main.go +++ b/main.go @@ -699,7 +699,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join scope := s.NewScope(source) for _, field := range scope.GetModelStruct().StructFields { if field.Name == column || field.DBName == column { - if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { + if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { source := (&Scope{Value: source}).GetModelStruct().ModelType destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType handler.Setup(field.Relationship, many2many, source, destination) diff --git a/model_struct.go b/model_struct.go index 5b5be618..12860e67 100644 --- a/model_struct.go +++ b/model_struct.go @@ -60,6 +60,30 @@ type StructField struct { Struct reflect.StructField IsForeignKey bool Relationship *Relationship + + tagSettingsLock sync.RWMutex +} + +// TagSettingsSet Sets a tag in the tag settings map +func (s *StructField) TagSettingsSet(key, val string) { + s.tagSettingsLock.Lock() + defer s.tagSettingsLock.Unlock() + s.TagSettings[key] = val +} + +// TagSettingsGet returns a tag from the tag settings +func (s *StructField) TagSettingsGet(key string) (string, bool) { + s.tagSettingsLock.RLock() + defer s.tagSettingsLock.RUnlock() + val, ok := s.TagSettings[key] + return val, ok +} + +// TagSettingsDelete deletes a tag +func (s *StructField) TagSettingsDelete(key string) { + s.tagSettingsLock.Lock() + defer s.tagSettingsLock.Unlock() + delete(s.TagSettings, key) } func (structField *StructField) clone() *StructField { @@ -83,6 +107,9 @@ func (structField *StructField) clone() *StructField { clone.Relationship = &relationship } + // copy the struct field tagSettings, they should be read-locked while they are copied + structField.tagSettingsLock.Lock() + defer structField.tagSettingsLock.Unlock() for key, value := range structField.TagSettings { clone.TagSettings[key] = value } @@ -149,19 +176,19 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // is ignored field - if _, ok := field.TagSettings["-"]; ok { + if _, ok := field.TagSettingsGet("-"); ok { field.IsIgnored = true } else { - if _, ok := field.TagSettings["PRIMARY_KEY"]; ok { + if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok { field.IsPrimaryKey = true modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) } - if _, ok := field.TagSettings["DEFAULT"]; ok { + if _, ok := field.TagSettingsGet("DEFAULT"); ok { field.HasDefaultValue = true } - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsPrimaryKey { + if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey { field.HasDefaultValue = true } @@ -177,8 +204,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if indirectType.Kind() == reflect.Struct { for i := 0; i < indirectType.NumField(); i++ { for key, value := range parseTagSetting(indirectType.Field(i).Tag) { - if _, ok := field.TagSettings[key]; !ok { - field.TagSettings[key] = value + if _, ok := field.TagSettingsGet(key); !ok { + field.TagSettingsSet(key, value) } } } @@ -186,17 +213,17 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } else if _, isTime := fieldValue.(*time.Time); isTime { // is time field.IsNormal = true - } else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { + } else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous { // is embedded struct for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields { subField = subField.clone() subField.Names = append([]string{fieldStruct.Name}, subField.Names...) - if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok { + if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok { subField.DBName = prefix + subField.DBName } if subField.IsPrimaryKey { - if _, ok := subField.TagSettings["PRIMARY_KEY"]; ok { + if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok { modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) } else { subField.IsPrimaryKey = false @@ -227,13 +254,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { elemType = field.Struct.Type ) - if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { foreignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { associationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { associationForeignKeys = strings.Split(foreignKey, ",") } @@ -242,13 +269,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if elemType.Kind() == reflect.Struct { - if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { + if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { relationship.Kind = "many_to_many" { // Foreign Keys for Source joinTableDBNames := []string{} - if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" { joinTableDBNames = strings.Split(foreignKey, ",") } @@ -279,7 +306,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { { // Foreign Keys for Association (Destination) associationJoinTableDBNames := []string{} - if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" { associationJoinTableDBNames = strings.Split(foreignKey, ",") } @@ -317,7 +344,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { var toFields = toScope.GetStructFields() relationship.Kind = "has_many" - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { // Dog has many toys, tag polymorphic is Owner, then associationType is Owner // Toy use OwnerID, OwnerType ('dogs') as foreign key if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { @@ -325,7 +352,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName // if Dog has multiple set of toys set name of the set (instead of default 'dogs') - if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { + if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { relationship.PolymorphicValue = value } else { relationship.PolymorphicValue = scope.TableName() @@ -407,17 +434,17 @@ func (scope *Scope) GetModelStruct() *ModelStruct { tagAssociationForeignKeys []string ) - if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { tagForeignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { tagAssociationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { tagAssociationForeignKeys = strings.Split(foreignKey, ",") } - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { // Cat has one toy, tag polymorphic is Owner, then associationType is Owner // Toy use OwnerID, OwnerType ('cats') as foreign key if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { @@ -425,7 +452,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName // if Cat has several different types of toys set name for each (instead of default 'cats') - if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { + if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { relationship.PolymorphicValue = value } else { relationship.PolymorphicValue = scope.TableName() @@ -563,7 +590,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Even it is ignored, also possible to decode db value into the field - if value, ok := field.TagSettings["COLUMN"]; ok { + if value, ok := field.TagSettingsGet("COLUMN"); ok { field.DBName = value } else { field.DBName = ToColumnName(fieldStruct.Name) diff --git a/scope.go b/scope.go index ce80ab86..fa521ca2 100644 --- a/scope.go +++ b/scope.go @@ -1115,8 +1115,8 @@ func (scope *Scope) createJoinTable(field *StructField) { if field, ok := scope.FieldByName(fieldName); ok { foreignKeyStruct := field.clone() foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" - delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT") + foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") + foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) } @@ -1126,8 +1126,8 @@ func (scope *Scope) createJoinTable(field *StructField) { if field, ok := toScope.FieldByName(fieldName); ok { foreignKeyStruct := field.clone() foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" - delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT") + foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") + foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) } @@ -1262,7 +1262,7 @@ func (scope *Scope) autoIndex() *Scope { var uniqueIndexes = map[string][]string{} for _, field := range scope.GetStructFields() { - if name, ok := field.TagSettings["INDEX"]; ok { + if name, ok := field.TagSettingsGet("INDEX"); ok { names := strings.Split(name, ",") for _, name := range names { @@ -1273,7 +1273,7 @@ func (scope *Scope) autoIndex() *Scope { } } - if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok { + if name, ok := field.TagSettingsGet("UNIQUE_INDEX"); ok { names := strings.Split(name, ",") for _, name := range names { From 5be9bd34135805e0332b993378864b159784d8a8 Mon Sep 17 00:00:00 2001 From: ch3rub1m Date: Fri, 14 Sep 2018 15:53:49 +0800 Subject: [PATCH 099/881] Rollback transaction when a panic happens in callback (#2093) --- scope.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/scope.go b/scope.go index fa521ca2..378025bd 100644 --- a/scope.go +++ b/scope.go @@ -855,6 +855,14 @@ func (scope *Scope) inlineCondition(values ...interface{}) *Scope { } func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { + defer func() { + if err := recover(); err != nil { + if db, ok := scope.db.db.(sqlTx); ok { + db.Rollback() + } + panic(err) + } + }() for _, f := range funcs { (*f)(scope) if scope.skipLeft { From f6260a00852946a10a57e8bb9f505f19bc9389b7 Mon Sep 17 00:00:00 2001 From: Artemij Shepelev Date: Sat, 22 Sep 2018 14:59:11 +0300 Subject: [PATCH 100/881] Second part of the defaultTableName field race fix (#2060) * fix (https://github.com/jinzhu/gorm/issues/1407) * changed map with mutex to sync.Map (https://github.com/jinzhu/gorm/issues/1407) * removed newModelStructsMap func * commit to rerun pipeline, comment changed * fix race with defaultTableName field (again) --- model_struct.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/model_struct.go b/model_struct.go index 12860e67..8c27e209 100644 --- a/model_struct.go +++ b/model_struct.go @@ -24,11 +24,16 @@ type ModelStruct struct { PrimaryFields []*StructField StructFields []*StructField ModelType reflect.Type + defaultTableName string + l sync.Mutex } // TableName returns model's table name func (s *ModelStruct) TableName(db *DB) string { + s.l.Lock() + defer s.l.Unlock() + if s.defaultTableName == "" && db != nil && s.ModelType != nil { // Set default table name if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { From 742154be9a26e849f02d296073c077e0a7c23828 Mon Sep 17 00:00:00 2001 From: "Iskander (Alex) Sharipov" Date: Sun, 7 Oct 2018 03:49:37 +0300 Subject: [PATCH 101/881] rewrite if-else chain as switch statement (#2121) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit From effective Go: https://golang.org/doc/effective_go.html#switch > It's therefore possible—and idiomatic—to write an if-else-if-else chain as a switch. --- association.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/association.go b/association.go index 8c6d9864..1b7744b5 100644 --- a/association.go +++ b/association.go @@ -267,15 +267,16 @@ func (association *Association) Count() int { query = scope.DB() ) - if relationship.Kind == "many_to_many" { + switch relationship.Kind { + case "many_to_many": query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { + case "has_many", "has_one": primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) query = query.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)..., ) - } else if relationship.Kind == "belongs_to" { + case "belongs_to": primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) query = query.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), From 50c61291de2f96a25627c55adcfda719ff5adae8 Mon Sep 17 00:00:00 2001 From: RikiyaFujii Date: Sat, 3 Nov 2018 22:55:52 +0900 Subject: [PATCH 102/881] add comment (#2163) * add comment * typo --- association.go | 1 + 1 file changed, 1 insertion(+) diff --git a/association.go b/association.go index 1b7744b5..a73344fe 100644 --- a/association.go +++ b/association.go @@ -368,6 +368,7 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa return association } +// setErr set error when the error is not nil. And return Association. func (association *Association) setErr(err error) *Association { if err != nil { association.Error = err From 68f5d25d640b04d1b302993b609b2b1c693432ad Mon Sep 17 00:00:00 2001 From: teresy <43420401+teresy@users.noreply.github.com> Date: Sat, 3 Nov 2018 09:56:27 -0400 Subject: [PATCH 103/881] simplify cases of strings.Index with strings.Contains (#2162) --- scope.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scope.go b/scope.go index 378025bd..806ccb7d 100644 --- a/scope.go +++ b/scope.go @@ -68,7 +68,7 @@ func (scope *Scope) Dialect() Dialect { // Quote used to quote string to escape them for database func (scope *Scope) Quote(str string) string { - if strings.Index(str, ".") != -1 { + if strings.Contains(str, ".") { newStrs := []string{} for _, str := range strings.Split(str, ".") { newStrs = append(newStrs, scope.Dialect().Quote(str)) @@ -330,7 +330,7 @@ func (scope *Scope) TableName() string { // QuotedTableName return quoted table name func (scope *Scope) QuotedTableName() (name string) { if scope.Search != nil && len(scope.Search.tableName) > 0 { - if strings.Index(scope.Search.tableName, " ") != -1 { + if strings.Contains(scope.Search.tableName, " ") { return scope.Search.tableName } return scope.Quote(scope.Search.tableName) From 472c70caa40267cb89fd8facb07fe6454b578626 Mon Sep 17 00:00:00 2001 From: Jun Jie Nan Date: Sat, 3 Nov 2018 22:14:39 +0800 Subject: [PATCH 104/881] Check valuer interface before scan value (#2155) Scan interface only accept int64, float64, bool, []byte, string, time.Time or nil. When do scan, it's better to check whether the type support valuer interface and do convert. --- field.go | 10 +++++++++- field_test.go | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/field.go b/field.go index 11c410b0..acd06e20 100644 --- a/field.go +++ b/field.go @@ -2,6 +2,7 @@ package gorm import ( "database/sql" + "database/sql/driver" "errors" "fmt" "reflect" @@ -44,7 +45,14 @@ func (field *Field) Set(value interface{}) (err error) { if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { fieldValue.Set(reflectValue.Convert(fieldValue.Type())) } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - err = scanner.Scan(reflectValue.Interface()) + v := reflectValue.Interface() + if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = scanner.Scan(v) + } + } else { + err = scanner.Scan(v) + } } else { err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) } diff --git a/field_test.go b/field_test.go index c3afdff5..03a3b3b7 100644 --- a/field_test.go +++ b/field_test.go @@ -3,6 +3,7 @@ package gorm_test import ( "testing" + "github.com/gofrs/uuid" "github.com/jinzhu/gorm" ) @@ -47,3 +48,20 @@ func TestCalculateField(t *testing.T) { t.Errorf("should find embedded field's tag settings") } } + +func TestFieldSet(t *testing.T) { + type TestFieldSetNullUUID struct { + NullUUID uuid.NullUUID + } + scope := DB.NewScope(&TestFieldSetNullUUID{}) + field := scope.Fields()[0] + err := field.Set(uuid.FromStringOrNil("3034d44a-da03-11e8-b366-4a00070b9f00")) + if err != nil { + t.Fatal(err) + } + if id, ok := field.Field.Addr().Interface().(*uuid.NullUUID); !ok { + t.Fatal() + } else if !id.Valid || id.UUID.String() != "3034d44a-da03-11e8-b366-4a00070b9f00" { + t.Fatal(id) + } +} From 5ad6f621e6f59672f5b5061df85b243436fde048 Mon Sep 17 00:00:00 2001 From: Sai Date: Thu, 13 Dec 2018 22:04:51 +0900 Subject: [PATCH 105/881] logMode codes more readable (#2216) --- main.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/main.go b/main.go index 17c75ed3..c1197bc9 100644 --- a/main.go +++ b/main.go @@ -19,7 +19,7 @@ type DB struct { // single db db SQLCommon blockGlobalUpdate bool - logMode int + logMode logModeValue logger logger search *search values sync.Map @@ -31,6 +31,14 @@ type DB struct { singularTable bool } +type logModeValue int + +const ( + defaultLogMode logModeValue = iota + noLogMode + detailedLogMode +) + // Open initialize a new db connection, need to import driver first, e.g: // // import _ "github.com/go-sql-driver/mysql" @@ -141,9 +149,9 @@ func (s *DB) SetLogger(log logger) { // LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs func (s *DB) LogMode(enable bool) *DB { if enable { - s.logMode = 2 + s.logMode = detailedLogMode } else { - s.logMode = 1 + s.logMode = noLogMode } return s } @@ -716,7 +724,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join func (s *DB) AddError(err error) error { if err != nil { if err != ErrRecordNotFound { - if s.logMode == 0 { + if s.logMode == defaultLogMode { go s.print(fileWithLineNum(), err) } else { s.log(err) @@ -780,13 +788,13 @@ func (s *DB) print(v ...interface{}) { } func (s *DB) log(v ...interface{}) { - if s != nil && s.logMode == 2 { + if s != nil && s.logMode == detailedLogMode { s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) } } func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { - if s.logMode == 2 { + if s.logMode == detailedLogMode { s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected) } } From 447d578628011308498d9316838f59f93834967c Mon Sep 17 00:00:00 2001 From: Zed Date: Wed, 2 Jan 2019 21:23:43 +0800 Subject: [PATCH 106/881] amended comments in error.go for clarity and grammar; for more polish when using IDEs (e.g. VSCODE) that show comments as help text (#2182) --- errors.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/errors.go b/errors.go index 27c9a92d..d5ef8d57 100644 --- a/errors.go +++ b/errors.go @@ -6,11 +6,11 @@ import ( ) var ( - // ErrRecordNotFound record not found error, happens when only haven't find any matched data when looking up with a struct, finding a slice won't return this error + // ErrRecordNotFound returns a "record not found error". Occurs only when attempting to query the database with a struct; querying with a slice won't return this error ErrRecordNotFound = errors.New("record not found") - // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL + // ErrInvalidSQL occurs when you attempt a query with invalid SQL ErrInvalidSQL = errors.New("invalid SQL") - // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` + // ErrInvalidTransaction occurs when you are trying to `Commit` or `Rollback` ErrInvalidTransaction = errors.New("no valid transaction") // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin` ErrCantStartTransaction = errors.New("can't start transaction") @@ -21,7 +21,7 @@ var ( // Errors contains all happened errors type Errors []error -// IsRecordNotFoundError returns current error has record not found error or not +// IsRecordNotFoundError returns true if error contains a RecordNotFound error func IsRecordNotFoundError(err error) bool { if errs, ok := err.(Errors); ok { for _, err := range errs { @@ -33,12 +33,12 @@ func IsRecordNotFoundError(err error) bool { return err == ErrRecordNotFound } -// GetErrors gets all happened errors +// GetErrors gets all errors that have occurred and returns a slice of errors (Error type) func (errs Errors) GetErrors() []error { return errs } -// Add adds an error +// Add adds an error to a given slice of errors func (errs Errors) Add(newErrors ...error) Errors { for _, err := range newErrors { if err == nil { @@ -62,7 +62,7 @@ func (errs Errors) Add(newErrors ...error) Errors { return errs } -// Error format happened errors +// Error takes a slice of all errors that have occurred and returns it as a formatted string func (errs Errors) Error() string { var errors = []string{} for _, e := range errs { From ac6c89ec0cb95e921ddf43759f1f1f367d9e587c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E9=B9=8F?= Date: Wed, 2 Jan 2019 21:25:37 +0800 Subject: [PATCH 107/881] =?UTF-8?q?search=E4=B8=8D=E9=9C=80=E8=A6=81?= =?UTF-8?q?=E5=86=8Dclone=EF=BC=8CdbClone=E5=86=85=E7=9A=84search=E5=B7=B2?= =?UTF-8?q?=E7=BB=8F=E6=98=AF=E4=B8=80=E4=B8=AA=E5=85=A8=E6=96=B0=E7=9A=84?= =?UTF-8?q?=E4=BA=86=20(#2179)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.go b/main.go index c1197bc9..34a6ddc8 100644 --- a/main.go +++ b/main.go @@ -178,7 +178,7 @@ func (s *DB) SingularTable(enable bool) { func (s *DB) NewScope(value interface{}) *Scope { dbClone := s.clone() dbClone.Value = value - return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} + return &Scope{db: dbClone, Search: dbClone.search, Value: value} } // QueryExpr returns the query as expr object From e2cfd6be3b09b548be8c4d349490bf563cb1ee13 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 2 Jan 2019 21:27:17 +0800 Subject: [PATCH 108/881] LintFix: Make receiver name of structField consistent (#2164) * Make receiver name of structField consistent * Change s to sf --- model_struct.go | 66 ++++++++++++++++++++++++------------------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/model_struct.go b/model_struct.go index 8c27e209..08e741fe 100644 --- a/model_struct.go +++ b/model_struct.go @@ -21,12 +21,12 @@ var modelStructsMap sync.Map // ModelStruct model definition type ModelStruct struct { - PrimaryFields []*StructField - StructFields []*StructField - ModelType reflect.Type + PrimaryFields []*StructField + StructFields []*StructField + ModelType reflect.Type defaultTableName string - l sync.Mutex + l sync.Mutex } // TableName returns model's table name @@ -70,52 +70,52 @@ type StructField struct { } // TagSettingsSet Sets a tag in the tag settings map -func (s *StructField) TagSettingsSet(key, val string) { - s.tagSettingsLock.Lock() - defer s.tagSettingsLock.Unlock() - s.TagSettings[key] = val +func (sf *StructField) TagSettingsSet(key, val string) { + sf.tagSettingsLock.Lock() + defer sf.tagSettingsLock.Unlock() + sf.TagSettings[key] = val } // TagSettingsGet returns a tag from the tag settings -func (s *StructField) TagSettingsGet(key string) (string, bool) { - s.tagSettingsLock.RLock() - defer s.tagSettingsLock.RUnlock() - val, ok := s.TagSettings[key] +func (sf *StructField) TagSettingsGet(key string) (string, bool) { + sf.tagSettingsLock.RLock() + defer sf.tagSettingsLock.RUnlock() + val, ok := sf.TagSettings[key] return val, ok } // TagSettingsDelete deletes a tag -func (s *StructField) TagSettingsDelete(key string) { - s.tagSettingsLock.Lock() - defer s.tagSettingsLock.Unlock() - delete(s.TagSettings, key) +func (sf *StructField) TagSettingsDelete(key string) { + sf.tagSettingsLock.Lock() + defer sf.tagSettingsLock.Unlock() + delete(sf.TagSettings, key) } -func (structField *StructField) clone() *StructField { +func (sf *StructField) clone() *StructField { clone := &StructField{ - DBName: structField.DBName, - Name: structField.Name, - Names: structField.Names, - IsPrimaryKey: structField.IsPrimaryKey, - IsNormal: structField.IsNormal, - IsIgnored: structField.IsIgnored, - IsScanner: structField.IsScanner, - HasDefaultValue: structField.HasDefaultValue, - Tag: structField.Tag, + DBName: sf.DBName, + Name: sf.Name, + Names: sf.Names, + IsPrimaryKey: sf.IsPrimaryKey, + IsNormal: sf.IsNormal, + IsIgnored: sf.IsIgnored, + IsScanner: sf.IsScanner, + HasDefaultValue: sf.HasDefaultValue, + Tag: sf.Tag, TagSettings: map[string]string{}, - Struct: structField.Struct, - IsForeignKey: structField.IsForeignKey, + Struct: sf.Struct, + IsForeignKey: sf.IsForeignKey, } - if structField.Relationship != nil { - relationship := *structField.Relationship + if sf.Relationship != nil { + relationship := *sf.Relationship clone.Relationship = &relationship } // copy the struct field tagSettings, they should be read-locked while they are copied - structField.tagSettingsLock.Lock() - defer structField.tagSettingsLock.Unlock() - for key, value := range structField.TagSettings { + sf.tagSettingsLock.Lock() + defer sf.tagSettingsLock.Unlock() + for key, value := range sf.TagSettings { clone.TagSettings[key] = value } From a6382da48500a7adfe8a3f75eedc89a34644f54f Mon Sep 17 00:00:00 2001 From: Edgar Fournival Date: Wed, 2 Jan 2019 14:28:02 +0100 Subject: [PATCH 109/881] Do not set CreatedAt if blank during Save (#2207) --- callback_update.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callback_update.go b/callback_update.go index f6ba0ffd..c52162c8 100644 --- a/callback_update.go +++ b/callback_update.go @@ -75,7 +75,7 @@ func updateCallback(scope *Scope) { } else { for _, field := range scope.Fields() { if scope.changeableField(field) { - if !field.IsPrimaryKey && field.IsNormal { + if !field.IsPrimaryKey && field.IsNormal && (field.Name != "CreatedAt" || !field.IsBlank) { if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } From 8316f94b72719208b2d939c70f3824287e62ea5d Mon Sep 17 00:00:00 2001 From: Brent Hughes Date: Wed, 2 Jan 2019 07:28:46 -0600 Subject: [PATCH 110/881] Fix Panic in test scenerio (#2131) I have found that there are times when testing that if I did not create the database through Open() it will not have the parent set and cause a panic when it hits this code path. --- model_struct.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_struct.go b/model_struct.go index 08e741fe..9e93db63 100644 --- a/model_struct.go +++ b/model_struct.go @@ -40,7 +40,7 @@ func (s *ModelStruct) TableName(db *DB) string { s.defaultTableName = tabler.TableName() } else { tableName := ToTableName(s.ModelType.Name()) - if db == nil || !db.parent.singularTable { + if db == nil || (db.parent != nil && !db.parent.singularTable) { tableName = inflection.Plural(tableName) } s.defaultTableName = tableName From 9f1a7f53511168c0567b4b4b4f10ab7d21265174 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=9C=BB=E8=9C=93=E7=89=B9=E6=B4=BE=E5=91=98?= Date: Wed, 2 Jan 2019 21:32:08 +0800 Subject: [PATCH 111/881] optimize getColumnAsArray (#2196) --- scope.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/scope.go b/scope.go index 806ccb7d..90e16782 100644 --- a/scope.go +++ b/scope.go @@ -1309,6 +1309,7 @@ func (scope *Scope) autoIndex() *Scope { } func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { + resultMap := make(map[string][]interface{}) for _, value := range values { indirectValue := indirect(reflect.ValueOf(value)) @@ -1327,7 +1328,10 @@ func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (r } if hasValue { - results = append(results, result) + h := fmt.Sprint(result...) + if _, exist := resultMap[h]; !exist { + resultMap[h] = result + } } } case reflect.Struct: @@ -1342,11 +1346,16 @@ func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (r } if hasValue { - results = append(results, result) + h := fmt.Sprint(result...) + if _, exist := resultMap[h]; !exist { + resultMap[h] = result + } } } } - + for _, v := range resultMap { + results = append(results, v) + } return } From 8494ecdc9857e74477cd95965df2f0297fe6a461 Mon Sep 17 00:00:00 2001 From: aixiaoxiang Date: Sun, 10 Feb 2019 15:37:39 +0800 Subject: [PATCH 112/881] Better log output int8, int, int16, int32, int64, float32, float64, bool. (#2258) * Better log output int, int16, int32, int64, int8, float32, float64. * Better log output bool --- logger.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/logger.go b/logger.go index 4324a2e4..10a1b805 100644 --- a/logger.go +++ b/logger.go @@ -63,7 +63,13 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) { formattedValues = append(formattedValues, "NULL") } } else { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) + switch value.(type) { + case int8, int, int16, int32, int64, float32, float64, bool: + formattedValues = append(formattedValues, fmt.Sprintf("%v", value)) + break + default: + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) + } } } else { formattedValues = append(formattedValues, "NULL") From 906799fef2f895116d915e1793314ab9053b400d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 10 Feb 2019 15:39:40 +0800 Subject: [PATCH 113/881] Better log output for uint* --- logger.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/logger.go b/logger.go index 10a1b805..484bc022 100644 --- a/logger.go +++ b/logger.go @@ -64,9 +64,8 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) { } } else { switch value.(type) { - case int8, int, int16, int32, int64, float32, float64, bool: + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool: formattedValues = append(formattedValues, fmt.Sprintf("%v", value)) - break default: formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) } From 4b13e079fcea637fcb166ee1752c8d80601e3ef0 Mon Sep 17 00:00:00 2001 From: Satoshi Inoue Date: Sun, 10 Mar 2019 08:29:21 +0900 Subject: [PATCH 114/881] go modules (#2279) --- go.mod | 3 +++ go.sum | 2 ++ 2 files changed, 5 insertions(+) create mode 100644 go.mod create mode 100644 go.sum diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..fa0883b8 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/istsh/gorm + +require github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..e2e8e11f --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a h1:eeaG9XMUvRBYXJi4pg1ZKM7nxc5AfXfojeLLW7O5J3k= +github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= From f3a0fc1566e32840934fc895dcbbff7101cc621c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 10 Mar 2019 18:37:07 +0800 Subject: [PATCH 115/881] Fix go.mod --- go.mod | 16 ++++++- go.sum | 149 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 163 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index fa0883b8..f675334d 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,15 @@ -module github.com/istsh/gorm +module github.com/jinzhu/gorm -require github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a +require ( + cloud.google.com/go v0.36.0 // indirect + github.com/denisenkom/go-mssqldb v0.0.0-20190204142019-df6d76eb9289 + github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 + github.com/go-sql-driver/mysql v1.4.1 + github.com/gofrs/uuid v3.2.0+incompatible + github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a + github.com/jinzhu/now v1.0.0 + github.com/lib/pq v1.0.0 + github.com/mattn/go-sqlite3 v1.10.0 + golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 // indirect + google.golang.org/appengine v1.4.0 // indirect +) diff --git a/go.sum b/go.sum index e2e8e11f..25f61146 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,151 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.36.0 h1:+aCSj7tOo2LODWVEuZDZeGCckdt6MlSF+X/rB3wUiS8= +cloud.google.com/go v0.36.0/go.mod h1:RUoy9p/M4ge0HzT8L+SDZ8jg+Q6fth0CiBuhFJpSV40= +dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU= +dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBrvjyP0v+ecvNYvCpyZgu5/xkfAUhi6wJj28eUfSU= +dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= +dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= +git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/denisenkom/go-mssqldb v0.0.0-20190204142019-df6d76eb9289 h1:U+DzmGUpc/dOjREgbyyChPhdDIFwPYnVk+/5YcAa194= +github.com/denisenkom/go-mssqldb v0.0.0-20190204142019-df6d76eb9289/go.mod h1:xN/JuLBIz4bjkxNmByTiV1IbhfnYb6oo99phBn4Eqhc= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DPaBjB8zlTR87/ElzFsnQfuHnVUVqpZZIcV5Y= +github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= +github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= +github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= +github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/gofrs/uuid v3.2.0+incompatible h1:y12jRkkFxsd7GpqdSZ+/KCs/fJbqpEXSGd4+jfEaewE= +github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= +github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= +github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= +github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a h1:eeaG9XMUvRBYXJi4pg1ZKM7nxc5AfXfojeLLW7O5J3k= github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.0.0 h1:6WV8LvwPpDhKjo5U9O6b4+xdG/jTXNPwlDme/MTo8Ns= +github.com/jinzhu/now v1.0.0/go.mod h1:oHTiXerJ20+SfYcrdlBO7rzZRJWGwSTQ0iUY2jI6Gfc= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= +github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= +github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= +github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= +github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= +github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= +github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= +github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= +github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= +github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470/go.mod h1:2dOwnU2uBioM+SGy2aZoq1f/Sd1l9OkAeAUvjSyvgU0= +github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= +github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= +github.com/shurcooL/gofontwoff v0.0.0-20180329035133-29b52fc0a18d/go.mod h1:05UtEgK5zq39gLST6uB0cf3NEHjETfB4Fgr3Gx5R9Vw= +github.com/shurcooL/gopherjslib v0.0.0-20160914041154-feb6d3990c2c/go.mod h1:8d3azKNyqcHP1GaQE/c6dDgjkgSx2BZ4IoEi4F1reUI= +github.com/shurcooL/highlight_diff v0.0.0-20170515013008-09bb4053de1b/go.mod h1:ZpfEhSmds4ytuByIcDnOLkTHGUI6KNqRNPDLHDk+mUU= +github.com/shurcooL/highlight_go v0.0.0-20181028180052-98c3abbbae20/go.mod h1:UDKB5a1T23gOMUJrI+uSuH0VRDStOiUVSjBTRDVBVag= +github.com/shurcooL/home v0.0.0-20181020052607-80b7ffcb30f9/go.mod h1:+rgNQw2P9ARFAs37qieuu7ohDNQ3gds9msbT2yn85sg= +github.com/shurcooL/htmlg v0.0.0-20170918183704-d01228ac9e50/go.mod h1:zPn1wHpTIePGnXSHpsVPWEktKXHr6+SS6x/IKRb7cpw= +github.com/shurcooL/httperror v0.0.0-20170206035902-86b7830d14cc/go.mod h1:aYMfkZ6DWSJPJ6c4Wwz3QtW22G7mf/PEgaB9k/ik5+Y= +github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371/go.mod h1:ZY1cvUeJuFPAdZ/B6v7RHavJWZn2YPVFQ1OSXhCGOkg= +github.com/shurcooL/httpgzip v0.0.0-20180522190206-b1c53ac65af9/go.mod h1:919LwcH0M7/W4fcZ0/jy0qGght1GIhqyS/EgWGH2j5Q= +github.com/shurcooL/issues v0.0.0-20181008053335-6292fdc1e191/go.mod h1:e2qWDig5bLteJ4fwvDAc2NHzqFEthkqn7aOZAOpj+PQ= +github.com/shurcooL/issuesapp v0.0.0-20180602232740-048589ce2241/go.mod h1:NPpHK2TI7iSaM0buivtFUc9offApnI0Alt/K8hcHy0I= +github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122/go.mod h1:b5uSkrEVM1jQUspwbixRBhaIjIzL2xazXp6kntxYle0= +github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2/go.mod h1:eWdoE5JD4R5UVWDucdOPg1g2fqQRq78IQa9zlOV1vpQ= +github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82/go.mod h1:TCR1lToEk4d2s07G3XGfz2QrgHXg4RJBvjrOozvoWfk= +github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYEDaXHZDBsXlPCDqdhQuJkuw4NOtaxYe3xii4= +github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw= +github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= +github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= +go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= +go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= +golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= +golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= +google.golang.org/genproto v0.0.0-20190201180003-4b09977fb922/go.mod h1:L3J43x8/uS+qIUoksaLKe6OS3nUKxOKuIFz1sl2/jx4= +google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= +google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= +google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= +honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck= +sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0= From d7ef7871a424f1652bf706a0a454a452693400ce Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 10 Mar 2019 19:33:49 +0800 Subject: [PATCH 116/881] Fix tests --- callback_query.go | 2 +- main.go | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/callback_query.go b/callback_query.go index 593e5d30..7facc42b 100644 --- a/callback_query.go +++ b/callback_query.go @@ -18,7 +18,7 @@ func queryCallback(scope *Scope) { if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { return } - + //we are only preloading relations, dont touch base model if _, skip := scope.InstanceGet("gorm:only_preload"); skip { return diff --git a/main.go b/main.go index 34a6ddc8..f52ba27b 100644 --- a/main.go +++ b/main.go @@ -178,7 +178,13 @@ func (s *DB) SingularTable(enable bool) { func (s *DB) NewScope(value interface{}) *Scope { dbClone := s.clone() dbClone.Value = value - return &Scope{db: dbClone, Search: dbClone.search, Value: value} + scope := &Scope{db: dbClone, Value: value} + if s.search != nil { + scope.Search = s.search.clone() + } else { + scope.Search = &search{} + } + return scope } // QueryExpr returns the query as expr object @@ -298,6 +304,7 @@ func (s *DB) Assign(attrs ...interface{}) *DB { func (s *DB) First(out interface{}, where ...interface{}) *DB { newScope := s.NewScope(out) newScope.Search.Limit(1) + return newScope.Set("gorm:order_by_primary_key", "ASC"). inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } From c721a198a7ae3b9d68d3aed38d9d7d5bc55f3084 Mon Sep 17 00:00:00 2001 From: haoc7 Date: Sun, 10 Mar 2019 20:01:57 +0800 Subject: [PATCH 117/881] create table add column comment (#2298) --- dialect.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dialect.go b/dialect.go index 27b308af..cdc4278e 100644 --- a/dialect.go +++ b/dialect.go @@ -126,6 +126,10 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel additionalType = additionalType + " DEFAULT " + value } + if value, ok := field.TagSettingsGet("COMMENT"); ok { + additionalType = additionalType + " COMMENT " + value + } + return fieldValue, dataType, size, strings.TrimSpace(additionalType) } From d239c4cab8a0cb09643a79567450d66ac972ba6c Mon Sep 17 00:00:00 2001 From: kuangzhiqiang Date: Sun, 10 Mar 2019 20:03:55 +0800 Subject: [PATCH 118/881] error log show trace file (#2296) --- main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.go b/main.go index f52ba27b..fda63d29 100644 --- a/main.go +++ b/main.go @@ -732,7 +732,7 @@ func (s *DB) AddError(err error) error { if err != nil { if err != ErrRecordNotFound { if s.logMode == defaultLogMode { - go s.print(fileWithLineNum(), err) + go s.print("error", fileWithLineNum(), err) } else { s.log(err) } From 8b07437717e71c2ff00602ae19f8353ba10aafbb Mon Sep 17 00:00:00 2001 From: Ali Koyuncu Date: Sun, 10 Mar 2019 14:17:21 +0200 Subject: [PATCH 119/881] add mysql insert modifiers (#2269) --- callback_create.go | 13 +++++++++++-- create_test.go | 17 +++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/callback_create.go b/callback_create.go index 2ab05d3b..763a2dfd 100644 --- a/callback_create.go +++ b/callback_create.go @@ -83,11 +83,18 @@ func createCallback(scope *Scope) { quotedTableName = scope.QuotedTableName() primaryField = scope.PrimaryField() extraOption string + insertModifier string ) if str, ok := scope.Get("gorm:insert_option"); ok { extraOption = fmt.Sprint(str) } + if str, ok := scope.Get("gorm:insert_modifier"); ok { + insertModifier = strings.ToUpper(fmt.Sprint(str)) + if insertModifier == "INTO" { + insertModifier = "" + } + } if primaryField != nil { returningColumn = scope.Quote(primaryField.DBName) @@ -97,7 +104,8 @@ func createCallback(scope *Scope) { if len(columns) == 0 { scope.Raw(fmt.Sprintf( - "INSERT INTO %v %v%v%v", + "INSERT %v INTO %v %v%v%v", + addExtraSpaceIfExist(insertModifier), quotedTableName, scope.Dialect().DefaultValueStr(), addExtraSpaceIfExist(extraOption), @@ -105,7 +113,8 @@ func createCallback(scope *Scope) { )) } else { scope.Raw(fmt.Sprintf( - "INSERT INTO %v (%v) VALUES (%v)%v%v", + "INSERT %v INTO %v (%v) VALUES (%v)%v%v", + addExtraSpaceIfExist(insertModifier), scope.QuotedTableName(), strings.Join(columns, ","), strings.Join(placeholders, ","), diff --git a/create_test.go b/create_test.go index 92560643..450dd8a4 100644 --- a/create_test.go +++ b/create_test.go @@ -229,3 +229,20 @@ func TestOmitWithCreate(t *testing.T) { t.Errorf("Should not create omitted relationships") } } + +func TestCreateIgnore(t *testing.T) { + float := 35.03554004971999 + now := time.Now() + user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float} + + if !DB.NewRecord(user) || !DB.NewRecord(&user) { + t.Error("User should be new record before create") + } + + if count := DB.Create(&user).RowsAffected; count != 1 { + t.Error("There should be one record be affected when create record") + } + if DB.Dialect().GetName() == "mysql" && DB.Set("gorm:insert_modifier", "IGNORE").Create(&user).Error != nil { + t.Error("Should ignore duplicate user insert by insert modifier:IGNORE ") + } +} From 26e8799a192569dcc22efd1d43f96a0bb1bafe81 Mon Sep 17 00:00:00 2001 From: Wendell Sun Date: Mon, 11 Mar 2019 19:56:03 +0800 Subject: [PATCH 120/881] fix the case that using Having on Count --- main_test.go | 26 ++++++++++++++++++++++++++ scope.go | 11 +++++++++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/main_test.go b/main_test.go index 94d2fa39..ac40c32b 100644 --- a/main_test.go +++ b/main_test.go @@ -1059,6 +1059,32 @@ func TestBlockGlobalUpdate(t *testing.T) { } } +func TestCountWithHaving(t *testing.T) { + db := DB.New() + db.Delete(User{}) + defer db.Delete(User{}) + + DB.Create(getPreparedUser("user1", "pluck_user")) + DB.Create(getPreparedUser("user2", "pluck_user")) + user3:=getPreparedUser("user3", "pluck_user") + user3.Languages=[]Language{} + DB.Create(user3) + + var count int + err:=db.Model(User{}).Select("users.id"). + Joins("LEFT JOIN user_languages ON user_languages.user_id = users.id"). + Joins("LEFT JOIN languages ON user_languages.language_id = languages.id"). + Group("users.id").Having("COUNT(languages.id) > 1").Count(&count).Error + + if err != nil { + t.Error("Unexpected error on query count with having") + } + + if count!=2{ + t.Error("Unexpected result on query count with having") + } +} + func BenchmarkGorm(b *testing.B) { b.N = 2000 for x := 0; x < b.N; x++ { diff --git a/scope.go b/scope.go index 90e16782..7fa64b19 100644 --- a/scope.go +++ b/scope.go @@ -1007,8 +1007,15 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope { func (scope *Scope) count(value interface{}) *Scope { if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) { if len(scope.Search.group) != 0 { - scope.Search.Select("count(*) FROM ( SELECT count(*) as name ") - scope.Search.group += " ) AS count_table" + if len(scope.Search.havingConditions) != 0 { + scope.prepareQuerySQL() + scope.Search = &search{} + scope.Search.Select("count(*)") + scope.Search.Table(fmt.Sprintf("( %s ) AS count_table", scope.SQL)) + } else { + scope.Search.Select("count(*) FROM ( SELECT count(*) as name ") + scope.Search.group += " ) AS count_table" + } } else { scope.Search.Select("count(*)") } From 2fb2c0d3b20dd20a2fc8017c4f0b302ee6069a88 Mon Sep 17 00:00:00 2001 From: Wendell Sun Date: Thu, 14 Mar 2019 02:33:42 +0800 Subject: [PATCH 121/881] return empty slice for many2many if no asscociation was found --- callback_query_preload.go | 16 +++++++++++----- preload_test.go | 1 + 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/callback_query_preload.go b/callback_query_preload.go index d7c8a133..a936180a 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -391,14 +391,20 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) } - for source, link := range linkHash { - for i, field := range fieldsSourceMap[source] { + + for source, fields := range fieldsSourceMap { + for _, f := range fields { //If not 0 this means Value is a pointer and we already added preloaded models to it - if fieldsSourceMap[source][i].Len() != 0 { + if f.Len() != 0 { continue } - field.Set(reflect.Append(fieldsSourceMap[source][i], link...)) - } + v := reflect.MakeSlice(f.Type(), 0, 0) + if len(linkHash[source]) > 0 { + v = reflect.Append(f, linkHash[source]...) + } + + f.Set(v) + } } } diff --git a/preload_test.go b/preload_test.go index 1db625c9..1a6a5d49 100644 --- a/preload_test.go +++ b/preload_test.go @@ -771,6 +771,7 @@ func TestNestedPreload11(t *testing.T) { levelB3 := &LevelB3{ Value: "bar", LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)}, + LevelB2s: []*LevelB2{}, } if err := DB.Create(levelB3).Error; err != nil { t.Error(err) From 14e0507fd2d31c10406811fe10f2c024e98d0b93 Mon Sep 17 00:00:00 2001 From: Wendell Sun Date: Thu, 14 Mar 2019 12:12:38 +0800 Subject: [PATCH 122/881] fix the table name of many2many --- customize_column_test.go | 11 +++++++++++ model_struct.go | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/customize_column_test.go b/customize_column_test.go index 5e19d6f4..c236ac24 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -289,6 +289,9 @@ type SelfReferencingUser struct { func TestSelfReferencingMany2ManyColumn(t *testing.T) { DB.DropTable(&SelfReferencingUser{}, "UserFriends") DB.AutoMigrate(&SelfReferencingUser{}) + if !DB.HasTable("UserFriends") { + t.Errorf("auto migrate error, table UserFriends should be created") + } friend1 := SelfReferencingUser{Name: "friend1_m2m"} if err := DB.Create(&friend1).Error; err != nil { @@ -313,6 +316,14 @@ func TestSelfReferencingMany2ManyColumn(t *testing.T) { t.Errorf("Should find created friends correctly") } + var count int + if err := DB.Table("UserFriends").Count(&count).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + if count == 0 { + t.Errorf("table UserFriends should have records") + } + var newUser = SelfReferencingUser{} if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil { diff --git a/model_struct.go b/model_struct.go index 9e93db63..a1e6c0e2 100644 --- a/model_struct.go +++ b/model_struct.go @@ -340,7 +340,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, ToTableName(many2many), reflectType, elemType) + joinTableHandler.Setup(relationship, many2many, reflectType, elemType) relationship.JoinTableHandler = &joinTableHandler field.Relationship = relationship } else { From bc5d3f07a8036de43115bdd04ce0da2f0d929d62 Mon Sep 17 00:00:00 2001 From: JUN JIE NAN Date: Fri, 5 Apr 2019 07:59:02 +0800 Subject: [PATCH 123/881] Removed the deps on uuid and appengine (#2354) gofrs/uuid was used in testing only, and go module count testing depends in. This patch removed the gofrs/uuid depends, and appengine as well. --- field_test.go | 71 ++++++++++++++++++++++++++++++++++++++++++++++++--- go.mod | 2 -- go.sum | 4 +-- 3 files changed, 68 insertions(+), 9 deletions(-) diff --git a/field_test.go b/field_test.go index 03a3b3b7..715661f0 100644 --- a/field_test.go +++ b/field_test.go @@ -1,9 +1,11 @@ package gorm_test import ( + "database/sql/driver" + "encoding/hex" + "fmt" "testing" - "github.com/gofrs/uuid" "github.com/jinzhu/gorm" ) @@ -49,17 +51,78 @@ func TestCalculateField(t *testing.T) { } } +type UUID [16]byte + +type NullUUID struct { + UUID + Valid bool +} + +func FromString(input string) (u UUID) { + src := []byte(input) + return FromBytes(src) +} + +func FromBytes(src []byte) (u UUID) { + dst := u[:] + hex.Decode(dst[0:4], src[0:8]) + hex.Decode(dst[4:6], src[9:13]) + hex.Decode(dst[6:8], src[14:18]) + hex.Decode(dst[8:10], src[19:23]) + hex.Decode(dst[10:], src[24:]) + return +} + +func (u UUID) String() string { + buf := make([]byte, 36) + src := u[:] + hex.Encode(buf[0:8], src[0:4]) + buf[8] = '-' + hex.Encode(buf[9:13], src[4:6]) + buf[13] = '-' + hex.Encode(buf[14:18], src[6:8]) + buf[18] = '-' + hex.Encode(buf[19:23], src[8:10]) + buf[23] = '-' + hex.Encode(buf[24:], src[10:]) + return string(buf) +} + +func (u UUID) Value() (driver.Value, error) { + return u.String(), nil +} + +func (u *UUID) Scan(src interface{}) error { + switch src := src.(type) { + case UUID: // support gorm convert from UUID to NullUUID + *u = src + return nil + case []byte: + *u = FromBytes(src) + return nil + case string: + *u = FromString(src) + return nil + } + return fmt.Errorf("uuid: cannot convert %T to UUID", src) +} + +func (u *NullUUID) Scan(src interface{}) error { + u.Valid = true + return u.UUID.Scan(src) +} + func TestFieldSet(t *testing.T) { type TestFieldSetNullUUID struct { - NullUUID uuid.NullUUID + NullUUID NullUUID } scope := DB.NewScope(&TestFieldSetNullUUID{}) field := scope.Fields()[0] - err := field.Set(uuid.FromStringOrNil("3034d44a-da03-11e8-b366-4a00070b9f00")) + err := field.Set(FromString("3034d44a-da03-11e8-b366-4a00070b9f00")) if err != nil { t.Fatal(err) } - if id, ok := field.Field.Addr().Interface().(*uuid.NullUUID); !ok { + if id, ok := field.Field.Addr().Interface().(*NullUUID); !ok { t.Fatal() } else if !id.Valid || id.UUID.String() != "3034d44a-da03-11e8-b366-4a00070b9f00" { t.Fatal(id) diff --git a/go.mod b/go.mod index f675334d..024f73ca 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,9 @@ require ( github.com/denisenkom/go-mssqldb v0.0.0-20190204142019-df6d76eb9289 github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 github.com/go-sql-driver/mysql v1.4.1 - github.com/gofrs/uuid v3.2.0+incompatible github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a github.com/jinzhu/now v1.0.0 github.com/lib/pq v1.0.0 github.com/mattn/go-sqlite3 v1.10.0 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 // indirect - google.golang.org/appengine v1.4.0 // indirect ) diff --git a/go.sum b/go.sum index 25f61146..894ee21b 100644 --- a/go.sum +++ b/go.sum @@ -25,8 +25,6 @@ github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeME github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/gofrs/uuid v3.2.0+incompatible h1:y12jRkkFxsd7GpqdSZ+/KCs/fJbqpEXSGd4+jfEaewE= -github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= @@ -34,6 +32,7 @@ github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfb github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= @@ -132,7 +131,6 @@ google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx1 google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= From 071b657418cccdab721e549108b6b6cf8a1b7361 Mon Sep 17 00:00:00 2001 From: Jony4 Date: Fri, 5 Apr 2019 08:00:48 +0800 Subject: [PATCH 124/881] fix TagSettings' map has "":"" value (#2372) --- model_struct.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/model_struct.go b/model_struct.go index 9e93db63..194bcfdc 100644 --- a/model_struct.go +++ b/model_struct.go @@ -625,6 +625,9 @@ func (scope *Scope) GetStructFields() (fields []*StructField) { func parseTagSetting(tags reflect.StructTag) map[string]string { setting := map[string]string{} for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { + if str == "" { + continue + } tags := strings.Split(str, ";") for _, value := range tags { v := strings.Split(value, ":") From 1c62bf1e5794f9db023e7a3f450788e071bd7bd3 Mon Sep 17 00:00:00 2001 From: Momo733 <1550526230@qq.com> Date: Sat, 13 Apr 2019 14:23:35 +0800 Subject: [PATCH 125/881] fix save err when specify a table name s.New() will clear all search conditions and search value,when I use Table() to set a table name. Then FirstOrCreate() will use struct name as my database table name,so It doesn't work. --- main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.go b/main.go index fda63d29..927bd5af 100644 --- a/main.go +++ b/main.go @@ -444,7 +444,7 @@ func (s *DB) Save(value interface{}) *DB { if !scope.PrimaryKeyZero() { newDB := scope.callCallbacks(s.parent.callbacks.updates).db if newDB.Error == nil && newDB.RowsAffected == 0 { - return s.New().FirstOrCreate(value) + return s.FirstOrCreate(value) } return newDB } From da037b0454eef67dee736aebd58efc1e7376184f Mon Sep 17 00:00:00 2001 From: Emir Beganovic Date: Thu, 11 Apr 2019 17:28:26 +0400 Subject: [PATCH 126/881] Cleanup go.mod --- go.mod | 6 +-- go.sum | 142 +++++++++++++++++++++++++++++++++------------------------ 2 files changed, 84 insertions(+), 64 deletions(-) diff --git a/go.mod b/go.mod index 024f73ca..89ca68d8 100644 --- a/go.mod +++ b/go.mod @@ -1,13 +1,9 @@ module github.com/jinzhu/gorm require ( - cloud.google.com/go v0.36.0 // indirect - github.com/denisenkom/go-mssqldb v0.0.0-20190204142019-df6d76eb9289 - github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 + github.com/denisenkom/go-mssqldb v0.0.0-20190401154936-ce35bd87d4b3 github.com/go-sql-driver/mysql v1.4.1 github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a - github.com/jinzhu/now v1.0.0 github.com/lib/pq v1.0.0 github.com/mattn/go-sqlite3 v1.10.0 - golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 // indirect ) diff --git a/go.sum b/go.sum index 894ee21b..a984e572 100644 --- a/go.sum +++ b/go.sum @@ -1,149 +1,173 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.36.0 h1:+aCSj7tOo2LODWVEuZDZeGCckdt6MlSF+X/rB3wUiS8= -cloud.google.com/go v0.36.0/go.mod h1:RUoy9p/M4ge0HzT8L+SDZ8jg+Q6fth0CiBuhFJpSV40= -dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU= -dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBrvjyP0v+ecvNYvCpyZgu5/xkfAUhi6wJj28eUfSU= -dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= -dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.37.4 h1:glPeL3BQJsbF6aIIYfZizMwc5LTYz250bDMjttbBGAU= +cloud.google.com/go v0.37.4/go.mod h1:NHPJ89PdicEuT9hdPXMROBD91xc5uRDxsMtSB16k7hw= git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= +git.apache.org/thrift.git v0.12.0/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= +github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= +github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= +github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/denisenkom/go-mssqldb v0.0.0-20190204142019-df6d76eb9289 h1:U+DzmGUpc/dOjREgbyyChPhdDIFwPYnVk+/5YcAa194= -github.com/denisenkom/go-mssqldb v0.0.0-20190204142019-df6d76eb9289/go.mod h1:xN/JuLBIz4bjkxNmByTiV1IbhfnYb6oo99phBn4Eqhc= -github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DPaBjB8zlTR87/ElzFsnQfuHnVUVqpZZIcV5Y= -github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= +github.com/denisenkom/go-mssqldb v0.0.0-20190401154936-ce35bd87d4b3/go.mod h1:EcO5fNtMZHCMjAvj8LE6T+5bphSdR6LQ75n+m1TtsFI= +github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= +github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= +github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= -github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= +github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= -github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= +github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= +github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= +github.com/grpc-ecosystem/grpc-gateway v1.6.2/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= +github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a h1:eeaG9XMUvRBYXJi4pg1ZKM7nxc5AfXfojeLLW7O5J3k= github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.0.0 h1:6WV8LvwPpDhKjo5U9O6b4+xdG/jTXNPwlDme/MTo8Ns= -github.com/jinzhu/now v1.0.0/go.mod h1:oHTiXerJ20+SfYcrdlBO7rzZRJWGwSTQ0iUY2jI6Gfc= +github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= +github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= -github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= -github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= +github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= +github.com/openzipkin/zipkin-go v0.1.3/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= +github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= +github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829/go.mod h1:p2iRAGwDERtqlqzRXnrOVns+ignqQo//hLXqYxZYVNs= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= +github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= -github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= -github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= -github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= -github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470/go.mod h1:2dOwnU2uBioM+SGy2aZoq1f/Sd1l9OkAeAUvjSyvgU0= -github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= -github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= -github.com/shurcooL/gofontwoff v0.0.0-20180329035133-29b52fc0a18d/go.mod h1:05UtEgK5zq39gLST6uB0cf3NEHjETfB4Fgr3Gx5R9Vw= -github.com/shurcooL/gopherjslib v0.0.0-20160914041154-feb6d3990c2c/go.mod h1:8d3azKNyqcHP1GaQE/c6dDgjkgSx2BZ4IoEi4F1reUI= -github.com/shurcooL/highlight_diff v0.0.0-20170515013008-09bb4053de1b/go.mod h1:ZpfEhSmds4ytuByIcDnOLkTHGUI6KNqRNPDLHDk+mUU= -github.com/shurcooL/highlight_go v0.0.0-20181028180052-98c3abbbae20/go.mod h1:UDKB5a1T23gOMUJrI+uSuH0VRDStOiUVSjBTRDVBVag= -github.com/shurcooL/home v0.0.0-20181020052607-80b7ffcb30f9/go.mod h1:+rgNQw2P9ARFAs37qieuu7ohDNQ3gds9msbT2yn85sg= -github.com/shurcooL/htmlg v0.0.0-20170918183704-d01228ac9e50/go.mod h1:zPn1wHpTIePGnXSHpsVPWEktKXHr6+SS6x/IKRb7cpw= -github.com/shurcooL/httperror v0.0.0-20170206035902-86b7830d14cc/go.mod h1:aYMfkZ6DWSJPJ6c4Wwz3QtW22G7mf/PEgaB9k/ik5+Y= -github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371/go.mod h1:ZY1cvUeJuFPAdZ/B6v7RHavJWZn2YPVFQ1OSXhCGOkg= -github.com/shurcooL/httpgzip v0.0.0-20180522190206-b1c53ac65af9/go.mod h1:919LwcH0M7/W4fcZ0/jy0qGght1GIhqyS/EgWGH2j5Q= -github.com/shurcooL/issues v0.0.0-20181008053335-6292fdc1e191/go.mod h1:e2qWDig5bLteJ4fwvDAc2NHzqFEthkqn7aOZAOpj+PQ= -github.com/shurcooL/issuesapp v0.0.0-20180602232740-048589ce2241/go.mod h1:NPpHK2TI7iSaM0buivtFUc9offApnI0Alt/K8hcHy0I= -github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122/go.mod h1:b5uSkrEVM1jQUspwbixRBhaIjIzL2xazXp6kntxYle0= -github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2/go.mod h1:eWdoE5JD4R5UVWDucdOPg1g2fqQRq78IQa9zlOV1vpQ= -github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82/go.mod h1:TCR1lToEk4d2s07G3XGfz2QrgHXg4RJBvjrOozvoWfk= -github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= -github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYEDaXHZDBsXlPCDqdhQuJkuw4NOtaxYe3xii4= -github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw= -github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= -github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= +github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= +go.opencensus.io v0.19.1/go.mod h1:gug0GbSHa8Pafr0d2urOSgoXHZ6x/RUlaiT0d9pqb4A= +go.opencensus.io v0.19.2/go.mod h1:NO/8qkisMZLZ1FCsKNqtJPwc8/TaclWyY0B6wcYNg9M= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= -golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= +golang.org/x/build v0.0.0-20190314133821-5284462c4bec/go.mod h1:atTaCNAy0f16Ah5aV1gMSwgiKVHwu/JncqDpuRr7lS4= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20181217174547-8f45f776aaf1/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181218192612-074acd46bca6/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181219222714-6e267b5cc78e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= -google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= +google.golang.org/api v0.0.0-20181220000619-583d854617af/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.2.0/go.mod h1:IfRCZScioGtypHNTlz3gFk67J8uePVW7uDTBzXuIkhU= +google.golang.org/api v0.3.0/go.mod h1:IuvZyQh8jgscv8qWfQ4ABd8m7hEudgBFM/EdhA3BnXw= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= -google.golang.org/genproto v0.0.0-20190201180003-4b09977fb922/go.mod h1:L3J43x8/uS+qIUoksaLKe6OS3nUKxOKuIFz1sl2/jx4= +google.golang.org/genproto v0.0.0-20181219182458-5a97ab628bfb/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= +google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20180920025451-e3ad64cb4ed3/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck= -sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0= From 59594877dafa901578dd80e390f2a25a236aaaeb Mon Sep 17 00:00:00 2001 From: Emir Beganovic Date: Sun, 14 Apr 2019 11:38:06 +0400 Subject: [PATCH 127/881] Fix unsafe concurrent SingularTable method call --- main.go | 4 +++- main_test.go | 33 +++++++++++++++++++++++++++++---- model_struct.go | 17 +++++++++++++++-- 3 files changed, 47 insertions(+), 7 deletions(-) diff --git a/main.go b/main.go index fda63d29..cc8ac68c 100644 --- a/main.go +++ b/main.go @@ -12,6 +12,7 @@ import ( // DB contains information for current db connection type DB struct { + sync.Mutex Value interface{} Error error RowsAffected int64 @@ -170,7 +171,8 @@ func (s *DB) HasBlockGlobalUpdate() bool { // SingularTable use singular table by default func (s *DB) SingularTable(enable bool) { - modelStructsMap = sync.Map{} + s.parent.Lock() + defer s.parent.Unlock() s.parent.singularTable = enable } diff --git a/main_test.go b/main_test.go index ac40c32b..1dc30093 100644 --- a/main_test.go +++ b/main_test.go @@ -9,6 +9,7 @@ import ( "reflect" "strconv" "strings" + "sync" "testing" "time" @@ -277,6 +278,30 @@ func TestTableName(t *testing.T) { DB.SingularTable(false) } +func TestTableNameConcurrently(t *testing.T) { + DB := DB.Model("") + if DB.NewScope(Order{}).TableName() != "orders" { + t.Errorf("Order's table name should be orders") + } + + var wg sync.WaitGroup + wg.Add(10) + + for i := 1; i <= 10; i++ { + go func(db *gorm.DB) { + DB.SingularTable(true) + wg.Done() + }(DB) + } + wg.Wait() + + if DB.NewScope(Order{}).TableName() != "order" { + t.Errorf("Order's singular table name should be order") + } + + DB.SingularTable(false) +} + func TestNullValues(t *testing.T) { DB.DropTable(&NullValue{}) DB.AutoMigrate(&NullValue{}) @@ -1066,12 +1091,12 @@ func TestCountWithHaving(t *testing.T) { DB.Create(getPreparedUser("user1", "pluck_user")) DB.Create(getPreparedUser("user2", "pluck_user")) - user3:=getPreparedUser("user3", "pluck_user") - user3.Languages=[]Language{} + user3 := getPreparedUser("user3", "pluck_user") + user3.Languages = []Language{} DB.Create(user3) var count int - err:=db.Model(User{}).Select("users.id"). + err := db.Model(User{}).Select("users.id"). Joins("LEFT JOIN user_languages ON user_languages.user_id = users.id"). Joins("LEFT JOIN languages ON user_languages.language_id = languages.id"). Group("users.id").Having("COUNT(languages.id) > 1").Count(&count).Error @@ -1080,7 +1105,7 @@ func TestCountWithHaving(t *testing.T) { t.Error("Unexpected error on query count with having") } - if count!=2{ + if count != 2 { t.Error("Unexpected result on query count with having") } } diff --git a/model_struct.go b/model_struct.go index f646910a..8d6313fb 100644 --- a/model_struct.go +++ b/model_struct.go @@ -40,9 +40,11 @@ func (s *ModelStruct) TableName(db *DB) string { s.defaultTableName = tabler.TableName() } else { tableName := ToTableName(s.ModelType.Name()) + db.parent.Lock() if db == nil || (db.parent != nil && !db.parent.singularTable) { tableName = inflection.Plural(tableName) } + db.parent.Unlock() s.defaultTableName = tableName } } @@ -163,7 +165,18 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Get Cached model struct - if value, ok := modelStructsMap.Load(reflectType); ok && value != nil { + isSingularTable := false + if scope.db != nil && scope.db.parent != nil { + scope.db.parent.Lock() + isSingularTable = scope.db.parent.singularTable + scope.db.parent.Unlock() + } + + hashKey := struct { + singularTable bool + reflectType reflect.Type + }{isSingularTable, reflectType} + if value, ok := modelStructsMap.Load(hashKey); ok && value != nil { return value.(*ModelStruct) } @@ -612,7 +625,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } - modelStructsMap.Store(reflectType, &modelStruct) + modelStructsMap.Store(hashKey, &modelStruct) return &modelStruct } From b4927348aebb1e84df37aa432c64ebb1c1ae3edb Mon Sep 17 00:00:00 2001 From: Emir Beganovic Date: Sun, 14 Apr 2019 11:40:05 +0400 Subject: [PATCH 128/881] gofmt --- preload_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/preload_test.go b/preload_test.go index 1a6a5d49..dd29fb5e 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1677,7 +1677,7 @@ func TestPreloadManyToManyCallbacks(t *testing.T) { lvl := Level1{ Name: "l1", Level2s: []Level2{ - Level2{Name: "l2-1"}, Level2{Name: "l2-2"}, + {Name: "l2-1"}, {Name: "l2-2"}, }, } DB.Save(&lvl) From b923e78e811c9bf9a244c6fb0983443101a4332b Mon Sep 17 00:00:00 2001 From: Emir Beganovic Date: Sun, 14 Apr 2019 12:23:26 +0400 Subject: [PATCH 129/881] Verbose go get output --- wercker.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wercker.yml b/wercker.yml index 0c3e73ef..43ad8209 100644 --- a/wercker.yml +++ b/wercker.yml @@ -83,7 +83,7 @@ build: code: | cd $WERCKER_SOURCE_DIR go version - go get -t ./... + go get -t -v ./... # Build the project - script: From 96d52f25b09fae789adb0c97ccf36f343a8f08fc Mon Sep 17 00:00:00 2001 From: Emir Beganovic Date: Sun, 14 Apr 2019 12:41:14 +0400 Subject: [PATCH 130/881] Use RWMutex --- main.go | 2 +- model_struct.go | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/main.go b/main.go index cc8ac68c..16820353 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,7 @@ import ( // DB contains information for current db connection type DB struct { - sync.Mutex + sync.RWMutex Value interface{} Error error RowsAffected int64 diff --git a/model_struct.go b/model_struct.go index 8d6313fb..bfab49c0 100644 --- a/model_struct.go +++ b/model_struct.go @@ -40,11 +40,11 @@ func (s *ModelStruct) TableName(db *DB) string { s.defaultTableName = tabler.TableName() } else { tableName := ToTableName(s.ModelType.Name()) - db.parent.Lock() + db.parent.RLock() if db == nil || (db.parent != nil && !db.parent.singularTable) { tableName = inflection.Plural(tableName) } - db.parent.Unlock() + db.parent.RUnlock() s.defaultTableName = tableName } } @@ -167,9 +167,9 @@ func (scope *Scope) GetModelStruct() *ModelStruct { // Get Cached model struct isSingularTable := false if scope.db != nil && scope.db.parent != nil { - scope.db.parent.Lock() + scope.db.parent.RLock() isSingularTable = scope.db.parent.singularTable - scope.db.parent.Unlock() + scope.db.parent.RUnlock() } hashKey := struct { From cd0f3ae41a86cdd5884e14147336542a81294fd6 Mon Sep 17 00:00:00 2001 From: Emir Beganovic Date: Sun, 14 Apr 2019 12:41:23 +0400 Subject: [PATCH 131/881] Run tests with race detector --- wercker.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wercker.yml b/wercker.yml index 43ad8209..de351fd2 100644 --- a/wercker.yml +++ b/wercker.yml @@ -95,7 +95,7 @@ build: - script: name: test sqlite code: | - go test ./... + go test -race -v ./... - script: name: test mariadb From ef9d2070bbed3d9186f8e0aa1b86c55b20411a55 Mon Sep 17 00:00:00 2001 From: Emir Beganovic Date: Sun, 14 Apr 2019 12:46:05 +0400 Subject: [PATCH 132/881] Run tests with race detector --- wercker.yml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/wercker.yml b/wercker.yml index de351fd2..98234583 100644 --- a/wercker.yml +++ b/wercker.yml @@ -100,49 +100,49 @@ build: - script: name: test mariadb code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test ./... + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - script: name: test mysql5.7 code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test ./... + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - script: name: test mysql5.6 code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test ./... + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - script: name: test mysql5.5 code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test ./... + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - script: name: test postgres code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - script: name: test postgres96 code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - script: name: test postgres95 code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - script: name: test postgres94 code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - script: name: test postgres93 code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - script: name: test mssql code: | - GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test ./... + GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test -race ./... From 7bc35615034c1d6994088c2cc925086fba6f565e Mon Sep 17 00:00:00 2001 From: Shunsuke Otani Date: Sun, 14 Apr 2019 22:11:29 +0900 Subject: [PATCH 133/881] Don't set NULL if timestamp column is Primary Key (#2332) --- dialect_mysql.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dialect_mysql.go b/dialect_mysql.go index 5d63e5cd..89b638b3 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -100,7 +100,7 @@ func (s *mysql) DataTypeOf(field *StructField) string { precision = fmt.Sprintf("(%s)", p) } - if _, ok := field.TagSettingsGet("NOT NULL"); ok { + if _, ok := field.TagSettingsGet("NOT NULL"); ok || field.IsPrimaryKey { sqlType = fmt.Sprintf("timestamp%v", precision) } else { sqlType = fmt.Sprintf("timestamp%v NULL", precision) From 8d1e6bc0f8e9710dcba60a1b8e4ec3f47e8bf8ea Mon Sep 17 00:00:00 2001 From: Dmitry Zenovich Date: Fri, 19 Apr 2019 14:41:30 +0300 Subject: [PATCH 134/881] remove old elements from the output parameter of Pluck() --- main_test.go | 31 +++++++++++++++++++++++++++++++ scope.go | 4 ++++ 2 files changed, 35 insertions(+) diff --git a/main_test.go b/main_test.go index 1dc30093..4100c7f8 100644 --- a/main_test.go +++ b/main_test.go @@ -1110,6 +1110,37 @@ func TestCountWithHaving(t *testing.T) { } } +func TestPluck(t *testing.T) { + db := DB.New() + db.Delete(User{}) + defer db.Delete(User{}) + + DB.Create(&User{Id: 1, Name: "user1"}) + DB.Create(&User{Id: 2, Name: "user2"}) + DB.Create(&User{Id: 3, Name: "user3"}) + + var ids []int64 + err := db.Model(User{}).Order("id").Pluck("id", &ids).Error + + if err != nil { + t.Error("Unexpected error on pluck") + } + + if len(ids) != 3 || ids[0] != 1 || ids[1] != 2 || ids[2] != 3 { + t.Error("Unexpected result on pluck") + } + + err = db.Model(User{}).Order("id").Pluck("id", &ids).Error + + if err != nil { + t.Error("Unexpected error on pluck again") + } + + if len(ids) != 3 || ids[0] != 1 || ids[1] != 2 || ids[2] != 3 { + t.Error("Unexpected result on pluck again") + } +} + func BenchmarkGorm(b *testing.B) { b.N = 2000 for x := 0; x < b.N; x++ { diff --git a/scope.go b/scope.go index 7fa64b19..0767bb66 100644 --- a/scope.go +++ b/scope.go @@ -984,6 +984,10 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope { return scope } + if dest.Len() > 0 { + dest.Set(reflect.Zero(dest.Type())) + } + if query, ok := scope.Search.selects["query"]; !ok || !scope.isQueryForColumn(query, column) { scope.Search.Select(column) } From adc8e9b706101707f6138e7832293fb7450b38a7 Mon Sep 17 00:00:00 2001 From: Dmitry Zenovich Date: Fri, 19 Apr 2019 14:48:52 +0300 Subject: [PATCH 135/881] apply gorm:query_option in Count() --- callback_row_query.go | 8 +++++++- main_test.go | 23 +++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/callback_row_query.go b/callback_row_query.go index c2ff4a08..687b0039 100644 --- a/callback_row_query.go +++ b/callback_row_query.go @@ -1,6 +1,9 @@ package gorm -import "database/sql" +import ( + "database/sql" + "fmt" +) // Define callbacks for row query func init() { @@ -20,6 +23,9 @@ type RowsQueryResult struct { func rowQueryCallback(scope *Scope) { if result, ok := scope.InstanceGet("row_query_result"); ok { scope.prepareQuerySQL() + if str, ok := scope.Get("gorm:query_option"); ok { + scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) + } if rowResult, ok := result.(*RowQueryResult); ok { rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) diff --git a/main_test.go b/main_test.go index 1dc30093..a0d95369 100644 --- a/main_test.go +++ b/main_test.go @@ -1110,6 +1110,29 @@ func TestCountWithHaving(t *testing.T) { } } +func TestCountWithQueryOption(t *testing.T) { + db := DB.New() + db.Delete(User{}) + defer db.Delete(User{}) + + DB.Create(&User{Name: "user1"}) + DB.Create(&User{Name: "user2"}) + DB.Create(&User{Name: "user3"}) + + var count int + err := db.Model(User{}).Select("users.id"). + Set("gorm:query_option", "WHERE users.name='user2'"). + Count(&count).Error + + if err != nil { + t.Error("Unexpected error on query count with query_option") + } + + if count != 1 { + t.Error("Unexpected result on query count with query_option") + } +} + func BenchmarkGorm(b *testing.B) { b.N = 2000 for x := 0; x < b.N; x++ { From 09a868b381e19e41f1d99bb38a75290976d5b9ed Mon Sep 17 00:00:00 2001 From: zaneli Date: Mon, 15 Apr 2019 17:46:50 +0900 Subject: [PATCH 136/881] Handle syntax to specify an index prefix length --- dialect.go | 3 +++ dialect_common.go | 9 ++++++++- dialect_mysql.go | 15 ++++++++++++++- dialects/mssql/mssql.go | 5 +++++ migration_test.go | 39 +++++++++++++++++++++++++++++++++++++++ scope.go | 6 ++++-- 6 files changed, 73 insertions(+), 4 deletions(-) diff --git a/dialect.go b/dialect.go index cdc4278e..831c0a8e 100644 --- a/dialect.go +++ b/dialect.go @@ -48,6 +48,9 @@ type Dialect interface { // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference BuildKeyName(kind, tableName string, fields ...string) string + // NormalizeIndexAndColumn returns valid index name and column name depending on each dialect + NormalizeIndexAndColumn(indexName, columnName string) (string, string) + // CurrentDatabase return current database name CurrentDatabase() string } diff --git a/dialect_common.go b/dialect_common.go index a479be79..e3a5b702 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -9,6 +9,8 @@ import ( "time" ) +var keyNameRegex = regexp.MustCompile("[^a-zA-Z0-9]+") + // DefaultForeignKeyNamer contains the default foreign key name generator method type DefaultForeignKeyNamer struct { } @@ -166,10 +168,15 @@ func (commonDialect) DefaultValueStr() string { // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) - keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_") + keyName = keyNameRegex.ReplaceAllString(keyName, "_") return keyName } +// NormalizeIndexAndColumn returns argument's index name and column name without doing anything +func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { + return indexName, columnName +} + // IsByteArrayOrSlice returns true of the reflected value is an array or slice func IsByteArrayOrSlice(value reflect.Value) bool { return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) diff --git a/dialect_mysql.go b/dialect_mysql.go index 89b638b3..5a1ad708 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -11,6 +11,8 @@ import ( "unicode/utf8" ) +var mysqlIndexRegex = regexp.MustCompile(`^(.+)\((\d+)\)$`) + type mysql struct { commonDialect } @@ -178,7 +180,7 @@ func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string { bs := h.Sum(nil) // sha1 is 40 characters, keep first 24 characters of destination - destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_")) + destRunes := []rune(keyNameRegex.ReplaceAllString(fields[0], "_")) if len(destRunes) > 24 { destRunes = destRunes[:24] } @@ -186,6 +188,17 @@ func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string { return fmt.Sprintf("%s%x", string(destRunes), bs) } +// NormalizeIndexAndColumn returns index name and column name for specify an index prefix length if needed +func (mysql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { + submatch := mysqlIndexRegex.FindStringSubmatch(indexName) + if len(submatch) != 3 { + return indexName, columnName + } + indexName = submatch[1] + columnName = fmt.Sprintf("%s(%s)", columnName, submatch[2]) + return indexName, columnName +} + func (mysql) DefaultValueStr() string { return "VALUES()" } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 6c424bc1..8c2360fc 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -198,6 +198,11 @@ func (mssql) DefaultValueStr() string { return "DEFAULT VALUES" } +// NormalizeIndexAndColumn returns argument's index name and column name without doing anything +func (mssql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { + return indexName, columnName +} + func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { if strings.Contains(tableName, ".") { splitStrings := strings.SplitN(tableName, ".", 2) diff --git a/migration_test.go b/migration_test.go index 3fb06648..d94ec9ec 100644 --- a/migration_test.go +++ b/migration_test.go @@ -538,3 +538,42 @@ func TestModifyColumnType(t *testing.T) { t.Errorf("No error should happen when ModifyColumn, but got %v", err) } } + +func TestIndexWithPrefixLength(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" { + t.Skip("Skipping this because only mysql support setting an index prefix length") + } + + type IndexWithPrefix struct { + gorm.Model + Name string + Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` + } + type IndexesWithPrefix struct { + gorm.Model + Name string + Description1 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` + Description2 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` + } + type IndexesWithPrefixAndWithoutPrefix struct { + gorm.Model + Name string `gorm:"index:idx_index_with_prefixes_length"` + Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` + } + tables := []interface{}{&IndexWithPrefix{}, &IndexesWithPrefix{}, &IndexesWithPrefixAndWithoutPrefix{}} + for _, table := range tables { + scope := DB.NewScope(table) + tableName := scope.TableName() + t.Run(fmt.Sprintf("Create index with prefix length: %s", tableName), func(t *testing.T) { + if err := DB.DropTableIfExists(table).Error; err != nil { + t.Errorf("Failed to drop %s table: %v", tableName, err) + } + if err := DB.CreateTable(table).Error; err != nil { + t.Errorf("Failed to create %s table: %v", tableName, err) + } + if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") { + t.Errorf("Failed to create %s table index:", tableName) + } + }) + } +} diff --git a/scope.go b/scope.go index 7fa64b19..01355103 100644 --- a/scope.go +++ b/scope.go @@ -1284,7 +1284,8 @@ func (scope *Scope) autoIndex() *Scope { if name == "INDEX" || name == "" { name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName) } - indexes[name] = append(indexes[name], field.DBName) + name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName) + indexes[name] = append(indexes[name], column) } } @@ -1295,7 +1296,8 @@ func (scope *Scope) autoIndex() *Scope { if name == "UNIQUE_INDEX" || name == "" { name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName) } - uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) + name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName) + uniqueIndexes[name] = append(uniqueIndexes[name], column) } } } From d9cfa3cb1289042eb4a25137579c77a61d4bcdc5 Mon Sep 17 00:00:00 2001 From: Emir Beganovic Date: Tue, 30 Apr 2019 11:12:47 +0400 Subject: [PATCH 137/881] Update to latest go-mssqldb --- go.mod | 6 ++++- go.sum | 76 +++++++++++++++------------------------------------------- 2 files changed, 24 insertions(+), 58 deletions(-) diff --git a/go.mod b/go.mod index 89ca68d8..4f6671f5 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,13 @@ module github.com/jinzhu/gorm require ( - github.com/denisenkom/go-mssqldb v0.0.0-20190401154936-ce35bd87d4b3 + github.com/denisenkom/go-mssqldb v0.0.0-20190423183735-731ef375ac02 + github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 github.com/go-sql-driver/mysql v1.4.1 github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a + github.com/jinzhu/now v1.0.0 github.com/lib/pq v1.0.0 github.com/mattn/go-sqlite3 v1.10.0 + golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734 // indirect + google.golang.org/appengine v1.5.0 // indirect ) diff --git a/go.sum b/go.sum index a984e572..478c7353 100644 --- a/go.sum +++ b/go.sum @@ -1,124 +1,100 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.37.4 h1:glPeL3BQJsbF6aIIYfZizMwc5LTYz250bDMjttbBGAU= cloud.google.com/go v0.37.4/go.mod h1:NHPJ89PdicEuT9hdPXMROBD91xc5uRDxsMtSB16k7hw= -git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= -git.apache.org/thrift.git v0.12.0/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= -github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/denisenkom/go-mssqldb v0.0.0-20190401154936-ce35bd87d4b3/go.mod h1:EcO5fNtMZHCMjAvj8LE6T+5bphSdR6LQ75n+m1TtsFI= +github.com/denisenkom/go-mssqldb v0.0.0-20190423183735-731ef375ac02 h1:PS3xfVPa8N84AzoWZHFCbA0+ikz4f4skktfjQoNMsgk= +github.com/denisenkom/go-mssqldb v0.0.0-20190423183735-731ef375ac02/go.mod h1:zAg7JM8CkOJ43xKXIj7eRO9kmWm/TW578qo+oDO6tuM= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= -github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= +github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DPaBjB8zlTR87/ElzFsnQfuHnVUVqpZZIcV5Y= +github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= -github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= +github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= -github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= -github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= -github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= -github.com/grpc-ecosystem/grpc-gateway v1.6.2/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= -github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a h1:eeaG9XMUvRBYXJi4pg1ZKM7nxc5AfXfojeLLW7O5J3k= github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.0.0 h1:6WV8LvwPpDhKjo5U9O6b4+xdG/jTXNPwlDme/MTo8Ns= +github.com/jinzhu/now v1.0.0/go.mod h1:oHTiXerJ20+SfYcrdlBO7rzZRJWGwSTQ0iUY2jI6Gfc= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= -github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= -github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= -github.com/openzipkin/zipkin-go v0.1.3/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829/go.mod h1:p2iRAGwDERtqlqzRXnrOVns+ignqQo//hLXqYxZYVNs= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= -go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= -go.opencensus.io v0.19.1/go.mod h1:gug0GbSHa8Pafr0d2urOSgoXHZ6x/RUlaiT0d9pqb4A= -go.opencensus.io v0.19.2/go.mod h1:NO/8qkisMZLZ1FCsKNqtJPwc8/TaclWyY0B6wcYNg9M= -go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= -golang.org/x/build v0.0.0-20190314133821-5284462c4bec/go.mod h1:atTaCNAy0f16Ah5aV1gMSwgiKVHwu/JncqDpuRr7lS4= +go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734 h1:p/H982KKEjUnLJkM3tt/LemDnOc1GiZL5FCVlORJ5zo= +golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20181217174547-8f45f776aaf1/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -126,48 +102,34 @@ golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181218192612-074acd46bca6/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20181219222714-6e267b5cc78e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= -google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= -google.golang.org/api v0.0.0-20181220000619-583d854617af/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= -google.golang.org/api v0.2.0/go.mod h1:IfRCZScioGtypHNTlz3gFk67J8uePVW7uDTBzXuIkhU= -google.golang.org/api v0.3.0/go.mod h1:IuvZyQh8jgscv8qWfQ4ABd8m7hEudgBFM/EdhA3BnXw= +google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.5.0 h1:KxkO13IPW4Lslp2bz+KHP2E3gtFlrIGNThxkZQ3g+4c= +google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20181219182458-5a97ab628bfb/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= -google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= +google.golang.org/genproto v0.0.0-20190404172233-64821d5d2107/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= -gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20180920025451-e3ad64cb4ed3/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= From 8931d8a7c3ba54624f373a7bf5a4c9e1e2248465 Mon Sep 17 00:00:00 2001 From: Emir Beganovic Date: Tue, 30 Apr 2019 11:59:39 +0400 Subject: [PATCH 138/881] Update dependencies --- go.mod | 4 +--- go.sum | 13 ++++++------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/go.mod b/go.mod index 4f6671f5..3ec7aab0 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,6 @@ require ( github.com/go-sql-driver/mysql v1.4.1 github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a github.com/jinzhu/now v1.0.0 - github.com/lib/pq v1.0.0 + github.com/lib/pq v1.1.0 github.com/mattn/go-sqlite3 v1.10.0 - golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734 // indirect - google.golang.org/appengine v1.5.0 // indirect ) diff --git a/go.sum b/go.sum index 478c7353..848f7293 100644 --- a/go.sum +++ b/go.sum @@ -32,7 +32,6 @@ github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfb github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= @@ -50,8 +49,8 @@ github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7V github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= -github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A= -github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0 h1:/5u4a+KGJptBRqGzPvYQL9p0d/tPR4S31+Tnzj9lEO4= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= @@ -59,6 +58,7 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/openzipkin/zipkin-go v0.1.6 h1:yXiysv1CSK7Q5yjGy1710zZGnsbMUIjluWBxtLXHPBo= github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -77,9 +77,8 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734 h1:p/H982KKEjUnLJkM3tt/LemDnOc1GiZL5FCVlORJ5zo= -golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= @@ -92,7 +91,6 @@ golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -105,7 +103,6 @@ golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -123,6 +120,8 @@ google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRn google.golang.org/genproto v0.0.0-20190404172233-64821d5d2107/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.20.1 h1:Hz2g2wirWK7H0qIIhGIqRGTuMwTE8HEKFnDZZ7lm9NU= +google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From b00248862ac8ca12dd54274094d404007173ec2c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 2 May 2019 22:49:30 +0800 Subject: [PATCH 139/881] Enable codecov --- .codeclimate.yml | 11 ----------- .gitignore | 1 + .travis.yml | 14 ++++++++++++++ 3 files changed, 15 insertions(+), 11 deletions(-) delete mode 100644 .codeclimate.yml create mode 100644 .travis.yml diff --git a/.codeclimate.yml b/.codeclimate.yml deleted file mode 100644 index 51aba50c..00000000 --- a/.codeclimate.yml +++ /dev/null @@ -1,11 +0,0 @@ ---- -engines: - gofmt: - enabled: true - govet: - enabled: true - golint: - enabled: true -ratings: - paths: - - "**.go" diff --git a/.gitignore b/.gitignore index 01dc5ce0..117f92f5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ documents +coverage.txt _book diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..8ce9f587 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,14 @@ +language: go + +go: + - 1.12.x + - tip + +before_install: + - go get -t -v ./... + +script: + - go test -race -coverprofile=coverage.txt -covermode=atomic + +after_success: + - bash <(curl -s https://codecov.io/bash) From 12c3abcd450dd39da93e4224ddc4bf9a9195dc4c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 5 May 2019 14:20:51 +0800 Subject: [PATCH 140/881] Fix codeconv integration --- .travis.yml | 14 -------------- wercker.yml | 6 ++++++ 2 files changed, 6 insertions(+), 14 deletions(-) delete mode 100644 .travis.yml diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 8ce9f587..00000000 --- a/.travis.yml +++ /dev/null @@ -1,14 +0,0 @@ -language: go - -go: - - 1.12.x - - tip - -before_install: - - go get -t -v ./... - -script: - - go test -race -coverprofile=coverage.txt -covermode=atomic - -after_success: - - bash <(curl -s https://codecov.io/bash) diff --git a/wercker.yml b/wercker.yml index 98234583..35af18da 100644 --- a/wercker.yml +++ b/wercker.yml @@ -146,3 +146,9 @@ build: name: test mssql code: | GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test -race ./... + + - script: + name: codeconv + code: | + go test -race -coverprofile=coverage.txt -covermode=atomic ./... + bash <(curl -s https://codecov.io/bash) From f9944083aed7a2f81d4172154e9e5284054479ca Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 5 May 2019 14:32:23 +0800 Subject: [PATCH 141/881] Add codecov badge --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 0c5c7ea6..aec2d46d 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm) [![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b) +[![codecov](https://codecov.io/gh/jinzhu/gorm/branch/master/graph/badge.svg)](https://codecov.io/gh/jinzhu/gorm) [![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm) [![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm) From 50ec201b910b33f3ed5cad46a39873192fbcddad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Emir=20Beganovi=C4=87?= Date: Sun, 5 May 2019 10:47:14 +0400 Subject: [PATCH 142/881] Fix typo --- wercker.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wercker.yml b/wercker.yml index 35af18da..43a3e7ae 100644 --- a/wercker.yml +++ b/wercker.yml @@ -148,7 +148,7 @@ build: GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test -race ./... - script: - name: codeconv + name: codecov code: | go test -race -coverprofile=coverage.txt -covermode=atomic ./... bash <(curl -s https://codecov.io/bash) From 741cd60b1bd4bd4940e4813a1b301cc864647dd8 Mon Sep 17 00:00:00 2001 From: Emir Beganovic Date: Sun, 5 May 2019 11:24:26 +0400 Subject: [PATCH 143/881] Add test for keeping float precision --- main_test.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/main_test.go b/main_test.go index 1fa38b98..b3e87831 100644 --- a/main_test.go +++ b/main_test.go @@ -1164,6 +1164,20 @@ func TestCountWithQueryOption(t *testing.T) { } } +func TestFloatColumnPrecision(t *testing.T) { + type FloatTest struct { + ID string `gorm:"primary_key"` + FloatValue float64 `gorm:"column:float_value" sql:"type:float(255,5);"` + } + DB.DropTable(&FloatTest{}) + DB.AutoMigrate(&FloatTest{}) + + data := FloatTest{ID: "uuid", FloatValue: 112.57315} + if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.FloatValue != 112.57315 { + t.Errorf("Float value should not lose precision") + } +} + func BenchmarkGorm(b *testing.B) { b.N = 2000 for x := 0; x < b.N; x++ { From abe3fa8631a3726d37c4d3d497f6c1d2b698f90d Mon Sep 17 00:00:00 2001 From: Emir Beganovic Date: Sun, 5 May 2019 11:51:05 +0400 Subject: [PATCH 144/881] Run only on MySQL and sqlite --- main_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/main_test.go b/main_test.go index b3e87831..25b5940c 100644 --- a/main_test.go +++ b/main_test.go @@ -1165,6 +1165,10 @@ func TestCountWithQueryOption(t *testing.T) { } func TestFloatColumnPrecision(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" && dialect != "sqlite" { + t.Skip() + } + type FloatTest struct { ID string `gorm:"primary_key"` FloatValue float64 `gorm:"column:float_value" sql:"type:float(255,5);"` From 206174c932639e2f6807d94f9aff328772ec2d72 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 5 May 2019 16:23:52 +0800 Subject: [PATCH 145/881] Change gorm.io links to https --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index aec2d46d..6d231103 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm) [![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm) -[![MIT license](http://img.shields.io/badge/license-MIT-brightgreen.svg)](http://opensource.org/licenses/MIT) +[![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) [![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) ## Overview @@ -28,11 +28,11 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Getting Started -* GORM Guides [http://gorm.io](http://gorm.io) +* GORM Guides [https://gorm.io](https://gorm.io) ## Contributing -[You can help to deliver a better GORM, check out things you can do](http://gorm.io/contribute.html) +[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html) ## License From 394b3a1818b8912cc8f4a4eefeb7a0340ae9ad07 Mon Sep 17 00:00:00 2001 From: Emir Beganovic Date: Sun, 5 May 2019 13:12:03 +0400 Subject: [PATCH 146/881] Fixed nil error when first updates with struct --- main_test.go | 21 +++++++++++++++++++++ scope.go | 10 +++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/main_test.go b/main_test.go index 25b5940c..14bf34ac 100644 --- a/main_test.go +++ b/main_test.go @@ -1182,6 +1182,27 @@ func TestFloatColumnPrecision(t *testing.T) { } } +func TestWhereUpdates(t *testing.T) { + type OwnerEntity struct { + gorm.Model + OwnerID uint + OwnerType string + } + + type SomeEntity struct { + gorm.Model + Name string + OwnerEntity OwnerEntity `gorm:"polymorphic:Owner"` + } + + db := DB.Debug() + db.DropTable(&SomeEntity{}) + db.AutoMigrate(&SomeEntity{}) + + a := SomeEntity{Name: "test"} + db.Model(&a).Where(a).Updates(SomeEntity{Name: "test2"}) +} + func BenchmarkGorm(b *testing.B) { b.N = 2000 for x := 0; x < b.N; x++ { diff --git a/scope.go b/scope.go index c6c92d5a..9f8820eb 100644 --- a/scope.go +++ b/scope.go @@ -872,7 +872,7 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { return scope } -func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string]interface{} { +func convertInterfaceToMap(values interface{}, withIgnoredField bool, db *DB) map[string]interface{} { var attrs = map[string]interface{}{} switch value := values.(type) { @@ -880,7 +880,7 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string return value case []interface{}: for _, v := range value { - for key, value := range convertInterfaceToMap(v, withIgnoredField) { + for key, value := range convertInterfaceToMap(v, withIgnoredField, db) { attrs[key] = value } } @@ -893,7 +893,7 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() } default: - for _, field := range (&Scope{Value: values}).Fields() { + for _, field := range (&Scope{Value: values, db: db}).Fields() { if !field.IsBlank && (withIgnoredField || !field.IsIgnored) { attrs[field.DBName] = field.Field.Interface() } @@ -905,12 +905,12 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) { if scope.IndirectValue().Kind() != reflect.Struct { - return convertInterfaceToMap(value, false), true + return convertInterfaceToMap(value, false, scope.db), true } results = map[string]interface{}{} - for key, value := range convertInterfaceToMap(value, true) { + for key, value := range convertInterfaceToMap(value, true, scope.db) { if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { if _, ok := value.(*expr); ok { hasUpdate = true From 8b127471f1679b468cc13c5736fa401e16f664d1 Mon Sep 17 00:00:00 2001 From: John Barker Date: Wed, 1 May 2019 15:54:39 -0600 Subject: [PATCH 147/881] Pass logger into Callback{} so that logs are printed consistently --- callback.go | 19 +++++++++++-------- callback_system_test.go | 14 +++++++------- main.go | 2 +- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/callback.go b/callback.go index a4382147..f990097b 100644 --- a/callback.go +++ b/callback.go @@ -13,6 +13,7 @@ var DefaultCallback = &Callback{} // Field `rowQueries` contains callbacks will be call when querying object with Row, Rows... // Field `processors` contains all callback processors, will be used to generate above callbacks in order type Callback struct { + logger logger creates []*func(scope *Scope) updates []*func(scope *Scope) deletes []*func(scope *Scope) @@ -23,6 +24,7 @@ type Callback struct { // CallbackProcessor contains callback informations type CallbackProcessor struct { + logger logger name string // current callback's name before string // register current callback before a callback after string // register current callback after a callback @@ -33,8 +35,9 @@ type CallbackProcessor struct { parent *Callback } -func (c *Callback) clone() *Callback { +func (c *Callback) clone(logger logger) *Callback { return &Callback{ + logger: logger, creates: c.creates, updates: c.updates, deletes: c.deletes, @@ -53,28 +56,28 @@ func (c *Callback) clone() *Callback { // scope.Err(errors.New("error")) // }) func (c *Callback) Create() *CallbackProcessor { - return &CallbackProcessor{kind: "create", parent: c} + return &CallbackProcessor{logger: c.logger, kind: "create", parent: c} } // Update could be used to register callbacks for updating object, refer `Create` for usage func (c *Callback) Update() *CallbackProcessor { - return &CallbackProcessor{kind: "update", parent: c} + return &CallbackProcessor{logger: c.logger, kind: "update", parent: c} } // Delete could be used to register callbacks for deleting object, refer `Create` for usage func (c *Callback) Delete() *CallbackProcessor { - return &CallbackProcessor{kind: "delete", parent: c} + return &CallbackProcessor{logger: c.logger, kind: "delete", parent: c} } // Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`... // Refer `Create` for usage func (c *Callback) Query() *CallbackProcessor { - return &CallbackProcessor{kind: "query", parent: c} + return &CallbackProcessor{logger: c.logger, kind: "query", parent: c} } // RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage func (c *Callback) RowQuery() *CallbackProcessor { - return &CallbackProcessor{kind: "row_query", parent: c} + return &CallbackProcessor{logger: c.logger, kind: "row_query", parent: c} } // After insert a new callback after callback `callbackName`, refer `Callbacks.Create` @@ -93,7 +96,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { if cp.kind == "row_query" { if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { - log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName) + cp.logger.Print("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName) cp.before = "gorm:row_query" } } @@ -107,7 +110,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope * // Remove a registered callback // db.Callback().Create().Remove("gorm:update_time_stamp_when_create") func (cp *CallbackProcessor) Remove(callbackName string) { - log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()) + cp.logger.Print("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()) cp.name = callbackName cp.remove = true cp.parent.processors = append(cp.parent.processors, cp) diff --git a/callback_system_test.go b/callback_system_test.go index 13ca3f42..2482eda4 100644 --- a/callback_system_test.go +++ b/callback_system_test.go @@ -23,7 +23,7 @@ func afterCreate1(s *Scope) {} func afterCreate2(s *Scope) {} func TestRegisterCallback(t *testing.T) { - var callback = &Callback{} + var callback = &Callback{logger: defaultLogger} callback.Create().Register("before_create1", beforeCreate1) callback.Create().Register("before_create2", beforeCreate2) @@ -37,7 +37,7 @@ func TestRegisterCallback(t *testing.T) { } func TestRegisterCallbackWithOrder(t *testing.T) { - var callback1 = &Callback{} + var callback1 = &Callback{logger: defaultLogger} callback1.Create().Register("before_create1", beforeCreate1) callback1.Create().Register("create", create) callback1.Create().Register("after_create1", afterCreate1) @@ -46,7 +46,7 @@ func TestRegisterCallbackWithOrder(t *testing.T) { t.Errorf("register callback with order") } - var callback2 = &Callback{} + var callback2 = &Callback{logger: defaultLogger} callback2.Update().Register("create", create) callback2.Update().Before("create").Register("before_create1", beforeCreate1) @@ -60,7 +60,7 @@ func TestRegisterCallbackWithOrder(t *testing.T) { } func TestRegisterCallbackWithComplexOrder(t *testing.T) { - var callback1 = &Callback{} + var callback1 = &Callback{logger: defaultLogger} callback1.Query().Before("after_create1").After("before_create1").Register("create", create) callback1.Query().Register("before_create1", beforeCreate1) @@ -70,7 +70,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) { t.Errorf("register callback with order") } - var callback2 = &Callback{} + var callback2 = &Callback{logger: defaultLogger} callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) callback2.Delete().Before("create").Register("before_create1", beforeCreate1) @@ -86,7 +86,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) { func replaceCreate(s *Scope) {} func TestReplaceCallback(t *testing.T) { - var callback = &Callback{} + var callback = &Callback{logger: defaultLogger} callback.Create().Before("after_create1").After("before_create1").Register("create", create) callback.Create().Register("before_create1", beforeCreate1) @@ -99,7 +99,7 @@ func TestReplaceCallback(t *testing.T) { } func TestRemoveCallback(t *testing.T) { - var callback = &Callback{} + var callback = &Callback{logger: defaultLogger} callback.Create().Before("after_create1").After("before_create1").Register("create", create) callback.Create().Register("before_create1", beforeCreate1) diff --git a/main.go b/main.go index 16820353..079a380d 100644 --- a/main.go +++ b/main.go @@ -138,7 +138,7 @@ func (s *DB) Dialect() Dialect { // db.Callback().Create().Register("update_created_at", updateCreated) // Refer https://jinzhu.github.io/gorm/development.html#callbacks func (s *DB) Callback() *Callback { - s.parent.callbacks = s.parent.callbacks.clone() + s.parent.callbacks = s.parent.callbacks.clone(s.logger) return s.parent.callbacks } From 9692c599ad07b4178fd005e6649017d98a8871ad Mon Sep 17 00:00:00 2001 From: Emir Beganovic Date: Wed, 8 May 2019 10:23:31 +0400 Subject: [PATCH 148/881] Fix drop table error with table options --- scope.go | 2 +- scope_test.go | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/scope.go b/scope.go index 9f8820eb..4836196a 100644 --- a/scope.go +++ b/scope.go @@ -1194,7 +1194,7 @@ func (scope *Scope) createTable() *Scope { } func (scope *Scope) dropTable() *Scope { - scope.Raw(fmt.Sprintf("DROP TABLE %v%s", scope.QuotedTableName(), scope.getTableOptions())).Exec() + scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() return scope } diff --git a/scope_test.go b/scope_test.go index 3018f350..f7f1ed08 100644 --- a/scope_test.go +++ b/scope_test.go @@ -78,3 +78,16 @@ func TestFailedValuer(t *testing.T) { t.Errorf("The error should be returned from Valuer, but get %v", err) } } + +func TestDropTableWithTableOptions(t *testing.T) { + type UserWithOptions struct { + gorm.Model + } + DB.AutoMigrate(&UserWithOptions{}) + + DB = DB.Set("gorm:table_options", "CHARSET=utf8") + err := DB.DropTable(&UserWithOptions{}).Error + if err != nil { + t.Errorf("Table must be dropped, got error %s", err) + } +} From bb3c74467dacc7106f53b72be50025bac724f89f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Emir=20Beganovi=C4=87?= Date: Wed, 8 May 2019 10:26:49 +0400 Subject: [PATCH 149/881] Update two more places --- callback.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/callback.go b/callback.go index f990097b..4ffc2d62 100644 --- a/callback.go +++ b/callback.go @@ -123,7 +123,7 @@ func (cp *CallbackProcessor) Remove(callbackName string) { // scope.SetColumn("Updated", now) // }) func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { - log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()) + cp.logger.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()) cp.name = callbackName cp.processor = &callback cp.replace = true @@ -162,7 +162,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { for _, cp := range cps { // show warning message the callback name already exists if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { - log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) + cp.logger.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) } allNames = append(allNames, cp.name) } From 6c53214a2992d832c228db49a4f1e3992fce0475 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Emir=20Beganovi=C4=87?= Date: Wed, 8 May 2019 10:49:00 +0400 Subject: [PATCH 150/881] Use Print method --- callback.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/callback.go b/callback.go index 4ffc2d62..42ebc800 100644 --- a/callback.go +++ b/callback.go @@ -1,6 +1,9 @@ package gorm -import "log" +import ( + "fmt" + "log" +) // DefaultCallback default callbacks defined by gorm var DefaultCallback = &Callback{} @@ -96,7 +99,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { if cp.kind == "row_query" { if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { - cp.logger.Print("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName) + cp.logger.Print(fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)) cp.before = "gorm:row_query" } } @@ -110,7 +113,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope * // Remove a registered callback // db.Callback().Create().Remove("gorm:update_time_stamp_when_create") func (cp *CallbackProcessor) Remove(callbackName string) { - cp.logger.Print("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()) + cp.logger.Print(fmt.Sprintf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())) cp.name = callbackName cp.remove = true cp.parent.processors = append(cp.parent.processors, cp) @@ -123,7 +126,7 @@ func (cp *CallbackProcessor) Remove(callbackName string) { // scope.SetColumn("Updated", now) // }) func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { - cp.logger.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()) + cp.logger.Print(fmt.Sprintf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())) cp.name = callbackName cp.processor = &callback cp.replace = true @@ -162,7 +165,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { for _, cp := range cps { // show warning message the callback name already exists if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { - cp.logger.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) + cp.logger.Print(fmt.Sprintf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())) } allNames = append(allNames, cp.name) } From 985c3a174ea165d89f5111a16382e9b079099653 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Emir=20Beganovi=C4=87?= Date: Wed, 8 May 2019 10:49:33 +0400 Subject: [PATCH 151/881] Remove unused import --- callback.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/callback.go b/callback.go index 42ebc800..6f60511b 100644 --- a/callback.go +++ b/callback.go @@ -1,9 +1,6 @@ package gorm -import ( - "fmt" - "log" -) +import "fmt" // DefaultCallback default callbacks defined by gorm var DefaultCallback = &Callback{} From 62197e576dcd1509eabab9ac9567d6a63d325688 Mon Sep 17 00:00:00 2001 From: Miguel Moll Date: Mon, 10 Jun 2019 08:12:13 -0400 Subject: [PATCH 152/881] Handle error when beginning transaction (#2489) --- scope.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope.go b/scope.go index 4836196a..0e639c70 100644 --- a/scope.go +++ b/scope.go @@ -402,7 +402,7 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) { // Begin start a transaction func (scope *Scope) Begin() *Scope { if db, ok := scope.SQLDB().(sqlDb); ok { - if tx, err := db.Begin(); err == nil { + if tx, err := db.Begin(); scope.Err(err) == nil { scope.db.db = interface{}(tx).(SQLCommon) scope.InstanceSet("gorm:started_transaction", true) } From ea124001902dfe81503bb8192bc397087e951072 Mon Sep 17 00:00:00 2001 From: John Barker Date: Mon, 10 Jun 2019 06:14:44 -0600 Subject: [PATCH 153/881] Don't AddError for Rollback on ErrTxDone (#2434) --- main.go | 4 +++- main_test.go | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index 079a380d..02d67440 100644 --- a/main.go +++ b/main.go @@ -533,7 +533,9 @@ func (s *DB) Commit() *DB { func (s *DB) Rollback() *DB { var emptySQLTx *sql.Tx if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - s.AddError(db.Rollback()) + if err := db.Rollback(); err != nil && err != sql.ErrTxDone { + s.AddError(err) + } } else { s.AddError(ErrInvalidTransaction) } diff --git a/main_test.go b/main_test.go index 14bf34ac..3d922dda 100644 --- a/main_test.go +++ b/main_test.go @@ -421,6 +421,22 @@ func TestTransaction(t *testing.T) { } } +func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) { + tx := DB.Begin() + u := User{Name: "transcation"} + if err := tx.Save(&u).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.Commit().Error; err != nil { + t.Errorf("Commit should not raise error") + } + + if err := tx.Rollback().Error; err != nil { + t.Errorf("Rollback should not raise error") + } +} + func TestRow(t *testing.T) { user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")} user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")} From 44d3060254255c13412b2741c227b1a962984561 Mon Sep 17 00:00:00 2001 From: Adam S Levy Date: Mon, 10 Jun 2019 04:19:39 -0800 Subject: [PATCH 154/881] Add RollbackUnlessCommitted() (#2126) --- main.go | 17 +++++++++++++++++ main_test.go | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/main.go b/main.go index 02d67440..906b7f41 100644 --- a/main.go +++ b/main.go @@ -542,6 +542,23 @@ func (s *DB) Rollback() *DB { return s } +// RollbackUnlessCommitted rollback a transaction if it has not yet been +// committed. +func (s *DB) RollbackUnlessCommitted() *DB { + var emptySQLTx *sql.Tx + if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { + err := db.Rollback() + // Ignore the error indicating that the transaction has already + // been committed. + if err != sql.ErrTxDone { + s.AddError(err) + } + } else { + s.AddError(ErrInvalidTransaction) + } + return s +} + // NewRecord check if value's primary key is blank func (s *DB) NewRecord(value interface{}) bool { return s.NewScope(value).PrimaryKeyZero() diff --git a/main_test.go b/main_test.go index 3d922dda..ee038cac 100644 --- a/main_test.go +++ b/main_test.go @@ -419,6 +419,40 @@ func TestTransaction(t *testing.T) { if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { t.Errorf("Should be able to find committed record") } + + tx3 := DB.Begin() + u3 := User{Name: "transcation-3"} + if err := tx3.Save(&u3).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx3.First(&User{}, "name = ?", "transcation-3").Error; err != nil { + t.Errorf("Should find saved record") + } + + tx3.RollbackUnlessCommitted() + + if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { + t.Errorf("Should not find record after rollback") + } + + tx4 := DB.Begin() + u4 := User{Name: "transcation-4"} + if err := tx4.Save(&u4).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx4.First(&User{}, "name = ?", "transcation-4").Error; err != nil { + t.Errorf("Should find saved record") + } + + tx4.Commit() + + tx4.RollbackUnlessCommitted() + + if err := DB.First(&User{}, "name = ?", "transcation-4").Error; err != nil { + t.Errorf("Should be able to find committed record") + } } func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) { From ac78f05986ab456936afd148e629533d8d819289 Mon Sep 17 00:00:00 2001 From: Hylke Visser Date: Mon, 10 Jun 2019 14:24:05 +0200 Subject: [PATCH 155/881] Don't set primary key's HasDefaultValue to true (#2127) --- model_struct.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_struct.go b/model_struct.go index bfab49c0..5234b287 100644 --- a/model_struct.go +++ b/model_struct.go @@ -202,7 +202,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) } - if _, ok := field.TagSettingsGet("DEFAULT"); ok { + if _, ok := field.TagSettingsGet("DEFAULT"); ok && !field.IsPrimaryKey { field.HasDefaultValue = true } From af01854d3ecae994322b18d71cafdec114de9d81 Mon Sep 17 00:00:00 2001 From: Tyler Stillwater Date: Mon, 10 Jun 2019 06:33:20 -0600 Subject: [PATCH 156/881] Add BeginTx for parity with sql.DB.BeginTx (#2227) --- interface.go | 6 +++++- main.go | 10 ++++++++-- main_test.go | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/interface.go b/interface.go index 55128f7f..fe649231 100644 --- a/interface.go +++ b/interface.go @@ -1,6 +1,9 @@ package gorm -import "database/sql" +import ( + "context" + "database/sql" +) // SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. type SQLCommon interface { @@ -12,6 +15,7 @@ type SQLCommon interface { type sqlDb interface { Begin() (*sql.Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } type sqlTx interface { diff --git a/main.go b/main.go index 906b7f41..994d1618 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "database/sql" "errors" "fmt" @@ -503,11 +504,16 @@ func (s *DB) Debug() *DB { return s.clone().LogMode(true) } -// Begin begin a transaction +// Begin begins a transaction func (s *DB) Begin() *DB { + return s.BeginTx(context.Background(), &sql.TxOptions{}) +} + +// BeginTX begins a transaction with options +func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB { c := s.clone() if db, ok := c.db.(sqlDb); ok && db != nil { - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, opts) c.db = interface{}(tx).(SQLCommon) c.dialect.SetDB(c.db) diff --git a/main_test.go b/main_test.go index ee038cac..81ecf0fe 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,7 @@ package gorm_test import ( + "context" "database/sql" "database/sql/driver" "fmt" @@ -471,6 +472,40 @@ func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) { } } +func TestTransactionReadonly(t *testing.T) { + dialect := os.Getenv("GORM_DIALECT") + if dialect == "" { + dialect = "sqlite" + } + switch dialect { + case "mssql", "sqlite": + t.Skipf("%s does not support readonly transactions\n", dialect) + } + + tx := DB.Begin() + u := User{Name: "transcation"} + if err := tx.Save(&u).Error; err != nil { + t.Errorf("No error should raise") + } + tx.Commit() + + tx = DB.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) + if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { + t.Errorf("Should find saved record") + } + + if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil { + t.Errorf("Should return the underlying sql.Tx") + } + + u = User{Name: "transcation-2"} + if err := tx.Save(&u).Error; err == nil { + t.Errorf("Error should have been raised in a readonly transaction") + } + + tx.Rollback() +} + func TestRow(t *testing.T) { user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")} user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")} From 712c4655605f094d283047501ae613db9c798850 Mon Sep 17 00:00:00 2001 From: Ruben de Vries Date: Mon, 10 Jun 2019 14:45:42 +0200 Subject: [PATCH 157/881] add an override on the DB instance instead of using the global NowFunc. (#2142) --- callback_create.go | 4 ++-- callback_delete.go | 2 +- callback_query.go | 2 +- callback_update.go | 2 +- create_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ main.go | 20 ++++++++++++++++++++ scope.go | 6 +++--- 7 files changed, 68 insertions(+), 8 deletions(-) diff --git a/callback_create.go b/callback_create.go index 763a2dfd..87aba8ee 100644 --- a/callback_create.go +++ b/callback_create.go @@ -31,7 +31,7 @@ func beforeCreateCallback(scope *Scope) { // updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating func updateTimeStampForCreateCallback(scope *Scope) { if !scope.HasError() { - now := NowFunc() + now := scope.db.nowFunc() if createdAtField, ok := scope.FieldByName("CreatedAt"); ok { if createdAtField.IsBlank { @@ -50,7 +50,7 @@ func updateTimeStampForCreateCallback(scope *Scope) { // createCallback the callback used to insert data into database func createCallback(scope *Scope) { if !scope.HasError() { - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) var ( columns, placeholders []string diff --git a/callback_delete.go b/callback_delete.go index 73d90880..50242e48 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -40,7 +40,7 @@ func deleteCallback(scope *Scope) { "UPDATE %v SET %v=%v%v%v", scope.QuotedTableName(), scope.Quote(deletedAtField.DBName), - scope.AddToVars(NowFunc()), + scope.AddToVars(scope.db.nowFunc()), addExtraSpaceIfExist(scope.CombinedConditionSql()), addExtraSpaceIfExist(extraOption), )).Exec() diff --git a/callback_query.go b/callback_query.go index 7facc42b..e3b3d534 100644 --- a/callback_query.go +++ b/callback_query.go @@ -24,7 +24,7 @@ func queryCallback(scope *Scope) { return } - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) var ( isSlice, isPtr bool diff --git a/callback_update.go b/callback_update.go index c52162c8..56711d37 100644 --- a/callback_update.go +++ b/callback_update.go @@ -50,7 +50,7 @@ func beforeUpdateCallback(scope *Scope) { // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating func updateTimeStampForUpdateCallback(scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { - scope.SetColumn("UpdatedAt", NowFunc()) + scope.SetColumn("UpdatedAt", scope.db.nowFunc()) } } diff --git a/create_test.go b/create_test.go index 450dd8a4..c80bdcbb 100644 --- a/create_test.go +++ b/create_test.go @@ -101,6 +101,46 @@ func TestCreateWithExistingTimestamp(t *testing.T) { } } +func TestCreateWithNowFuncOverride(t *testing.T) { + user1 := User{Name: "CreateUserTimestampOverride"} + + timeA := now.MustParse("2016-01-01") + + // do DB.New() because we don't want this test to affect other tests + db1 := DB.New() + // set the override to use static timeA + db1.SetNowFuncOverride(func() time.Time { + return timeA + }) + // call .New again to check the override is carried over as well during clone + db1 = db1.New() + + db1.Save(&user1) + + if user1.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { + t.Errorf("CreatedAt be using the nowFuncOverride") + } + if user1.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { + t.Errorf("UpdatedAt be using the nowFuncOverride") + } + + // now create another user with a fresh DB.Now() that doesn't have the nowFuncOverride set + // to make sure that setting it only affected the above instance + + user2 := User{Name: "CreateUserTimestampOverrideNoMore"} + + db2 := DB.New() + + db2.Save(&user2) + + if user2.CreatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) { + t.Errorf("CreatedAt no longer be using the nowFuncOverride") + } + if user2.UpdatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) { + t.Errorf("UpdatedAt no longer be using the nowFuncOverride") + } +} + type AutoIncrementUser struct { User Sequence uint `gorm:"AUTO_INCREMENT"` diff --git a/main.go b/main.go index 994d1618..1316dbd3 100644 --- a/main.go +++ b/main.go @@ -31,6 +31,9 @@ type DB struct { callbacks *Callback dialect Dialect singularTable bool + + // function to be used to override the creating of a new timestamp + nowFuncOverride func() time.Time } type logModeValue int @@ -158,6 +161,22 @@ func (s *DB) LogMode(enable bool) *DB { return s } +// SetNowFuncOverride set the function to be used when creating a new timestamp +func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB { + s.nowFuncOverride = nowFuncOverride + return s +} + +// Get a new timestamp, using the provided nowFuncOverride on the DB instance if set, +// otherwise defaults to the global NowFunc() +func (s *DB) nowFunc() time.Time { + if s.nowFuncOverride != nil { + return s.nowFuncOverride() + } + + return NowFunc() +} + // BlockGlobalUpdate if true, generates an error on update/delete without where clause. // This is to prevent eventual error with empty objects updates/deletions func (s *DB) BlockGlobalUpdate(enable bool) *DB { @@ -800,6 +819,7 @@ func (s *DB) clone() *DB { Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate, dialect: newDialect(s.dialect.GetName(), s.db), + nowFuncOverride: s.nowFuncOverride, } s.values.Range(func(k, v interface{}) bool { diff --git a/scope.go b/scope.go index 0e639c70..c962c165 100644 --- a/scope.go +++ b/scope.go @@ -358,7 +358,7 @@ func (scope *Scope) Raw(sql string) *Scope { // Exec perform generated SQL func (scope *Scope) Exec() *Scope { - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) if !scope.HasError() { if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { @@ -932,7 +932,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin } func (scope *Scope) row() *sql.Row { - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) result := &RowQueryResult{} scope.InstanceSet("row_query_result", result) @@ -942,7 +942,7 @@ func (scope *Scope) row() *sql.Row { } func (scope *Scope) rows() (*sql.Rows, error) { - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) result := &RowsQueryResult{} scope.InstanceSet("row_query_result", result) From 9127f7d86e13ff8e57a784b42ab76e0b86e5edf9 Mon Sep 17 00:00:00 2001 From: Miguel Moll Date: Mon, 10 Jun 2019 08:12:13 -0400 Subject: [PATCH 158/881] Handle error when beginning transaction (#2489) --- scope.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope.go b/scope.go index 4836196a..0e639c70 100644 --- a/scope.go +++ b/scope.go @@ -402,7 +402,7 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) { // Begin start a transaction func (scope *Scope) Begin() *Scope { if db, ok := scope.SQLDB().(sqlDb); ok { - if tx, err := db.Begin(); err == nil { + if tx, err := db.Begin(); scope.Err(err) == nil { scope.db.db = interface{}(tx).(SQLCommon) scope.InstanceSet("gorm:started_transaction", true) } From 280dd011a14b84dd8618aed0995fe08e270cb1c2 Mon Sep 17 00:00:00 2001 From: John Barker Date: Mon, 10 Jun 2019 06:14:44 -0600 Subject: [PATCH 159/881] Don't AddError for Rollback on ErrTxDone (#2434) --- main.go | 4 +++- main_test.go | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index 3b058231..6bd006d7 100644 --- a/main.go +++ b/main.go @@ -533,7 +533,9 @@ func (s *DB) Commit() *DB { func (s *DB) Rollback() *DB { var emptySQLTx *sql.Tx if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - s.AddError(db.Rollback()) + if err := db.Rollback(); err != nil && err != sql.ErrTxDone { + s.AddError(err) + } } else { s.AddError(ErrInvalidTransaction) } diff --git a/main_test.go b/main_test.go index 14bf34ac..3d922dda 100644 --- a/main_test.go +++ b/main_test.go @@ -421,6 +421,22 @@ func TestTransaction(t *testing.T) { } } +func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) { + tx := DB.Begin() + u := User{Name: "transcation"} + if err := tx.Save(&u).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.Commit().Error; err != nil { + t.Errorf("Commit should not raise error") + } + + if err := tx.Rollback().Error; err != nil { + t.Errorf("Rollback should not raise error") + } +} + func TestRow(t *testing.T) { user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")} user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")} From f301f86e295525aebe0ae2306e08d8fc576afc2e Mon Sep 17 00:00:00 2001 From: Adam S Levy Date: Mon, 10 Jun 2019 04:19:39 -0800 Subject: [PATCH 160/881] Add RollbackUnlessCommitted() (#2126) --- main.go | 17 +++++++++++++++++ main_test.go | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/main.go b/main.go index 6bd006d7..9bebe6f9 100644 --- a/main.go +++ b/main.go @@ -542,6 +542,23 @@ func (s *DB) Rollback() *DB { return s } +// RollbackUnlessCommitted rollback a transaction if it has not yet been +// committed. +func (s *DB) RollbackUnlessCommitted() *DB { + var emptySQLTx *sql.Tx + if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { + err := db.Rollback() + // Ignore the error indicating that the transaction has already + // been committed. + if err != sql.ErrTxDone { + s.AddError(err) + } + } else { + s.AddError(ErrInvalidTransaction) + } + return s +} + // NewRecord check if value's primary key is blank func (s *DB) NewRecord(value interface{}) bool { return s.NewScope(value).PrimaryKeyZero() diff --git a/main_test.go b/main_test.go index 3d922dda..ee038cac 100644 --- a/main_test.go +++ b/main_test.go @@ -419,6 +419,40 @@ func TestTransaction(t *testing.T) { if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { t.Errorf("Should be able to find committed record") } + + tx3 := DB.Begin() + u3 := User{Name: "transcation-3"} + if err := tx3.Save(&u3).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx3.First(&User{}, "name = ?", "transcation-3").Error; err != nil { + t.Errorf("Should find saved record") + } + + tx3.RollbackUnlessCommitted() + + if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { + t.Errorf("Should not find record after rollback") + } + + tx4 := DB.Begin() + u4 := User{Name: "transcation-4"} + if err := tx4.Save(&u4).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx4.First(&User{}, "name = ?", "transcation-4").Error; err != nil { + t.Errorf("Should find saved record") + } + + tx4.Commit() + + tx4.RollbackUnlessCommitted() + + if err := DB.First(&User{}, "name = ?", "transcation-4").Error; err != nil { + t.Errorf("Should be able to find committed record") + } } func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) { From cf9b85ed90acf96933b70e8bae0e4dc28a0f9687 Mon Sep 17 00:00:00 2001 From: Hylke Visser Date: Mon, 10 Jun 2019 14:24:05 +0200 Subject: [PATCH 161/881] Don't set primary key's HasDefaultValue to true (#2127) --- model_struct.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_struct.go b/model_struct.go index bfab49c0..5234b287 100644 --- a/model_struct.go +++ b/model_struct.go @@ -202,7 +202,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) } - if _, ok := field.TagSettingsGet("DEFAULT"); ok { + if _, ok := field.TagSettingsGet("DEFAULT"); ok && !field.IsPrimaryKey { field.HasDefaultValue = true } From fec06da6a3120c30068765b8959b2d6bf36a50e6 Mon Sep 17 00:00:00 2001 From: Tyler Stillwater Date: Mon, 10 Jun 2019 06:33:20 -0600 Subject: [PATCH 162/881] Add BeginTx for parity with sql.DB.BeginTx (#2227) --- interface.go | 6 +++++- main.go | 10 ++++++++-- main_test.go | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/interface.go b/interface.go index 55128f7f..fe649231 100644 --- a/interface.go +++ b/interface.go @@ -1,6 +1,9 @@ package gorm -import "database/sql" +import ( + "context" + "database/sql" +) // SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. type SQLCommon interface { @@ -12,6 +15,7 @@ type SQLCommon interface { type sqlDb interface { Begin() (*sql.Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } type sqlTx interface { diff --git a/main.go b/main.go index 9bebe6f9..3093ec80 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "database/sql" "errors" "fmt" @@ -503,11 +504,16 @@ func (s *DB) Debug() *DB { return s.clone().LogMode(true) } -// Begin begin a transaction +// Begin begins a transaction func (s *DB) Begin() *DB { + return s.BeginTx(context.Background(), &sql.TxOptions{}) +} + +// BeginTX begins a transaction with options +func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB { c := s.clone() if db, ok := c.db.(sqlDb); ok && db != nil { - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, opts) c.db = interface{}(tx).(SQLCommon) c.dialect.SetDB(c.db) diff --git a/main_test.go b/main_test.go index ee038cac..81ecf0fe 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,7 @@ package gorm_test import ( + "context" "database/sql" "database/sql/driver" "fmt" @@ -471,6 +472,40 @@ func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) { } } +func TestTransactionReadonly(t *testing.T) { + dialect := os.Getenv("GORM_DIALECT") + if dialect == "" { + dialect = "sqlite" + } + switch dialect { + case "mssql", "sqlite": + t.Skipf("%s does not support readonly transactions\n", dialect) + } + + tx := DB.Begin() + u := User{Name: "transcation"} + if err := tx.Save(&u).Error; err != nil { + t.Errorf("No error should raise") + } + tx.Commit() + + tx = DB.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) + if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { + t.Errorf("Should find saved record") + } + + if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil { + t.Errorf("Should return the underlying sql.Tx") + } + + u = User{Name: "transcation-2"} + if err := tx.Save(&u).Error; err == nil { + t.Errorf("Error should have been raised in a readonly transaction") + } + + tx.Rollback() +} + func TestRow(t *testing.T) { user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")} user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")} From c44c6027fb2e96a42b290bc73975efe933a6c44d Mon Sep 17 00:00:00 2001 From: Ruben de Vries Date: Mon, 10 Jun 2019 14:45:42 +0200 Subject: [PATCH 163/881] add an override on the DB instance instead of using the global NowFunc. (#2142) --- callback_create.go | 4 ++-- callback_delete.go | 2 +- callback_query.go | 2 +- callback_update.go | 2 +- create_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ main.go | 20 ++++++++++++++++++++ scope.go | 6 +++--- 7 files changed, 68 insertions(+), 8 deletions(-) diff --git a/callback_create.go b/callback_create.go index 763a2dfd..87aba8ee 100644 --- a/callback_create.go +++ b/callback_create.go @@ -31,7 +31,7 @@ func beforeCreateCallback(scope *Scope) { // updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating func updateTimeStampForCreateCallback(scope *Scope) { if !scope.HasError() { - now := NowFunc() + now := scope.db.nowFunc() if createdAtField, ok := scope.FieldByName("CreatedAt"); ok { if createdAtField.IsBlank { @@ -50,7 +50,7 @@ func updateTimeStampForCreateCallback(scope *Scope) { // createCallback the callback used to insert data into database func createCallback(scope *Scope) { if !scope.HasError() { - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) var ( columns, placeholders []string diff --git a/callback_delete.go b/callback_delete.go index 73d90880..50242e48 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -40,7 +40,7 @@ func deleteCallback(scope *Scope) { "UPDATE %v SET %v=%v%v%v", scope.QuotedTableName(), scope.Quote(deletedAtField.DBName), - scope.AddToVars(NowFunc()), + scope.AddToVars(scope.db.nowFunc()), addExtraSpaceIfExist(scope.CombinedConditionSql()), addExtraSpaceIfExist(extraOption), )).Exec() diff --git a/callback_query.go b/callback_query.go index 7facc42b..e3b3d534 100644 --- a/callback_query.go +++ b/callback_query.go @@ -24,7 +24,7 @@ func queryCallback(scope *Scope) { return } - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) var ( isSlice, isPtr bool diff --git a/callback_update.go b/callback_update.go index c52162c8..56711d37 100644 --- a/callback_update.go +++ b/callback_update.go @@ -50,7 +50,7 @@ func beforeUpdateCallback(scope *Scope) { // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating func updateTimeStampForUpdateCallback(scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { - scope.SetColumn("UpdatedAt", NowFunc()) + scope.SetColumn("UpdatedAt", scope.db.nowFunc()) } } diff --git a/create_test.go b/create_test.go index 450dd8a4..c80bdcbb 100644 --- a/create_test.go +++ b/create_test.go @@ -101,6 +101,46 @@ func TestCreateWithExistingTimestamp(t *testing.T) { } } +func TestCreateWithNowFuncOverride(t *testing.T) { + user1 := User{Name: "CreateUserTimestampOverride"} + + timeA := now.MustParse("2016-01-01") + + // do DB.New() because we don't want this test to affect other tests + db1 := DB.New() + // set the override to use static timeA + db1.SetNowFuncOverride(func() time.Time { + return timeA + }) + // call .New again to check the override is carried over as well during clone + db1 = db1.New() + + db1.Save(&user1) + + if user1.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { + t.Errorf("CreatedAt be using the nowFuncOverride") + } + if user1.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { + t.Errorf("UpdatedAt be using the nowFuncOverride") + } + + // now create another user with a fresh DB.Now() that doesn't have the nowFuncOverride set + // to make sure that setting it only affected the above instance + + user2 := User{Name: "CreateUserTimestampOverrideNoMore"} + + db2 := DB.New() + + db2.Save(&user2) + + if user2.CreatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) { + t.Errorf("CreatedAt no longer be using the nowFuncOverride") + } + if user2.UpdatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) { + t.Errorf("UpdatedAt no longer be using the nowFuncOverride") + } +} + type AutoIncrementUser struct { User Sequence uint `gorm:"AUTO_INCREMENT"` diff --git a/main.go b/main.go index 3093ec80..ec84906b 100644 --- a/main.go +++ b/main.go @@ -31,6 +31,9 @@ type DB struct { callbacks *Callback dialect Dialect singularTable bool + + // function to be used to override the creating of a new timestamp + nowFuncOverride func() time.Time } type logModeValue int @@ -158,6 +161,22 @@ func (s *DB) LogMode(enable bool) *DB { return s } +// SetNowFuncOverride set the function to be used when creating a new timestamp +func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB { + s.nowFuncOverride = nowFuncOverride + return s +} + +// Get a new timestamp, using the provided nowFuncOverride on the DB instance if set, +// otherwise defaults to the global NowFunc() +func (s *DB) nowFunc() time.Time { + if s.nowFuncOverride != nil { + return s.nowFuncOverride() + } + + return NowFunc() +} + // BlockGlobalUpdate if true, generates an error on update/delete without where clause. // This is to prevent eventual error with empty objects updates/deletions func (s *DB) BlockGlobalUpdate(enable bool) *DB { @@ -800,6 +819,7 @@ func (s *DB) clone() *DB { Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate, dialect: newDialect(s.dialect.GetName(), s.db), + nowFuncOverride: s.nowFuncOverride, } s.values.Range(func(k, v interface{}) bool { diff --git a/scope.go b/scope.go index 0e639c70..c962c165 100644 --- a/scope.go +++ b/scope.go @@ -358,7 +358,7 @@ func (scope *Scope) Raw(sql string) *Scope { // Exec perform generated SQL func (scope *Scope) Exec() *Scope { - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) if !scope.HasError() { if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { @@ -932,7 +932,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin } func (scope *Scope) row() *sql.Row { - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) result := &RowQueryResult{} scope.InstanceSet("row_query_result", result) @@ -942,7 +942,7 @@ func (scope *Scope) row() *sql.Row { } func (scope *Scope) rows() (*sql.Rows, error) { - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) result := &RowsQueryResult{} scope.InstanceSet("row_query_result", result) From 153ce22c99edba93882f1a2352f412edd966e8ea Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 11 Jun 2019 17:30:14 +0800 Subject: [PATCH 164/881] Test Save with specfied table name --- main.go | 2 +- main_test.go | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/main.go b/main.go index ec84906b..e24638a6 100644 --- a/main.go +++ b/main.go @@ -466,7 +466,7 @@ func (s *DB) Save(value interface{}) *DB { if !scope.PrimaryKeyZero() { newDB := scope.callCallbacks(s.parent.callbacks.updates).db if newDB.Error == nil && newDB.RowsAffected == 0 { - return s.FirstOrCreate(value) + return s.New().Table(scope.TableName()).FirstOrCreate(value) } return newDB } diff --git a/main_test.go b/main_test.go index 81ecf0fe..35474cf3 100644 --- a/main_test.go +++ b/main_test.go @@ -44,13 +44,13 @@ func OpenTestConnection() (db *gorm.DB, err error) { case "mysql": fmt.Println("testing mysql...") if dbDSN == "" { - dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" + dbDSN = "gorm:gorm@tcp(localhost:3306)/gorm?charset=utf8&parseTime=True" } db, err = gorm.Open("mysql", dbDSN) case "postgres": fmt.Println("testing postgres...") if dbDSN == "" { - dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" + dbDSN = "user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" } db, err = gorm.Open("postgres", dbDSN) case "mssql": @@ -61,7 +61,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { // sp_changedbowner 'gorm'; fmt.Println("testing mssql...") if dbDSN == "" { - dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm" } db, err = gorm.Open("mssql", dbDSN) default: @@ -178,6 +178,15 @@ func TestSetTable(t *testing.T) { t.Errorf("Query from specified table") } + var user User + DB.Table("deleted_users").First(&user, "name = ?", "DeletedUser") + + user.Age = 20 + DB.Table("deleted_users").Save(&user) + if DB.Table("deleted_users").First(&user, "name = ? AND age = ?", "DeletedUser", 20).RecordNotFound() { + t.Errorf("Failed to found updated user") + } + DB.Save(getPreparedUser("normal_user", "reset_table")) DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table")) var user1, user2, user3 User From 781a8183906a286ba46024ffe2cf94f957acffa4 Mon Sep 17 00:00:00 2001 From: Momo733 <1550526230@qq.com> Date: Sat, 13 Apr 2019 14:23:35 +0800 Subject: [PATCH 165/881] fix save err when specify a table name s.New() will clear all search conditions and search value,when I use Table() to set a table name. Then FirstOrCreate() will use struct name as my database table name,so It doesn't work. --- main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.go b/main.go index 1316dbd3..ec84906b 100644 --- a/main.go +++ b/main.go @@ -466,7 +466,7 @@ func (s *DB) Save(value interface{}) *DB { if !scope.PrimaryKeyZero() { newDB := scope.callCallbacks(s.parent.callbacks.updates).db if newDB.Error == nil && newDB.RowsAffected == 0 { - return s.New().FirstOrCreate(value) + return s.FirstOrCreate(value) } return newDB } From ff430cad49df63e2758d1bbd4a7c0048a57cabfd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Jun 2019 11:21:13 +0800 Subject: [PATCH 166/881] Update tests --- main_test.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/main_test.go b/main_test.go index 35474cf3..46b3e7a6 100644 --- a/main_test.go +++ b/main_test.go @@ -1,5 +1,9 @@ package gorm_test +// Run tests +// $ docker-compose up +// $ ./test_all.sh + import ( "context" "database/sql" @@ -44,13 +48,13 @@ func OpenTestConnection() (db *gorm.DB, err error) { case "mysql": fmt.Println("testing mysql...") if dbDSN == "" { - dbDSN = "gorm:gorm@tcp(localhost:3306)/gorm?charset=utf8&parseTime=True" + dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" } db, err = gorm.Open("mysql", dbDSN) case "postgres": fmt.Println("testing postgres...") if dbDSN == "" { - dbDSN = "user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" + dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" } db, err = gorm.Open("postgres", dbDSN) case "mssql": @@ -61,7 +65,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { // sp_changedbowner 'gorm'; fmt.Println("testing mssql...") if dbDSN == "" { - dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm" + dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" } db, err = gorm.Open("mssql", dbDSN) default: From 835ca6ca93ee96ac7967c22dfd0ee030810db604 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Jun 2019 11:48:19 +0800 Subject: [PATCH 167/881] Update wercker.yml to include mysql 8 --- wercker.yml | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/wercker.yml b/wercker.yml index 43a3e7ae..c74fa4d4 100644 --- a/wercker.yml +++ b/wercker.yml @@ -9,6 +9,13 @@ services: MYSQL_USER: gorm MYSQL_PASSWORD: gorm MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql + id: mysql:latest + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" - name: mysql57 id: mysql:5.7 env: @@ -23,13 +30,6 @@ services: MYSQL_USER: gorm MYSQL_PASSWORD: gorm MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql55 - id: mysql:5.5 - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - name: postgres id: postgres:latest env: @@ -102,6 +102,11 @@ build: code: | GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... + - script: + name: test mysql + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... + - script: name: test mysql5.7 code: | @@ -112,11 +117,6 @@ build: code: | GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - - script: - name: test mysql5.5 - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - - script: name: test postgres code: | From 5acd5e20e684478441ac08a3b1e4a622451d5fb9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Jun 2019 12:20:11 +0800 Subject: [PATCH 168/881] Remove Debug mode from test code --- main_test.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/main_test.go b/main_test.go index 46b3e7a6..68bf7419 100644 --- a/main_test.go +++ b/main_test.go @@ -1293,12 +1293,11 @@ func TestWhereUpdates(t *testing.T) { OwnerEntity OwnerEntity `gorm:"polymorphic:Owner"` } - db := DB.Debug() - db.DropTable(&SomeEntity{}) - db.AutoMigrate(&SomeEntity{}) + DB.DropTable(&SomeEntity{}) + DB.AutoMigrate(&SomeEntity{}) a := SomeEntity{Name: "test"} - db.Model(&a).Where(a).Updates(SomeEntity{Name: "test2"}) + DB.Model(&a).Where(a).Updates(SomeEntity{Name: "test2"}) } func BenchmarkGorm(b *testing.B) { From 01b66011427614f01e84a473b0303c917179f2a0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Jun 2019 14:42:55 +0800 Subject: [PATCH 169/881] Update go.mod --- go.mod | 10 ++++++---- go.sum | 23 ++++++++++------------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/go.mod b/go.mod index 3ec7aab0..d2424b3f 100644 --- a/go.mod +++ b/go.mod @@ -1,11 +1,13 @@ module github.com/jinzhu/gorm +go 1.12 + require ( - github.com/denisenkom/go-mssqldb v0.0.0-20190423183735-731ef375ac02 + github.com/denisenkom/go-mssqldb v0.0.0-20190515213511-eb9f6a1743f3 github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 github.com/go-sql-driver/mysql v1.4.1 - github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a - github.com/jinzhu/now v1.0.0 - github.com/lib/pq v1.1.0 + github.com/jinzhu/inflection v1.0.0 + github.com/jinzhu/now v1.0.1 + github.com/lib/pq v1.1.1 github.com/mattn/go-sqlite3 v1.10.0 ) diff --git a/go.sum b/go.sum index 848f7293..d9d073e6 100644 --- a/go.sum +++ b/go.sum @@ -11,8 +11,8 @@ github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/denisenkom/go-mssqldb v0.0.0-20190423183735-731ef375ac02 h1:PS3xfVPa8N84AzoWZHFCbA0+ikz4f4skktfjQoNMsgk= -github.com/denisenkom/go-mssqldb v0.0.0-20190423183735-731ef375ac02/go.mod h1:zAg7JM8CkOJ43xKXIj7eRO9kmWm/TW578qo+oDO6tuM= +github.com/denisenkom/go-mssqldb v0.0.0-20190515213511-eb9f6a1743f3 h1:tkum0XDgfR0jcVVXuTsYv/erY2NnEDqwRojbxR1rBYA= +github.com/denisenkom/go-mssqldb v0.0.0-20190515213511-eb9f6a1743f3/go.mod h1:zAg7JM8CkOJ43xKXIj7eRO9kmWm/TW578qo+oDO6tuM= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= @@ -32,6 +32,7 @@ github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfb github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= @@ -40,17 +41,17 @@ github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51 github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= -github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a h1:eeaG9XMUvRBYXJi4pg1ZKM7nxc5AfXfojeLLW7O5J3k= -github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.0.0 h1:6WV8LvwPpDhKjo5U9O6b4+xdG/jTXNPwlDme/MTo8Ns= -github.com/jinzhu/now v1.0.0/go.mod h1:oHTiXerJ20+SfYcrdlBO7rzZRJWGwSTQ0iUY2jI6Gfc= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M= +github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= -github.com/lib/pq v1.1.0 h1:/5u4a+KGJptBRqGzPvYQL9p0d/tPR4S31+Tnzj9lEO4= -github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4= +github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= @@ -58,7 +59,6 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= -github.com/openzipkin/zipkin-go v0.1.6 h1:yXiysv1CSK7Q5yjGy1710zZGnsbMUIjluWBxtLXHPBo= github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -112,16 +112,13 @@ golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3 golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.5.0 h1:KxkO13IPW4Lslp2bz+KHP2E3gtFlrIGNThxkZQ3g+4c= -google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190404172233-64821d5d2107/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.20.1 h1:Hz2g2wirWK7H0qIIhGIqRGTuMwTE8HEKFnDZZ7lm9NU= -google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From beb591e642787c6790afb9ff48310a819829acb6 Mon Sep 17 00:00:00 2001 From: zaneli Date: Mon, 24 Jun 2019 20:38:13 +0900 Subject: [PATCH 170/881] Fix function name of comment --- main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.go b/main.go index e24638a6..67e5f58e 100644 --- a/main.go +++ b/main.go @@ -528,7 +528,7 @@ func (s *DB) Begin() *DB { return s.BeginTx(context.Background(), &sql.TxOptions{}) } -// BeginTX begins a transaction with options +// BeginTx begins a transaction with options func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB { c := s.clone() if db, ok := c.db.(sqlDb); ok && db != nil { From e3cc5ea4d403078a370e299629da56cd011b6583 Mon Sep 17 00:00:00 2001 From: Herpiko Dwi Aguno Date: Fri, 21 Jun 2019 21:29:12 +0700 Subject: [PATCH 171/881] Fix #2517 : Check for incomplete parentheses to prevent SQL injection. --- query_test.go | 17 +++++++++++++++++ scope.go | 21 +++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/query_test.go b/query_test.go index 15bf8b3c..2b7e0dff 100644 --- a/query_test.go +++ b/query_test.go @@ -133,6 +133,23 @@ func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode) } } +func TestStringAgainstIncompleteParentheses(t *testing.T) { + type AddressByZipCode struct { + ZipCode string `gorm:"primary_key"` + Address string + } + + DB.AutoMigrate(&AddressByZipCode{}) + DB.Create(&AddressByZipCode{ZipCode: "00502", Address: "Holtsville"}) + + var address AddressByZipCode + var addresses []AddressByZipCode + _ = DB.First(&address, "address_by_zip_codes=00502)) UNION ALL SELECT NULL,version(),current_database(),NULL,NULL,NULL,NULL,NULL--").Find(&addresses).GetErrors() + if len(addresses) > 0 { + t.Errorf("Fetch a record from with a string that has incomplete parentheses should be fail, zip code is %v", address.ZipCode) + } + +} func TestFindAsSliceOfPointers(t *testing.T) { DB.Save(&User{Name: "user"}) diff --git a/scope.go b/scope.go index c962c165..541fe522 100644 --- a/scope.go +++ b/scope.go @@ -277,6 +277,23 @@ func (scope *Scope) AddToVars(value interface{}) string { return scope.Dialect().BindVar(len(scope.SQLVars)) } +// IsCompleteParentheses check if the string has complete parentheses to prevent SQL injection +func (scope *Scope) IsCompleteParentheses(value string) bool { + count := 0 + for i, _ := range value { + if value[i] == 40 { // ( + count++ + } else if value[i] == 41 { // ) + count-- + } + if count < 0 { + break + } + i++ + } + return count == 0 +} + // SelectAttrs return selected attributes func (scope *Scope) SelectAttrs() []string { if scope.selectAttrs == nil { @@ -556,6 +573,10 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) } if value != "" { + if !scope.IsCompleteParentheses(value) { + scope.Err(fmt.Errorf("incomplete parentheses found: %v", value)) + return + } if !include { if comparisonRegexp.MatchString(value) { str = fmt.Sprintf("NOT (%v)", value) From 2a3ab99a081dc14b29dfd4df42d4c59ba1814d21 Mon Sep 17 00:00:00 2001 From: haoc7 Date: Mon, 2 Sep 2019 09:44:50 +0800 Subject: [PATCH 172/881] fix insert timezero 0001-01-01 (#2635) --- logger.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/logger.go b/logger.go index 484bc022..a42f2727 100644 --- a/logger.go +++ b/logger.go @@ -49,7 +49,11 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) { if indirectValue.IsValid() { value = indirectValue.Interface() if t, ok := value.(time.Time); ok { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05"))) + if t.IsZero() { + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", "0000-00-00 00:00:00")) + } else { + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05"))) + } } else if b, ok := value.([]byte); ok { if str := string(b); isPrintable(str) { formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str)) From b9548541168d54697fed015b99e732c12f2289ec Mon Sep 17 00:00:00 2001 From: Steve Ellis Date: Thu, 12 Sep 2019 10:13:59 -0400 Subject: [PATCH 173/881] bump mattn/go-sqlite3 to v1.11.0 (#2565) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index d2424b3f..2d2fec37 100644 --- a/go.mod +++ b/go.mod @@ -9,5 +9,5 @@ require ( github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.0.1 github.com/lib/pq v1.1.1 - github.com/mattn/go-sqlite3 v1.10.0 + github.com/mattn/go-sqlite3 v1.11.0 ) diff --git a/go.sum b/go.sum index d9d073e6..c43559bf 100644 --- a/go.sum +++ b/go.sum @@ -52,8 +52,8 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxv github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4= github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= -github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q= +github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= From d5cafb5db15c1c6026005bfe0b41220cf2513887 Mon Sep 17 00:00:00 2001 From: Shunsuke Otani Date: Thu, 12 Sep 2019 23:16:05 +0900 Subject: [PATCH 174/881] Fix CallbackProcessor.Get() for removed or replaced same name callback (#2548) --- callback.go | 10 +++++++--- callbacks_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/callback.go b/callback.go index 6f60511b..202af06e 100644 --- a/callback.go +++ b/callback.go @@ -135,11 +135,15 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S // db.Callback().Create().Get("gorm:create") func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { for _, p := range cp.parent.processors { - if p.name == callbackName && p.kind == cp.kind && !cp.remove { - return *p.processor + if p.name == callbackName && p.kind == cp.kind { + if p.remove { + callback = nil + } else { + callback = *p.processor + } } } - return nil + return } // getRIndex get right index from string slice diff --git a/callbacks_test.go b/callbacks_test.go index a58913d7..c1a1d5e4 100644 --- a/callbacks_test.go +++ b/callbacks_test.go @@ -2,11 +2,10 @@ package gorm_test import ( "errors" - - "github.com/jinzhu/gorm" - "reflect" "testing" + + "github.com/jinzhu/gorm" ) func (s *Product) BeforeCreate() (err error) { @@ -175,3 +174,46 @@ func TestCallbacksWithErrors(t *testing.T) { t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback") } } + +func TestGetCallback(t *testing.T) { + scope := DB.NewScope(nil) + + if DB.Callback().Create().Get("gorm:test_callback") != nil { + t.Errorf("`gorm:test_callback` should be nil") + } + + DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) }) + callback := DB.Callback().Create().Get("gorm:test_callback") + if callback == nil { + t.Errorf("`gorm:test_callback` should be non-nil") + } + callback(scope) + if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 { + t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok) + } + + DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) }) + callback = DB.Callback().Create().Get("gorm:test_callback") + if callback == nil { + t.Errorf("`gorm:test_callback` should be non-nil") + } + callback(scope) + if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 { + t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok) + } + + DB.Callback().Create().Remove("gorm:test_callback") + if DB.Callback().Create().Get("gorm:test_callback") != nil { + t.Errorf("`gorm:test_callback` should be nil") + } + + DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) }) + callback = DB.Callback().Create().Get("gorm:test_callback") + if callback == nil { + t.Errorf("`gorm:test_callback` should be non-nil") + } + callback(scope) + if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 { + t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok) + } +} From 13f19a503687379fcf3080a49e4b2f4482355b75 Mon Sep 17 00:00:00 2001 From: Shunsuke Otani Date: Thu, 12 Sep 2019 23:16:52 +0900 Subject: [PATCH 175/881] Uncapitalize error strings (#2533) --- callback_delete.go | 2 +- callback_update.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/callback_delete.go b/callback_delete.go index 50242e48..48b97acb 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -17,7 +17,7 @@ func init() { // beforeDeleteCallback will invoke `BeforeDelete` method before deleting func beforeDeleteCallback(scope *Scope) { if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { - scope.Err(errors.New("Missing WHERE clause while deleting")) + scope.Err(errors.New("missing WHERE clause while deleting")) return } if !scope.HasError() { diff --git a/callback_update.go b/callback_update.go index 56711d37..699e534b 100644 --- a/callback_update.go +++ b/callback_update.go @@ -34,7 +34,7 @@ func assignUpdatingAttributesCallback(scope *Scope) { // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating func beforeUpdateCallback(scope *Scope) { if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { - scope.Err(errors.New("Missing WHERE clause while updating")) + scope.Err(errors.New("missing WHERE clause while updating")) return } if _, ok := scope.Get("gorm:update_column"); !ok { From 0c98e7d712e2fdc3a191a7cd2a37fabfce3768f2 Mon Sep 17 00:00:00 2001 From: Christian Muehlhaeuser Date: Thu, 12 Sep 2019 16:17:31 +0200 Subject: [PATCH 176/881] Fixed import formatting to match goimports (#2568) --- dialects/postgres/postgres.go | 1 + 1 file changed, 1 insertion(+) diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 424e8bdc..e6c088b1 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "fmt" + _ "github.com/lib/pq" "github.com/lib/pq/hstore" ) From 81c17a7e2529c59efc4e74c5b32c1fb71fb12fa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Emir=20Beganovi=C4=87?= Date: Wed, 25 Sep 2019 13:22:43 +0200 Subject: [PATCH 177/881] Revert "Fix #2517 : Check for incomplete parentheses to prevent SQL injection." (#2674) This reverts commit e3cc5ea4d403078a370e299629da56cd011b6583. --- query_test.go | 17 ----------------- scope.go | 21 --------------------- 2 files changed, 38 deletions(-) diff --git a/query_test.go b/query_test.go index 2b7e0dff..15bf8b3c 100644 --- a/query_test.go +++ b/query_test.go @@ -133,23 +133,6 @@ func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode) } } -func TestStringAgainstIncompleteParentheses(t *testing.T) { - type AddressByZipCode struct { - ZipCode string `gorm:"primary_key"` - Address string - } - - DB.AutoMigrate(&AddressByZipCode{}) - DB.Create(&AddressByZipCode{ZipCode: "00502", Address: "Holtsville"}) - - var address AddressByZipCode - var addresses []AddressByZipCode - _ = DB.First(&address, "address_by_zip_codes=00502)) UNION ALL SELECT NULL,version(),current_database(),NULL,NULL,NULL,NULL,NULL--").Find(&addresses).GetErrors() - if len(addresses) > 0 { - t.Errorf("Fetch a record from with a string that has incomplete parentheses should be fail, zip code is %v", address.ZipCode) - } - -} func TestFindAsSliceOfPointers(t *testing.T) { DB.Save(&User{Name: "user"}) diff --git a/scope.go b/scope.go index 541fe522..c962c165 100644 --- a/scope.go +++ b/scope.go @@ -277,23 +277,6 @@ func (scope *Scope) AddToVars(value interface{}) string { return scope.Dialect().BindVar(len(scope.SQLVars)) } -// IsCompleteParentheses check if the string has complete parentheses to prevent SQL injection -func (scope *Scope) IsCompleteParentheses(value string) bool { - count := 0 - for i, _ := range value { - if value[i] == 40 { // ( - count++ - } else if value[i] == 41 { // ) - count-- - } - if count < 0 { - break - } - i++ - } - return count == 0 -} - // SelectAttrs return selected attributes func (scope *Scope) SelectAttrs() []string { if scope.selectAttrs == nil { @@ -573,10 +556,6 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) } if value != "" { - if !scope.IsCompleteParentheses(value) { - scope.Err(fmt.Errorf("incomplete parentheses found: %v", value)) - return - } if !include { if comparisonRegexp.MatchString(value) { str = fmt.Sprintf("NOT (%v)", value) From e5d0267c0bee4a92af603ea570fa9121e6440b11 Mon Sep 17 00:00:00 2001 From: Jay Chung Date: Sat, 5 Oct 2019 12:12:47 +0800 Subject: [PATCH 178/881] Fix typo of example code --- callback.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/callback.go b/callback.go index 202af06e..719b0a78 100644 --- a/callback.go +++ b/callback.go @@ -119,8 +119,8 @@ func (cp *CallbackProcessor) Remove(callbackName string) { // Replace a registered callback with new callback // db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) { -// scope.SetColumn("Created", now) -// scope.SetColumn("Updated", now) +// scope.SetColumn("CreatedAt", now) +// scope.SetColumn("UpdatedAt", now) // }) func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { cp.logger.Print(fmt.Sprintf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())) From 820b5f244abf7ef16f362de39b19adfef31fff2d Mon Sep 17 00:00:00 2001 From: Alex Stockwell Date: Thu, 17 Oct 2019 07:54:11 -0700 Subject: [PATCH 179/881] MSSQL Create() fix: Add LastInsertIDReturningSuffix to dialect (#2690) * MSSQL Create() fix: Add LastInsertIDReturningSuffix to dialect Per https://github.com/denisenkom/go-mssqldb/issues/355 * MSSQL Create() fix: Added OUTPUT query to Create() builder --- callback_create.go | 47 +++++++++++++++++++++++++++++------------ dialect.go | 2 ++ dialect_common.go | 4 ++++ dialect_postgres.go | 4 ++++ dialects/mssql/mssql.go | 8 +++++++ 5 files changed, 52 insertions(+), 13 deletions(-) diff --git a/callback_create.go b/callback_create.go index 87aba8ee..3527858b 100644 --- a/callback_create.go +++ b/callback_create.go @@ -101,10 +101,11 @@ func createCallback(scope *Scope) { } lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) + lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns) if len(columns) == 0 { scope.Raw(fmt.Sprintf( - "INSERT %v INTO %v %v%v%v", + "INSERT%v INTO %v %v%v%v", addExtraSpaceIfExist(insertModifier), quotedTableName, scope.Dialect().DefaultValueStr(), @@ -113,18 +114,19 @@ func createCallback(scope *Scope) { )) } else { scope.Raw(fmt.Sprintf( - "INSERT %v INTO %v (%v) VALUES (%v)%v%v", + "INSERT%v INTO %v (%v)%v VALUES (%v)%v%v", addExtraSpaceIfExist(insertModifier), scope.QuotedTableName(), strings.Join(columns, ","), + addExtraSpaceIfExist(lastInsertIDOutputInterstitial), strings.Join(placeholders, ","), addExtraSpaceIfExist(extraOption), addExtraSpaceIfExist(lastInsertIDReturningSuffix), )) } - // execute create sql - if lastInsertIDReturningSuffix == "" || primaryField == nil { + // execute create sql: no primaryField + if primaryField == nil { if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { // set rows affected count scope.db.RowsAffected, _ = result.RowsAffected() @@ -136,16 +138,35 @@ func createCallback(scope *Scope) { } } } - } else { - if primaryField.Field.CanAddr() { - if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { - primaryField.IsBlank = false - scope.db.RowsAffected = 1 - } - } else { - scope.Err(ErrUnaddressable) - } + return } + + // execute create sql: lastInsertID implemention for majority of dialects + if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" { + if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + // set rows affected count + scope.db.RowsAffected, _ = result.RowsAffected() + + // set primary value to primary field + if primaryField != nil && primaryField.IsBlank { + if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { + scope.Err(primaryField.Set(primaryValue)) + } + } + } + return + } + + // execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql) + if primaryField.Field.CanAddr() { + if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { + primaryField.IsBlank = false + scope.db.RowsAffected = 1 + } + } else { + scope.Err(ErrUnaddressable) + } + return } } diff --git a/dialect.go b/dialect.go index 831c0a8e..b6f95df7 100644 --- a/dialect.go +++ b/dialect.go @@ -40,6 +40,8 @@ type Dialect interface { LimitAndOffsetSQL(limit, offset interface{}) string // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` SelectFromDummyTable() string + // LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT` + LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` LastInsertIDReturningSuffix(tableName, columnName string) string // DefaultValueStr diff --git a/dialect_common.go b/dialect_common.go index e3a5b702..16da76dc 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -157,6 +157,10 @@ func (commonDialect) SelectFromDummyTable() string { return "" } +func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { + return "" +} + func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } diff --git a/dialect_postgres.go b/dialect_postgres.go index 53d31388..d2df3131 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -120,6 +120,10 @@ func (s postgres) CurrentDatabase() (name string) { return } +func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string { + return "" +} + func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { return fmt.Sprintf("RETURNING %v.%v", tableName, key) } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 8c2360fc..eb79f7e7 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -190,6 +190,14 @@ func (mssql) SelectFromDummyTable() string { return "" } +func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { + if len(columns) == 0 { + // No OUTPUT to query + return "" + } + return fmt.Sprintf("OUTPUT Inserted.%v", columnName) +} + func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } From d2007b3c826bf2f528d8dae0913f77cbac3ef7fd Mon Sep 17 00:00:00 2001 From: Devin Samarin Date: Thu, 17 Oct 2019 07:56:19 -0700 Subject: [PATCH 180/881] Describe name of field for invalid SQL datatypes (#2689) --- dialect_mysql.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dialect_mysql.go b/dialect_mysql.go index 5a1ad708..1addaf36 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -120,7 +120,7 @@ func (s *mysql) DataTypeOf(field *StructField) string { } if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String())) + panic(fmt.Sprintf("invalid sql type %s (%s) in field %s for mysql", dataValue.Type().Name(), dataValue.Kind().String(), field.Name)) } if strings.TrimSpace(additionalType) == "" { From 7729627ff65324940367a4ea9d068767ac4e79fb Mon Sep 17 00:00:00 2001 From: Lilit Date: Thu, 17 Oct 2019 18:12:01 +0300 Subject: [PATCH 181/881] Fix logging callbacks (#2652) --- callback.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/callback.go b/callback.go index 719b0a78..4d8e72c0 100644 --- a/callback.go +++ b/callback.go @@ -96,7 +96,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { if cp.kind == "row_query" { if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { - cp.logger.Print(fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)) + cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName)) cp.before = "gorm:row_query" } } @@ -110,7 +110,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope * // Remove a registered callback // db.Callback().Create().Remove("gorm:update_time_stamp_when_create") func (cp *CallbackProcessor) Remove(callbackName string) { - cp.logger.Print(fmt.Sprintf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())) + cp.logger.Print("info", fmt.Sprintf("[info] removing callback `%v` from %v", callbackName, fileWithLineNum())) cp.name = callbackName cp.remove = true cp.parent.processors = append(cp.parent.processors, cp) @@ -123,7 +123,7 @@ func (cp *CallbackProcessor) Remove(callbackName string) { // scope.SetColumn("UpdatedAt", now) // }) func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { - cp.logger.Print(fmt.Sprintf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())) + cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum())) cp.name = callbackName cp.processor = &callback cp.replace = true @@ -166,7 +166,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { for _, cp := range cps { // show warning message the callback name already exists if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { - cp.logger.Print(fmt.Sprintf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())) + cp.logger.Print("warning", fmt.Sprintf("[warning] duplicated callback `%v` from %v", cp.name, fileWithLineNum())) } allNames = append(allNames, cp.name) } From 120d39b4d6873cb2a5a4b789a031bd2cc8465a12 Mon Sep 17 00:00:00 2001 From: okhowang <3352585+okhowang@users.noreply.github.com> Date: Thu, 17 Oct 2019 23:22:13 +0800 Subject: [PATCH 182/881] use show statement in mysql dialect for compatibility for tencent tdsql (#2643) --- dialect_mysql.go | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/dialect_mysql.go b/dialect_mysql.go index 1addaf36..ac9b3b2e 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -2,6 +2,7 @@ package gorm import ( "crypto/sha1" + "database/sql" "fmt" "reflect" "regexp" @@ -161,6 +162,39 @@ func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { return count > 0 } +func (s mysql) HasTable(tableName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + var name string + if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM %s WHERE Tables_in_%s = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil { + if err == sql.ErrNoRows { + return false + } + panic(err) + } else { + return true + } +} + +func (s mysql) HasIndex(tableName string, indexName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil { + panic(err) + } else { + defer rows.Close() + return rows.Next() + } +} + +func (s mysql) HasColumn(tableName string, columnName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil { + panic(err) + } else { + defer rows.Close() + return rows.Next() + } +} + func (s mysql) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DATABASE()").Scan(&name) return From b99f2d827067caef22fdd72c967f597515fba15d Mon Sep 17 00:00:00 2001 From: "lotus.wu" Date: Thu, 17 Oct 2019 23:36:06 +0800 Subject: [PATCH 183/881] 1. suport date time '2070-03-30 00:00:00',timestamp can't support large date time. (#1823) --- dialect_mysql.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dialect_mysql.go b/dialect_mysql.go index ac9b3b2e..da46d586 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -103,10 +103,10 @@ func (s *mysql) DataTypeOf(field *StructField) string { precision = fmt.Sprintf("(%s)", p) } - if _, ok := field.TagSettingsGet("NOT NULL"); ok || field.IsPrimaryKey { - sqlType = fmt.Sprintf("timestamp%v", precision) + if _, ok := field.TagSettings["NOT NULL"]; ok || field.IsPrimaryKey { + sqlType = fmt.Sprintf("DATETIME%v", precision) } else { - sqlType = fmt.Sprintf("timestamp%v NULL", precision) + sqlType = fmt.Sprintf("DATETIME%v NULL", precision) } } default: From a8a530db5a78f0c5719f3ea8b0970de637245da5 Mon Sep 17 00:00:00 2001 From: aimuz Date: Thu, 17 Oct 2019 23:38:37 +0800 Subject: [PATCH 184/881] SetColumn No fields ignored were processed (#2579) --- scope.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope.go b/scope.go index c962c165..e64a8ba8 100644 --- a/scope.go +++ b/scope.go @@ -225,7 +225,7 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error { updateAttrs[field.DBName] = value return field.Set(value) } - if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) { + if !field.IsIgnored && ((field.DBName == dbName) || (field.Name == name && mostMatchedField == nil)) { mostMatchedField = field } } From 5b3e40ac12c1b5ad09fbcefc06fa6d7bda415ef3 Mon Sep 17 00:00:00 2001 From: macklin-10x <53452532+macklin-10x@users.noreply.github.com> Date: Thu, 17 Oct 2019 08:44:34 -0700 Subject: [PATCH 185/881] Rename expr type to make it public. (#2604) --- main.go | 6 +++--- scope.go | 6 +++--- search.go | 2 +- utils.go | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/main.go b/main.go index 67e5f58e..eac28f8a 100644 --- a/main.go +++ b/main.go @@ -209,8 +209,8 @@ func (s *DB) NewScope(value interface{}) *Scope { return scope } -// QueryExpr returns the query as expr object -func (s *DB) QueryExpr() *expr { +// QueryExpr returns the query as SqlExpr object +func (s *DB) QueryExpr() *SqlExpr { scope := s.NewScope(s.Value) scope.InstanceSet("skip_bindvar", true) scope.prepareQuerySQL() @@ -219,7 +219,7 @@ func (s *DB) QueryExpr() *expr { } // SubQuery returns the query as sub query -func (s *DB) SubQuery() *expr { +func (s *DB) SubQuery() *SqlExpr { scope := s.NewScope(s.Value) scope.InstanceSet("skip_bindvar", true) scope.prepareQuerySQL() diff --git a/scope.go b/scope.go index e64a8ba8..eb7525b8 100644 --- a/scope.go +++ b/scope.go @@ -257,7 +257,7 @@ func (scope *Scope) CallMethod(methodName string) { func (scope *Scope) AddToVars(value interface{}) string { _, skipBindVar := scope.InstanceGet("skip_bindvar") - if expr, ok := value.(*expr); ok { + if expr, ok := value.(*SqlExpr); ok { exp := expr.expr for _, arg := range expr.args { if skipBindVar { @@ -785,7 +785,7 @@ func (scope *Scope) orderSQL() string { for _, order := range scope.Search.orders { if str, ok := order.(string); ok { orders = append(orders, scope.quoteIfPossible(str)) - } else if expr, ok := order.(*expr); ok { + } else if expr, ok := order.(*SqlExpr); ok { exp := expr.expr for _, arg := range expr.args { exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) @@ -912,7 +912,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin for key, value := range convertInterfaceToMap(value, true, scope.db) { if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { - if _, ok := value.(*expr); ok { + if _, ok := value.(*SqlExpr); ok { hasUpdate = true results[field.DBName] = value } else { diff --git a/search.go b/search.go index 90138595..7c4cc184 100644 --- a/search.go +++ b/search.go @@ -98,7 +98,7 @@ func (s *search) Group(query string) *search { } func (s *search) Having(query interface{}, values ...interface{}) *search { - if val, ok := query.(*expr); ok { + if val, ok := query.(*SqlExpr); ok { s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args}) } else { s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) diff --git a/utils.go b/utils.go index e58e57a5..d2ae9465 100644 --- a/utils.go +++ b/utils.go @@ -58,15 +58,15 @@ func newSafeMap() *safeMap { } // SQL expression -type expr struct { +type SqlExpr struct { expr string args []interface{} } // Expr generate raw SQL expression, for example: // DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) -func Expr(expression string, args ...interface{}) *expr { - return &expr{expr: expression, args: args} +func Expr(expression string, args ...interface{}) *SqlExpr { + return &SqlExpr{expr: expression, args: args} } func indirect(reflectValue reflect.Value) reflect.Value { From 5fe32d593fad1bd8005c5fbc90489c9174ce73d6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 21 Oct 2019 20:20:38 +0800 Subject: [PATCH 186/881] Escape table name for mysql HasTable --- dialect_mysql.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dialect_mysql.go b/dialect_mysql.go index da46d586..ee9a43d3 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -165,7 +165,7 @@ func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { func (s mysql) HasTable(tableName string) bool { currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) var name string - if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM %s WHERE Tables_in_%s = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil { + if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil { if err == sql.ErrNoRows { return false } From 795328fedc12a34cd2ea7483b2d8ee618bca46c6 Mon Sep 17 00:00:00 2001 From: FWangZil <779158078@qq.com> Date: Mon, 21 Oct 2019 20:45:38 +0800 Subject: [PATCH 187/881] fix(HasTable): database name (#2717) * fix(HasTable): database name allow mysql database name with '-' character * docs: add comment --- dialect_mysql.go | 1 + 1 file changed, 1 insertion(+) diff --git a/dialect_mysql.go b/dialect_mysql.go index ee9a43d3..ab6a8a91 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -165,6 +165,7 @@ func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { func (s mysql) HasTable(tableName string) bool { currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) var name string + // allow mysql database name with '-' character if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil { if err == sql.ErrNoRows { return false From 530711e724f3d4c678abc73f84be52c807e3df69 Mon Sep 17 00:00:00 2001 From: Ruben de Vries Date: Tue, 22 Oct 2019 11:27:30 +0200 Subject: [PATCH 188/881] fix a race condition on IsForeignKey that is being detected by -race sometimes. --- model_struct.go | 19 ++++++++- model_struct_test.go | 93 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 2 deletions(-) create mode 100644 model_struct_test.go diff --git a/model_struct.go b/model_struct.go index 5234b287..d9e2e90f 100644 --- a/model_struct.go +++ b/model_struct.go @@ -17,6 +17,10 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { return defaultTableName } +// lock for mutating global cached model metadata +var structsLock sync.Mutex + +// global cache of model metadata var modelStructsMap sync.Map // ModelStruct model definition @@ -419,8 +423,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for idx, foreignKey := range foreignKeys { if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { - // source foreign keys + // mark field as foreignkey, use global lock to avoid race + structsLock.Lock() foreignField.IsForeignKey = true + structsLock.Unlock() + + // association foreign keys relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) @@ -523,8 +531,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for idx, foreignKey := range foreignKeys { if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { + // mark field as foreignkey, use global lock to avoid race + structsLock.Lock() foreignField.IsForeignKey = true - // source foreign keys + structsLock.Unlock() + + // association foreign keys relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) @@ -582,7 +594,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for idx, foreignKey := range foreignKeys { if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { + // mark field as foreignkey, use global lock to avoid race + structsLock.Lock() foreignField.IsForeignKey = true + structsLock.Unlock() // association foreign keys relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) diff --git a/model_struct_test.go b/model_struct_test.go new file mode 100644 index 00000000..2ae419a0 --- /dev/null +++ b/model_struct_test.go @@ -0,0 +1,93 @@ +package gorm_test + +import ( + "sync" + "testing" + + "github.com/jinzhu/gorm" +) + +type ModelA struct { + gorm.Model + Name string + + ModelCs []ModelC `gorm:"foreignkey:OtherAID"` +} + +type ModelB struct { + gorm.Model + Name string + + ModelCs []ModelC `gorm:"foreignkey:OtherBID"` +} + +type ModelC struct { + gorm.Model + Name string + + OtherAID uint64 + OtherA *ModelA `gorm:"foreignkey:OtherAID"` + OtherBID uint64 + OtherB *ModelB `gorm:"foreignkey:OtherBID"` +} + +// This test will try to cause a race condition on the model's foreignkey metadata +func TestModelStructRaceSameModel(t *testing.T) { + // use a WaitGroup to execute as much in-sync as possible + // it's more likely to hit a race condition than without + n := 32 + start := sync.WaitGroup{} + start.Add(n) + + // use another WaitGroup to know when the test is done + done := sync.WaitGroup{} + done.Add(n) + + for i := 0; i < n; i++ { + go func() { + start.Wait() + + // call GetStructFields, this had a race condition before we fixed it + DB.NewScope(&ModelA{}).GetStructFields() + + done.Done() + }() + + start.Done() + } + + done.Wait() +} + +// This test will try to cause a race condition on the model's foreignkey metadata +func TestModelStructRaceDifferentModel(t *testing.T) { + // use a WaitGroup to execute as much in-sync as possible + // it's more likely to hit a race condition than without + n := 32 + start := sync.WaitGroup{} + start.Add(n) + + // use another WaitGroup to know when the test is done + done := sync.WaitGroup{} + done.Add(n) + + for i := 0; i < n; i++ { + i := i + go func() { + start.Wait() + + // call GetStructFields, this had a race condition before we fixed it + if i%2 == 0 { + DB.NewScope(&ModelA{}).GetStructFields() + } else { + DB.NewScope(&ModelB{}).GetStructFields() + } + + done.Done() + }() + + start.Done() + } + + done.Wait() +} From d926a05bec9ab9ee6f8bc6d865c1ccdf9350c74b Mon Sep 17 00:00:00 2001 From: "kouha.shu" Date: Wed, 23 Oct 2019 10:38:05 +0900 Subject: [PATCH 189/881] add warning comment --- main.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index eac28f8a..5dda8838 100644 --- a/main.go +++ b/main.go @@ -433,7 +433,8 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { return c } -// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update +// Update update attributes with callbacks. refer: https://jinzhu.github.io/gorm/crud.html#update +// WARNING when update with struct, GORM will not update fields that with zero value func (s *DB) Update(attrs ...interface{}) *DB { return s.Updates(toSearchableMap(attrs...), true) } @@ -480,6 +481,7 @@ func (s *DB) Create(value interface{}) *DB { } // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition +// WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time func (s *DB) Delete(value interface{}, where ...interface{}) *DB { return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db } From 2ee239a4c07c9e3f9948500cf01f667e89a7986d Mon Sep 17 00:00:00 2001 From: "kouha.shu" Date: Wed, 23 Oct 2019 10:40:34 +0900 Subject: [PATCH 190/881] Update main.go --- main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.go b/main.go index 5dda8838..e39a868a 100644 --- a/main.go +++ b/main.go @@ -433,7 +433,7 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { return c } -// Update update attributes with callbacks. refer: https://jinzhu.github.io/gorm/crud.html#update +// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update // WARNING when update with struct, GORM will not update fields that with zero value func (s *DB) Update(attrs ...interface{}) *DB { return s.Updates(toSearchableMap(attrs...), true) From c46c01c11689fa240dc70483f00ffb10dab9141f Mon Sep 17 00:00:00 2001 From: Dom Narducci Date: Fri, 25 Oct 2019 13:51:29 -0700 Subject: [PATCH 191/881] Log callback registration if logger exists for consistency --- callback.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/callback.go b/callback.go index 4d8e72c0..56b2064a 100644 --- a/callback.go +++ b/callback.go @@ -101,6 +101,12 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope * } } + if cp.logger != nil { + // note cp.logger will be nil during the default gorm callback registrations + // as they occur within init() blocks. However, any user-registered callbacks + // will happen after cp.logger exists (as the default logger or user-specified). + cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum())) + } cp.name = callbackName cp.processor = &callback cp.parent.processors = append(cp.parent.processors, cp) From 59408390c2dce9ca8b48fae08937213e72b24f9a Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Tue, 19 Nov 2019 16:08:00 +0800 Subject: [PATCH 192/881] Add `db.Transaction` method for create Transaction block. (#2767) * Add `db.Transaction` method for create Transaction block. example: ```go func CreateAnimals(db *gorm.DB) error { db.Transaction(func(tx *gorm.DB) error { if err := tx.Create(&Animal{Name: "Giraffe"}).Error; err != nil { // return any error will rollback return err } if err := tx.Create(&Animal{Name: "Lion"}).Error; err != nil { return err } // return nil will commit return nil }) } ``` * Ensure rollback when commit has error. --- main.go | 25 ++++++++++++++++++++++ main_test.go | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/main.go b/main.go index e39a868a..48d22c85 100644 --- a/main.go +++ b/main.go @@ -525,6 +525,31 @@ func (s *DB) Debug() *DB { return s.clone().LogMode(true) } +// Transaction start a transaction as a block, +// return error will rollback, otherwise to commit. +func (s *DB) Transaction(fc func(tx *DB) error) (err error) { + tx := s.Begin() + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%s", r) + tx.Rollback() + return + } + }() + + err = fc(tx) + + if err == nil { + err = tx.Commit().Error + } + + // Makesure rollback when Block error or Commit error + if err != nil { + tx.Rollback() + } + return +} + // Begin begins a transaction func (s *DB) Begin() *DB { return s.BeginTx(context.Background(), &sql.TxOptions{}) diff --git a/main_test.go b/main_test.go index 68bf7419..134672b7 100644 --- a/main_test.go +++ b/main_test.go @@ -8,6 +8,7 @@ import ( "context" "database/sql" "database/sql/driver" + "errors" "fmt" "os" "path/filepath" @@ -469,6 +470,65 @@ func TestTransaction(t *testing.T) { } } +func TestTransactionWithBlock(t *testing.T) { + // rollback + err := DB.Transaction(func(tx *gorm.DB) error { + u := User{Name: "transcation"} + if err := tx.Save(&u).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { + t.Errorf("Should find saved record") + } + + return errors.New("the error message") + }) + + if err.Error() != "the error message" { + t.Errorf("Transaction return error will equal the block returns error") + } + + if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil { + t.Errorf("Should not find record after rollback") + } + + // commit + DB.Transaction(func(tx *gorm.DB) error { + u2 := User{Name: "transcation-2"} + if err := tx.Save(&u2).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + t.Errorf("Should find saved record") + } + return nil + }) + + if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + t.Errorf("Should be able to find committed record") + } + + // panic will rollback + DB.Transaction(func(tx *gorm.DB) error { + u3 := User{Name: "transcation-3"} + if err := tx.Save(&u3).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", "transcation-3").Error; err != nil { + t.Errorf("Should find saved record") + } + + panic("force panic") + }) + + if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil { + t.Errorf("Should not find record after panic rollback") + } +} + func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) { tx := DB.Begin() u := User{Name: "transcation"} From 23f6840776b08a33b8eb1394616abee31a4c9e98 Mon Sep 17 00:00:00 2001 From: zaneli Date: Thu, 31 Oct 2019 02:51:26 +0900 Subject: [PATCH 193/881] Add limit and offset parse error --- dialect.go | 2 +- dialect_common.go | 19 ++++++++++-- dialect_mysql.go | 15 ++++++--- dialects/mssql/mssql.go | 17 +++++++++-- query_test.go | 68 +++++++++++++++++++++++++++++++++++++++++ scope.go | 4 ++- 6 files changed, 113 insertions(+), 12 deletions(-) diff --git a/dialect.go b/dialect.go index b6f95df7..749587f4 100644 --- a/dialect.go +++ b/dialect.go @@ -37,7 +37,7 @@ type Dialect interface { ModifyColumn(tableName string, columnName string, typ string) error // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case - LimitAndOffsetSQL(limit, offset interface{}) string + LimitAndOffsetSQL(limit, offset interface{}) (string, error) // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` SelectFromDummyTable() string // LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT` diff --git a/dialect_common.go b/dialect_common.go index 16da76dc..950c1986 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -139,14 +139,23 @@ func (s commonDialect) CurrentDatabase() (name string) { return } -func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { +// LimitAndOffsetSQL return generated SQL with Limit and Offset +func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { if limit != nil { - if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { + parsedLimit, err := s.parseInt(limit) + if err != nil { + return "", err + } + if parsedLimit >= 0 { sql += fmt.Sprintf(" LIMIT %d", parsedLimit) } } if offset != nil { - if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { + parsedOffset, err := s.parseInt(offset) + if err != nil { + return "", err + } + if parsedOffset >= 0 { sql += fmt.Sprintf(" OFFSET %d", parsedOffset) } } @@ -181,6 +190,10 @@ func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (stri return indexName, columnName } +func (commonDialect) parseInt(value interface{}) (int64, error) { + return strconv.ParseInt(fmt.Sprint(value), 0, 0) +} + // IsByteArrayOrSlice returns true of the reflected value is an array or slice func IsByteArrayOrSlice(value reflect.Value) bool { return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) diff --git a/dialect_mysql.go b/dialect_mysql.go index ab6a8a91..b4467ffa 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -6,7 +6,6 @@ import ( "fmt" "reflect" "regexp" - "strconv" "strings" "time" "unicode/utf8" @@ -140,13 +139,21 @@ func (s mysql) ModifyColumn(tableName string, columnName string, typ string) err return err } -func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { +func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { if limit != nil { - if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { + parsedLimit, err := s.parseInt(limit) + if err != nil { + return "", err + } + if parsedLimit >= 0 { sql += fmt.Sprintf(" LIMIT %d", parsedLimit) if offset != nil { - if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { + parsedOffset, err := s.parseInt(offset) + if err != nil { + return "", err + } + if parsedOffset >= 0 { sql += fmt.Sprintf(" OFFSET %d", parsedOffset) } } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index eb79f7e7..43acb379 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -168,14 +168,25 @@ func (s mssql) CurrentDatabase() (name string) { return } -func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { +func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { + parseInt := func(value interface{}) (int64, error) { + return strconv.ParseInt(fmt.Sprint(value), 0, 0) + } if offset != nil { - if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { + parsedOffset, err := parseInt(offset) + if err != nil { + return "", err + } + if parsedOffset >= 0 { sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset) } } if limit != nil { - if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { + parsedLimit, err := parseInt(limit) + if err != nil { + return "", err + } + if parsedLimit >= 0 { if sql == "" { // add default zero offset sql += " OFFSET 0 ROWS" diff --git a/query_test.go b/query_test.go index 15bf8b3c..a23a9e24 100644 --- a/query_test.go +++ b/query_test.go @@ -457,6 +457,74 @@ func TestOffset(t *testing.T) { } } +func TestLimitAndOffsetSQL(t *testing.T) { + user1 := User{Name: "TestLimitAndOffsetSQL1", Age: 10} + user2 := User{Name: "TestLimitAndOffsetSQL2", Age: 20} + user3 := User{Name: "TestLimitAndOffsetSQL3", Age: 30} + user4 := User{Name: "TestLimitAndOffsetSQL4", Age: 40} + user5 := User{Name: "TestLimitAndOffsetSQL5", Age: 50} + if err := DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5).Error; err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + limit, offset interface{} + users []*User + ok bool + }{ + { + name: "OK", + limit: float64(2), + offset: float64(2), + users: []*User{ + &User{Name: "TestLimitAndOffsetSQL3", Age: 30}, + &User{Name: "TestLimitAndOffsetSQL2", Age: 20}, + }, + ok: true, + }, + { + name: "Limit parse error", + limit: float64(1000000), // 1e+06 + offset: float64(2), + ok: false, + }, + { + name: "Offset parse error", + limit: float64(2), + offset: float64(1000000), // 1e+06 + ok: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var users []*User + err := DB.Where("name LIKE ?", "TestLimitAndOffsetSQL%").Order("age desc").Limit(tt.limit).Offset(tt.offset).Find(&users).Error + if tt.ok { + if err != nil { + t.Errorf("error expected nil, but got %v", err) + } + if len(users) != len(tt.users) { + t.Errorf("users length expected %d, but got %d", len(tt.users), len(users)) + } + for i := range tt.users { + if users[i].Name != tt.users[i].Name { + t.Errorf("users[%d] name expected %s, but got %s", i, tt.users[i].Name, users[i].Name) + } + if users[i].Age != tt.users[i].Age { + t.Errorf("users[%d] age expected %d, but got %d", i, tt.users[i].Age, users[i].Age) + } + } + } else { + if err == nil { + t.Error("error expected not nil, but got nil") + } + } + }) + } +} + func TestOr(t *testing.T) { user1 := User{Name: "OrUser1", Age: 1} user2 := User{Name: "OrUser2", Age: 10} diff --git a/scope.go b/scope.go index eb7525b8..0e9dfd1c 100644 --- a/scope.go +++ b/scope.go @@ -797,7 +797,9 @@ func (scope *Scope) orderSQL() string { } func (scope *Scope) limitAndOffsetSQL() string { - return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) + sql, err := scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) + scope.Err(err) + return sql } func (scope *Scope) groupSQL() string { From 9827710b60e717b1411611da5b1bf52476aa34cb Mon Sep 17 00:00:00 2001 From: Thomas Tacquet Date: Wed, 27 Nov 2019 15:51:23 +0100 Subject: [PATCH 194/881] bump go-sqlite3 to v1.12.0 to fix go1.13 issues --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 2d2fec37..87207be4 100644 --- a/go.mod +++ b/go.mod @@ -9,5 +9,5 @@ require ( github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.0.1 github.com/lib/pq v1.1.1 - github.com/mattn/go-sqlite3 v1.11.0 + github.com/mattn/go-sqlite3 v1.12.0 ) diff --git a/go.sum b/go.sum index c43559bf..9c7e8a54 100644 --- a/go.sum +++ b/go.sum @@ -54,6 +54,8 @@ github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4= github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q= github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/mattn/go-sqlite3 v1.12.0 h1:u/x3mp++qUxvYfulZ4HKOvVO0JWhk7HtE8lWhbGz/Do= +github.com/mattn/go-sqlite3 v1.12.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= From b543a11ca0f9768994c6be4328284b167c1d83ba Mon Sep 17 00:00:00 2001 From: Charles Strahan Date: Thu, 5 Dec 2019 03:54:32 -0600 Subject: [PATCH 195/881] transaction blocks: don't swallow panics (#2774) This improves upon #2767. Previously, the code would swallow any panics, which isn't ideal; panic is intended to be used when a critical error arises, where the process should fail fast instead of trying to limp along. This now defers the any recovery (if desired) to the client code. --- main.go | 11 ++++------- main_test.go | 29 ++++++++++++++++++++--------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/main.go b/main.go index 48d22c85..24fd8382 100644 --- a/main.go +++ b/main.go @@ -528,12 +528,12 @@ func (s *DB) Debug() *DB { // Transaction start a transaction as a block, // return error will rollback, otherwise to commit. func (s *DB) Transaction(fc func(tx *DB) error) (err error) { + panicked := true tx := s.Begin() defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("%s", r) + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { tx.Rollback() - return } }() @@ -543,10 +543,7 @@ func (s *DB) Transaction(fc func(tx *DB) error) (err error) { err = tx.Commit().Error } - // Makesure rollback when Block error or Commit error - if err != nil { - tx.Rollback() - } + panicked = false return } diff --git a/main_test.go b/main_test.go index 134672b7..98ea4694 100644 --- a/main_test.go +++ b/main_test.go @@ -470,6 +470,15 @@ func TestTransaction(t *testing.T) { } } +func assertPanic(t *testing.T, f func()) { + defer func() { + if r := recover(); r == nil { + t.Errorf("The code did not panic") + } + }() + f() +} + func TestTransactionWithBlock(t *testing.T) { // rollback err := DB.Transaction(func(tx *gorm.DB) error { @@ -511,17 +520,19 @@ func TestTransactionWithBlock(t *testing.T) { } // panic will rollback - DB.Transaction(func(tx *gorm.DB) error { - u3 := User{Name: "transcation-3"} - if err := tx.Save(&u3).Error; err != nil { - t.Errorf("No error should raise") - } + assertPanic(t, func() { + DB.Transaction(func(tx *gorm.DB) error { + u3 := User{Name: "transcation-3"} + if err := tx.Save(&u3).Error; err != nil { + t.Errorf("No error should raise") + } - if err := tx.First(&User{}, "name = ?", "transcation-3").Error; err != nil { - t.Errorf("Should find saved record") - } + if err := tx.First(&User{}, "name = ?", "transcation-3").Error; err != nil { + t.Errorf("Should find saved record") + } - panic("force panic") + panic("force panic") + }) }) if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil { From 2c2fbb99e5234bd22f0659ad82104f6e9adcd63d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 5 Dec 2019 18:05:12 +0800 Subject: [PATCH 196/881] Upgrade go-sqlite to v2.0.1 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 87207be4..4d6eb7fa 100644 --- a/go.mod +++ b/go.mod @@ -9,5 +9,5 @@ require ( github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.0.1 github.com/lib/pq v1.1.1 - github.com/mattn/go-sqlite3 v1.12.0 + github.com/mattn/go-sqlite3 v2.0.1+incompatible ) diff --git a/go.sum b/go.sum index 9c7e8a54..a9ae14d5 100644 --- a/go.sum +++ b/go.sum @@ -56,6 +56,8 @@ github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mattn/go-sqlite3 v1.12.0 h1:u/x3mp++qUxvYfulZ4HKOvVO0JWhk7HtE8lWhbGz/Do= github.com/mattn/go-sqlite3 v1.12.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw= +github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= From 32ec5c04a6884ad3d85b6e83a77ce66de1a71816 Mon Sep 17 00:00:00 2001 From: Thomas Tacquet Date: Wed, 27 Nov 2019 15:51:23 +0100 Subject: [PATCH 197/881] bump go-sqlite3 to v1.12.0 to fix go1.13 issues --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 2d2fec37..87207be4 100644 --- a/go.mod +++ b/go.mod @@ -9,5 +9,5 @@ require ( github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.0.1 github.com/lib/pq v1.1.1 - github.com/mattn/go-sqlite3 v1.11.0 + github.com/mattn/go-sqlite3 v1.12.0 ) diff --git a/go.sum b/go.sum index c43559bf..9c7e8a54 100644 --- a/go.sum +++ b/go.sum @@ -54,6 +54,8 @@ github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4= github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q= github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/mattn/go-sqlite3 v1.12.0 h1:u/x3mp++qUxvYfulZ4HKOvVO0JWhk7HtE8lWhbGz/Do= +github.com/mattn/go-sqlite3 v1.12.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= From 0aba7ff3a0bff05dc25ec027895b5e6789e28bb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=BE=E4=B8=80=E9=A5=BC?= Date: Thu, 5 Dec 2019 18:26:16 +0800 Subject: [PATCH 198/881] Beautify callback log output (#2749) --- logger.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/logger.go b/logger.go index a42f2727..b4a362ce 100644 --- a/logger.go +++ b/logger.go @@ -39,6 +39,15 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) { messages = []interface{}{source, currentTime} + if len(values) == 2 { + //remove the line break + currentTime = currentTime[1:] + //remove the brackets + source = fmt.Sprintf("\033[35m%v\033[0m", values[1]) + + messages = []interface{}{currentTime, source} + } + if level == "sql" { // duration messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) From e8c07b55316b12d028eecac5e9a49f1b16918e44 Mon Sep 17 00:00:00 2001 From: Shunsuke Otani Date: Thu, 5 Dec 2019 23:57:15 +0900 Subject: [PATCH 199/881] Set nopLogger to DefaultCallback for avoid nil pointer dereference (#2742) --- callback.go | 9 ++------- callbacks_test.go | 30 ++++++++++++++++++++++++++++++ logger.go | 4 ++++ 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/callback.go b/callback.go index 56b2064a..1f0e3c79 100644 --- a/callback.go +++ b/callback.go @@ -3,7 +3,7 @@ package gorm import "fmt" // DefaultCallback default callbacks defined by gorm -var DefaultCallback = &Callback{} +var DefaultCallback = &Callback{logger: nopLogger{}} // Callback is a struct that contains all CRUD callbacks // Field `creates` contains callbacks will be call when creating object @@ -101,12 +101,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope * } } - if cp.logger != nil { - // note cp.logger will be nil during the default gorm callback registrations - // as they occur within init() blocks. However, any user-registered callbacks - // will happen after cp.logger exists (as the default logger or user-specified). - cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum())) - } + cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum())) cp.name = callbackName cp.processor = &callback cp.parent.processors = append(cp.parent.processors, cp) diff --git a/callbacks_test.go b/callbacks_test.go index c1a1d5e4..bebd0e38 100644 --- a/callbacks_test.go +++ b/callbacks_test.go @@ -217,3 +217,33 @@ func TestGetCallback(t *testing.T) { t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok) } } + +func TestUseDefaultCallback(t *testing.T) { + createCallbackName := "gorm:test_use_default_callback_for_create" + gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) { + // nop + }) + if gorm.DefaultCallback.Create().Get(createCallbackName) == nil { + t.Errorf("`%s` expected non-nil, but got nil", createCallbackName) + } + gorm.DefaultCallback.Create().Remove(createCallbackName) + if gorm.DefaultCallback.Create().Get(createCallbackName) != nil { + t.Errorf("`%s` expected nil, but got non-nil", createCallbackName) + } + + updateCallbackName := "gorm:test_use_default_callback_for_update" + scopeValueName := "gorm:test_use_default_callback_for_update_value" + gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) { + scope.Set(scopeValueName, 1) + }) + gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) { + scope.Set(scopeValueName, 2) + }) + + scope := DB.NewScope(nil) + callback := gorm.DefaultCallback.Update().Get(updateCallbackName) + callback(scope) + if v, ok := scope.Get(scopeValueName); !ok || v != 2 { + t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok) + } +} diff --git a/logger.go b/logger.go index b4a362ce..88e167dd 100644 --- a/logger.go +++ b/logger.go @@ -135,3 +135,7 @@ type Logger struct { func (logger Logger) Print(values ...interface{}) { logger.Println(LogFormatter(values...)...) } + +type nopLogger struct{} + +func (nopLogger) Print(values ...interface{}) {} From 11e2819f44a6b6e2b21119a9eaf451244abd808b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 5 Dec 2019 23:13:54 +0800 Subject: [PATCH 200/881] Extract parseInt --- dialect_common.go | 12 ++++-------- dialects/mssql/mssql.go | 19 ++++++++----------- 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/dialect_common.go b/dialect_common.go index 950c1986..d549510c 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -142,20 +142,16 @@ func (s commonDialect) CurrentDatabase() (name string) { // LimitAndOffsetSQL return generated SQL with Limit and Offset func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { if limit != nil { - parsedLimit, err := s.parseInt(limit) - if err != nil { + if parsedLimit, err := s.parseInt(limit); err != nil { return "", err - } - if parsedLimit >= 0 { + } else if parsedLimit >= 0 { sql += fmt.Sprintf(" LIMIT %d", parsedLimit) } } if offset != nil { - parsedOffset, err := s.parseInt(offset) - if err != nil { + if parsedOffset, err := s.parseInt(offset); err != nil { return "", err - } - if parsedOffset >= 0 { + } else if parsedOffset >= 0 { sql += fmt.Sprintf(" OFFSET %d", parsedOffset) } } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 43acb379..cb2714e0 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -168,25 +168,22 @@ func (s mssql) CurrentDatabase() (name string) { return } +func parseInt(value interface{}) (int64, error) { + return strconv.ParseInt(fmt.Sprint(value), 0, 0) +} + func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { - parseInt := func(value interface{}) (int64, error) { - return strconv.ParseInt(fmt.Sprint(value), 0, 0) - } if offset != nil { - parsedOffset, err := parseInt(offset) - if err != nil { + if parsedOffset, err := parseInt(offset); err != nil { return "", err - } - if parsedOffset >= 0 { + } else if parsedOffset >= 0 { sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset) } } if limit != nil { - parsedLimit, err := parseInt(limit) - if err != nil { + if parsedLimit, err := parseInt(limit); err != nil { return "", err - } - if parsedLimit >= 0 { + } else if parsedLimit >= 0 { if sql == "" { // add default zero offset sql += " OFFSET 0 ROWS" From 5490a87fe9f9d72a38cfa641e7965bf48f588b87 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 6 Dec 2019 00:01:40 +0800 Subject: [PATCH 201/881] Should use global NowFunc when trace SQL --- callback_create.go | 2 +- callback_query.go | 2 +- scope.go | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/callback_create.go b/callback_create.go index 3527858b..5271dc29 100644 --- a/callback_create.go +++ b/callback_create.go @@ -50,7 +50,7 @@ func updateTimeStampForCreateCallback(scope *Scope) { // createCallback the callback used to insert data into database func createCallback(scope *Scope) { if !scope.HasError() { - defer scope.trace(scope.db.nowFunc()) + defer scope.trace(NowFunc()) var ( columns, placeholders []string diff --git a/callback_query.go b/callback_query.go index e3b3d534..7facc42b 100644 --- a/callback_query.go +++ b/callback_query.go @@ -24,7 +24,7 @@ func queryCallback(scope *Scope) { return } - defer scope.trace(scope.db.nowFunc()) + defer scope.trace(NowFunc()) var ( isSlice, isPtr bool diff --git a/scope.go b/scope.go index 0e9dfd1c..d82cadbc 100644 --- a/scope.go +++ b/scope.go @@ -358,7 +358,7 @@ func (scope *Scope) Raw(sql string) *Scope { // Exec perform generated SQL func (scope *Scope) Exec() *Scope { - defer scope.trace(scope.db.nowFunc()) + defer scope.trace(NowFunc()) if !scope.HasError() { if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { @@ -934,7 +934,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin } func (scope *Scope) row() *sql.Row { - defer scope.trace(scope.db.nowFunc()) + defer scope.trace(NowFunc()) result := &RowQueryResult{} scope.InstanceSet("row_query_result", result) @@ -944,7 +944,7 @@ func (scope *Scope) row() *sql.Row { } func (scope *Scope) rows() (*sql.Rows, error) { - defer scope.trace(scope.db.nowFunc()) + defer scope.trace(NowFunc()) result := &RowsQueryResult{} scope.InstanceSet("row_query_result", result) From 9d2b65f8c9604651197b9d864500d05ddce2cc99 Mon Sep 17 00:00:00 2001 From: Dozer Date: Fri, 6 Dec 2019 09:16:51 +0800 Subject: [PATCH 202/881] add query hint support (#2351) * add query hint support * remove add extra space * add test and fix bug * fix ut * fix ut --- callback_query.go | 5 +++++ callback_row_query.go | 5 +++++ main_test.go | 24 ++++++++++++++++++++++++ 3 files changed, 34 insertions(+) diff --git a/callback_query.go b/callback_query.go index 7facc42b..544afd63 100644 --- a/callback_query.go +++ b/callback_query.go @@ -60,6 +60,11 @@ func queryCallback(scope *Scope) { if !scope.HasError() { scope.db.RowsAffected = 0 + + if str, ok := scope.Get("gorm:query_hint"); ok { + scope.SQL = fmt.Sprint(str) + scope.SQL + } + if str, ok := scope.Get("gorm:query_option"); ok { scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) } diff --git a/callback_row_query.go b/callback_row_query.go index 687b0039..323b1605 100644 --- a/callback_row_query.go +++ b/callback_row_query.go @@ -23,6 +23,11 @@ type RowsQueryResult struct { func rowQueryCallback(scope *Scope) { if result, ok := scope.InstanceGet("row_query_result"); ok { scope.prepareQuerySQL() + + if str, ok := scope.Get("gorm:query_hint"); ok { + scope.SQL = fmt.Sprint(str) + scope.SQL + } + if str, ok := scope.Get("gorm:query_option"); ok { scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) } diff --git a/main_test.go b/main_test.go index 98ea4694..b51fe413 100644 --- a/main_test.go +++ b/main_test.go @@ -1333,6 +1333,30 @@ func TestCountWithQueryOption(t *testing.T) { } } +func TestQueryHint1(t *testing.T) { + db := DB.New() + + _, err := db.Model(User{}).Raw("select 1").Rows() + + if err != nil { + t.Error("Unexpected error on query count with query_option") + } +} + +func TestQueryHint2(t *testing.T) { + type TestStruct struct { + ID string `gorm:"primary_key"` + Name string + } + DB.DropTable(&TestStruct{}) + DB.AutoMigrate(&TestStruct{}) + + data := TestStruct{ID: "uuid", Name: "hello"} + if err := DB.Set("gorm:query_hint", "/*master*/").Save(&data).Error; err != nil { + t.Error("Unexpected error on query count with query_option") + } +} + func TestFloatColumnPrecision(t *testing.T) { if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" && dialect != "sqlite" { t.Skip() From f616ccd39773f0b1c6967aab3eb1de4f04dd001f Mon Sep 17 00:00:00 2001 From: misko Date: Mon, 14 Oct 2019 14:13:18 +0800 Subject: [PATCH 203/881] 1. fix bug : https://github.com/jinzhu/gorm/issues/2700 --- main.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index 24fd8382..3db87870 100644 --- a/main.go +++ b/main.go @@ -124,7 +124,10 @@ func (s *DB) Close() error { // DB get `*sql.DB` from current connection // If the underlying database connection is not a *sql.DB, returns nil func (s *DB) DB() *sql.DB { - db, _ := s.db.(*sql.DB) + db, ok := s.db.(*sql.DB) + if !ok { + panic("can't support full GORM on currently status, maybe this is a TX instance.") + } return db } From 79a77d771dee4e4b60e9c543e8663bbc80466670 Mon Sep 17 00:00:00 2001 From: jaden <1336364665@qq.com> Date: Fri, 6 Dec 2019 22:22:28 +0800 Subject: [PATCH 204/881] go.mod: remove unnecessary dependences through upgrade go-mssqldb (#2795) * go.mod: remove unnecessary dependences through upgrade go-mssqldb $ go get -v -u github.com/denisenkom/go-mssqldb && go mod tidy -v go: finding github.com/denisenkom/go-mssqldb latest go: finding github.com/golang-sql/civil latest go: finding golang.org/x/crypto latest unused cloud.google.com/go unused gopkg.in/check.v1 unused gopkg.in/yaml.v2 * mssql: use SCOPE_IDENTITY() if OUTPUT not possible * go-mssqldb: find a up-to-date version pass test -race --- callback_create.go | 5 +- dialects/mssql/mssql.go | 3 +- go.mod | 4 +- go.sum | 122 +++------------------------------------- 4 files changed, 17 insertions(+), 117 deletions(-) diff --git a/callback_create.go b/callback_create.go index 5271dc29..c4d25f37 100644 --- a/callback_create.go +++ b/callback_create.go @@ -100,8 +100,11 @@ func createCallback(scope *Scope) { returningColumn = scope.Quote(primaryField.DBName) } - lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns) + var lastInsertIDReturningSuffix string + if lastInsertIDOutputInterstitial == "" { + lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) + } if len(columns) == 0 { scope.Raw(fmt.Sprintf( diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index cb2714e0..a516ed4a 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -207,7 +207,8 @@ func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, column } func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { - return "" + // https://stackoverflow.com/questions/5228780/how-to-get-last-inserted-id + return "; SELECT SCOPE_IDENTITY()" } func (mssql) DefaultValueStr() string { diff --git a/go.mod b/go.mod index 4d6eb7fa..6e923b9d 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,13 @@ module github.com/jinzhu/gorm go 1.12 require ( - github.com/denisenkom/go-mssqldb v0.0.0-20190515213511-eb9f6a1743f3 + github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 github.com/go-sql-driver/mysql v1.4.1 github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.0.1 github.com/lib/pq v1.1.1 github.com/mattn/go-sqlite3 v2.0.1+incompatible + golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd // indirect + google.golang.org/appengine v1.4.0 // indirect ) diff --git a/go.sum b/go.sum index a9ae14d5..915b4c21 100644 --- a/go.sum +++ b/go.sum @@ -1,135 +1,29 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.37.4 h1:glPeL3BQJsbF6aIIYfZizMwc5LTYz250bDMjttbBGAU= -cloud.google.com/go v0.37.4/go.mod h1:NHPJ89PdicEuT9hdPXMROBD91xc5uRDxsMtSB16k7hw= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= -github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= -github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= -github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/denisenkom/go-mssqldb v0.0.0-20190515213511-eb9f6a1743f3 h1:tkum0XDgfR0jcVVXuTsYv/erY2NnEDqwRojbxR1rBYA= -github.com/denisenkom/go-mssqldb v0.0.0-20190515213511-eb9f6a1743f3/go.mod h1:zAg7JM8CkOJ43xKXIj7eRO9kmWm/TW578qo+oDO6tuM= -github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= -github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= -github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= +github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd h1:83Wprp6ROGeiHFAP8WJdI2RoxALQYgdllERc3N5N2DM= +github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DPaBjB8zlTR87/ElzFsnQfuHnVUVqpZZIcV5Y= github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= -github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= -github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= -github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= -github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= -github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M= github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= -github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4= github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q= -github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= -github.com/mattn/go-sqlite3 v1.12.0 h1:u/x3mp++qUxvYfulZ4HKOvVO0JWhk7HtE8lWhbGz/Do= -github.com/mattn/go-sqlite3 v1.12.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw= github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= -github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= -github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= -github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= -github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829/go.mod h1:p2iRAGwDERtqlqzRXnrOVns+ignqQo//hLXqYxZYVNs= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= -github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= -golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd h1:GGJVjV8waZKRHrgwvtH66z9ZGVurTD1MT0n1Bb+q4aM= +golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190404172233-64821d5d2107/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= From 7180bd0f27d167f18c253c32d548c7de3adc6b0d Mon Sep 17 00:00:00 2001 From: Mike Zuev <39210290+mszuyev@users.noreply.github.com> Date: Sun, 26 Jan 2020 18:28:32 +0300 Subject: [PATCH 205/881] updated go-sql-driver package (#2859) --- go.mod | 3 +-- go.sum | 8 ++------ 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index 6e923b9d..91ff3cb8 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,10 @@ go 1.12 require ( github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 - github.com/go-sql-driver/mysql v1.4.1 + github.com/go-sql-driver/mysql v1.5.0 github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.0.1 github.com/lib/pq v1.1.1 github.com/mattn/go-sqlite3 v2.0.1+incompatible golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd // indirect - google.golang.org/appengine v1.4.0 // indirect ) diff --git a/go.sum b/go.sum index 915b4c21..e09a0352 100644 --- a/go.sum +++ b/go.sum @@ -2,11 +2,10 @@ github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd h1:83Wprp6RO github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DPaBjB8zlTR87/ElzFsnQfuHnVUVqpZZIcV5Y= github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= -github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= -github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M= @@ -20,10 +19,7 @@ golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0F golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd h1:GGJVjV8waZKRHrgwvtH66z9ZGVurTD1MT0n1Bb+q4aM= golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= From f0d514e3308c8a53dc09a989b3b69284ce5b63eb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 28 Jan 2020 16:21:14 +0800 Subject: [PATCH 206/881] Cleanup --- association.go | 377 -------- association_test.go | 1050 -------------------- callback.go | 250 ----- callback_create.go | 197 ---- callback_delete.go | 63 -- callback_query.go | 109 --- callback_query_preload.go | 410 -------- callback_row_query.go | 41 - callback_save.go | 170 ---- callback_system_test.go | 112 --- callback_update.go | 121 --- callbacks_test.go | 249 ----- create_test.go | 288 ------ customize_column_test.go | 357 ------- delete_test.go | 91 -- dialect.go | 147 --- dialect_common.go | 196 ---- dialect_mysql.go | 246 ----- dialect_postgres.go | 147 --- dialect_sqlite3.go | 107 --- dialects/mssql/mssql.go | 253 ----- dialects/mysql/mysql.go | 3 - dialects/postgres/postgres.go | 81 -- dialects/sqlite/sqlite.go | 3 - docker-compose.yml | 30 - embedded_struct_test.go | 91 -- errors.go | 72 -- errors_test.go | 20 - field.go | 66 -- field_test.go | 130 --- go.mod | 13 - go.sum | 25 - interface.go | 24 - join_table_handler.go | 211 ---- join_table_test.go | 117 --- logger.go | 141 --- main.go | 881 ----------------- main_test.go | 1444 ---------------------------- migration_test.go | 579 ----------- model.go | 14 - model_struct.go | 671 ------------- model_struct_test.go | 93 -- multi_primary_keys_test.go | 381 -------- naming.go | 124 --- naming_test.go | 69 -- pointer_test.go | 84 -- polymorphic_test.go | 366 ------- preload_test.go | 1701 --------------------------------- query_test.go | 841 ---------------- scaner_test.go | 139 --- scope.go | 1421 --------------------------- scope_test.go | 93 -- search.go | 153 --- search_test.go | 30 - test_all.sh | 5 - update_test.go | 465 --------- utils.go | 226 ----- wercker.yml | 154 --- 58 files changed, 15942 deletions(-) delete mode 100644 association.go delete mode 100644 association_test.go delete mode 100644 callback.go delete mode 100644 callback_create.go delete mode 100644 callback_delete.go delete mode 100644 callback_query.go delete mode 100644 callback_query_preload.go delete mode 100644 callback_row_query.go delete mode 100644 callback_save.go delete mode 100644 callback_system_test.go delete mode 100644 callback_update.go delete mode 100644 callbacks_test.go delete mode 100644 create_test.go delete mode 100644 customize_column_test.go delete mode 100644 delete_test.go delete mode 100644 dialect.go delete mode 100644 dialect_common.go delete mode 100644 dialect_mysql.go delete mode 100644 dialect_postgres.go delete mode 100644 dialect_sqlite3.go delete mode 100644 dialects/mssql/mssql.go delete mode 100644 dialects/mysql/mysql.go delete mode 100644 dialects/postgres/postgres.go delete mode 100644 dialects/sqlite/sqlite.go delete mode 100644 docker-compose.yml delete mode 100644 embedded_struct_test.go delete mode 100644 errors.go delete mode 100644 errors_test.go delete mode 100644 field.go delete mode 100644 field_test.go delete mode 100644 go.sum delete mode 100644 interface.go delete mode 100644 join_table_handler.go delete mode 100644 join_table_test.go delete mode 100644 logger.go delete mode 100644 main.go delete mode 100644 main_test.go delete mode 100644 migration_test.go delete mode 100644 model.go delete mode 100644 model_struct.go delete mode 100644 model_struct_test.go delete mode 100644 multi_primary_keys_test.go delete mode 100644 naming.go delete mode 100644 naming_test.go delete mode 100644 pointer_test.go delete mode 100644 polymorphic_test.go delete mode 100644 preload_test.go delete mode 100644 query_test.go delete mode 100644 scaner_test.go delete mode 100644 scope.go delete mode 100644 scope_test.go delete mode 100644 search.go delete mode 100644 search_test.go delete mode 100755 test_all.sh delete mode 100644 update_test.go delete mode 100644 utils.go delete mode 100644 wercker.yml diff --git a/association.go b/association.go deleted file mode 100644 index a73344fe..00000000 --- a/association.go +++ /dev/null @@ -1,377 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" -) - -// Association Mode contains some helper methods to handle relationship things easily. -type Association struct { - Error error - scope *Scope - column string - field *Field -} - -// Find find out all related associations -func (association *Association) Find(value interface{}) *Association { - association.scope.related(value, association.column) - return association.setErr(association.scope.db.Error) -} - -// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to -func (association *Association) Append(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - if relationship := association.field.Relationship; relationship.Kind == "has_one" { - return association.Replace(values...) - } - return association.saveAssociations(values...) -} - -// Replace replace current associations with new one -func (association *Association) Replace(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - var ( - relationship = association.field.Relationship - scope = association.scope - field = association.field.Field - newDB = scope.NewDB() - ) - - // Append new values - association.field.Set(reflect.Zero(association.field.Field.Type())) - association.saveAssociations(values...) - - // Belongs To - if relationship.Kind == "belongs_to" { - // Set foreign key to be null when clearing value (length equals 0) - if len(values) == 0 { - // Set foreign key to be nil - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) - } - } else { - // Polymorphic Relations - if relationship.PolymorphicDBName != "" { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) - } - - // Delete Relations except new created - if len(values) > 0 { - var associationForeignFieldNames, associationForeignDBNames []string - if relationship.Kind == "many_to_many" { - // if many to many relations, get association fields name from association foreign keys - associationScope := scope.New(reflect.New(field.Type()).Interface()) - for idx, dbName := range relationship.AssociationForeignFieldNames { - if field, ok := associationScope.FieldByName(dbName); ok { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - associationForeignDBNames = append(associationForeignDBNames, relationship.AssociationForeignDBNames[idx]) - } - } - } else { - // If has one/many relations, use primary keys - for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - associationForeignDBNames = append(associationForeignDBNames, field.DBName) - } - } - - newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) - - if len(newPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...) - } - } - - if relationship.Kind == "many_to_many" { - // if many to many relations, delete related relations from join table - var sourceForeignFieldNames []string - - for _, dbName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name) - } - } - - if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { - newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) - } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) - var foreignKeyMap = map[string]interface{}{} - for idx, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - fieldValue := reflect.New(association.field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) - } - } - return association -} - -// Delete remove relationship between source & passed arguments, but won't delete those arguments -func (association *Association) Delete(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - var ( - relationship = association.field.Relationship - scope = association.scope - field = association.field.Field - newDB = scope.NewDB() - ) - - if len(values) == 0 { - return association - } - - var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string - for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { - deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name) - deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) - } - - deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) - - if relationship.Kind == "many_to_many" { - // source value's foreign keys - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - // get association's foreign fields name - var associationScope = scope.New(reflect.New(field.Type()).Interface()) - var associationForeignFieldNames []string - for _, associationDBName := range relationship.AssociationForeignFieldNames { - if field, ok := associationScope.FieldByName(associationDBName); ok { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - } - } - - // association value's foreign keys - deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...) - sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) - } else { - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - - if relationship.Kind == "belongs_to" { - // find with deleting relation's foreign keys - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - - // set foreign key to be null if there are some records affected - modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() - if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { - if results.RowsAffected > 0 { - scope.updatedAttrsWithValues(foreignKeyMap) - } - } else { - association.setErr(results.Error) - } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // find all relations - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - - // only include those deleting relations - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)), - toQueryValues(deletingPrimaryKeys)..., - ) - - // set matched relation's foreign key to be null - fieldValue := reflect.New(association.field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) - } - } - - // Remove deleted records from source's field - if association.Error == nil { - if field.Kind() == reflect.Slice { - leftValues := reflect.Zero(field.Type()) - - for i := 0; i < field.Len(); i++ { - reflectValue := field.Index(i) - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] - var isDeleted = false - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - isDeleted = true - break - } - } - if !isDeleted { - leftValues = reflect.Append(leftValues, reflectValue) - } - } - - association.field.Set(leftValues) - } else if field.Kind() == reflect.Struct { - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0] - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - association.field.Set(reflect.Zero(field.Type())) - break - } - } - } - } - - return association -} - -// Clear remove relationship between source & current associations, won't delete those associations -func (association *Association) Clear() *Association { - return association.Replace() -} - -// Count return the count of current associations -func (association *Association) Count() int { - var ( - count = 0 - relationship = association.field.Relationship - scope = association.scope - fieldValue = association.field.Field.Interface() - query = scope.DB() - ) - - switch relationship.Kind { - case "many_to_many": - query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) - case "has_many", "has_one": - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) - query = query.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - case "belongs_to": - primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) - query = query.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - } - - if relationship.PolymorphicType != "" { - query = query.Where( - fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)), - relationship.PolymorphicValue, - ) - } - - if err := query.Model(fieldValue).Count(&count).Error; err != nil { - association.Error = err - } - return count -} - -// saveAssociations save passed values as associations -func (association *Association) saveAssociations(values ...interface{}) *Association { - var ( - scope = association.scope - field = association.field - relationship = field.Relationship - ) - - saveAssociation := func(reflectValue reflect.Value) { - // value has to been pointer - if reflectValue.Kind() != reflect.Ptr { - reflectPtr := reflect.New(reflectValue.Type()) - reflectPtr.Elem().Set(reflectValue) - reflectValue = reflectPtr - } - - // value has to been saved for many2many - if relationship.Kind == "many_to_many" { - if scope.New(reflectValue.Interface()).PrimaryKeyZero() { - association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) - } - } - - // Assign Fields - var fieldType = field.Field.Type() - var setFieldBackToValue, setSliceFieldBackToValue bool - if reflectValue.Type().AssignableTo(fieldType) { - field.Set(reflectValue) - } else if reflectValue.Type().Elem().AssignableTo(fieldType) { - // if field's type is struct, then need to set value back to argument after save - setFieldBackToValue = true - field.Set(reflectValue.Elem()) - } else if fieldType.Kind() == reflect.Slice { - if reflectValue.Type().AssignableTo(fieldType.Elem()) { - field.Set(reflect.Append(field.Field, reflectValue)) - } else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) { - // if field's type is slice of struct, then need to set value back to argument after save - setSliceFieldBackToValue = true - field.Set(reflect.Append(field.Field, reflectValue.Elem())) - } - } - - if relationship.Kind == "many_to_many" { - association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) - } else { - association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) - - if setFieldBackToValue { - reflectValue.Elem().Set(field.Field) - } else if setSliceFieldBackToValue { - reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1)) - } - } - } - - for _, value := range values { - reflectValue := reflect.ValueOf(value) - indirectReflectValue := reflect.Indirect(reflectValue) - if indirectReflectValue.Kind() == reflect.Struct { - saveAssociation(reflectValue) - } else if indirectReflectValue.Kind() == reflect.Slice { - for i := 0; i < indirectReflectValue.Len(); i++ { - saveAssociation(indirectReflectValue.Index(i)) - } - } else { - association.setErr(errors.New("invalid value type")) - } - } - return association -} - -// setErr set error when the error is not nil. And return Association. -func (association *Association) setErr(err error) *Association { - if err != nil { - association.Error = err - } - return association -} diff --git a/association_test.go b/association_test.go deleted file mode 100644 index 60d0cf48..00000000 --- a/association_test.go +++ /dev/null @@ -1,1050 +0,0 @@ -package gorm_test - -import ( - "fmt" - "os" - "reflect" - "sort" - "testing" - - "github.com/jinzhu/gorm" -) - -func TestBelongsTo(t *testing.T) { - post := Post{ - Title: "post belongs to", - Body: "body belongs to", - Category: Category{Name: "Category 1"}, - MainCategory: Category{Name: "Main Category 1"}, - } - - if err := DB.Save(&post).Error; err != nil { - t.Error("Got errors when save post", err) - } - - if post.Category.ID == 0 || post.MainCategory.ID == 0 { - t.Errorf("Category's primary key should be updated") - } - - if post.CategoryId.Int64 == 0 || post.MainCategoryId == 0 { - t.Errorf("post's foreign key should be updated") - } - - // Query - var category1 Category - DB.Model(&post).Association("Category").Find(&category1) - if category1.Name != "Category 1" { - t.Errorf("Query belongs to relations with Association") - } - - var mainCategory1 Category - DB.Model(&post).Association("MainCategory").Find(&mainCategory1) - if mainCategory1.Name != "Main Category 1" { - t.Errorf("Query belongs to relations with Association") - } - - var category11 Category - DB.Model(&post).Related(&category11) - if category11.Name != "Category 1" { - t.Errorf("Query belongs to relations with Related") - } - - if DB.Model(&post).Association("Category").Count() != 1 { - t.Errorf("Post's category count should be 1") - } - - if DB.Model(&post).Association("MainCategory").Count() != 1 { - t.Errorf("Post's main category count should be 1") - } - - // Append - var category2 = Category{ - Name: "Category 2", - } - DB.Model(&post).Association("Category").Append(&category2) - - if category2.ID == 0 { - t.Errorf("Category should has ID when created with Append") - } - - var category21 Category - DB.Model(&post).Related(&category21) - - if category21.Name != "Category 2" { - t.Errorf("Category should be updated with Append") - } - - if DB.Model(&post).Association("Category").Count() != 1 { - t.Errorf("Post's category count should be 1") - } - - // Replace - var category3 = Category{ - Name: "Category 3", - } - DB.Model(&post).Association("Category").Replace(&category3) - - if category3.ID == 0 { - t.Errorf("Category should has ID when created with Replace") - } - - var category31 Category - DB.Model(&post).Related(&category31) - if category31.Name != "Category 3" { - t.Errorf("Category should be updated with Replace") - } - - if DB.Model(&post).Association("Category").Count() != 1 { - t.Errorf("Post's category count should be 1") - } - - // Delete - DB.Model(&post).Association("Category").Delete(&category2) - if DB.Model(&post).Related(&Category{}).RecordNotFound() { - t.Errorf("Should not delete any category when Delete a unrelated Category") - } - - if post.Category.Name == "" { - t.Errorf("Post's category should not be reseted when Delete a unrelated Category") - } - - DB.Model(&post).Association("Category").Delete(&category3) - - if post.Category.Name != "" { - t.Errorf("Post's category should be reseted after Delete") - } - - var category41 Category - DB.Model(&post).Related(&category41) - if category41.Name != "" { - t.Errorf("Category should be deleted with Delete") - } - - if count := DB.Model(&post).Association("Category").Count(); count != 0 { - t.Errorf("Post's category count should be 0 after Delete, but got %v", count) - } - - // Clear - DB.Model(&post).Association("Category").Append(&Category{ - Name: "Category 2", - }) - - if DB.Model(&post).Related(&Category{}).RecordNotFound() { - t.Errorf("Should find category after append") - } - - if post.Category.Name == "" { - t.Errorf("Post's category should has value after Append") - } - - DB.Model(&post).Association("Category").Clear() - - if post.Category.Name != "" { - t.Errorf("Post's category should be cleared after Clear") - } - - if !DB.Model(&post).Related(&Category{}).RecordNotFound() { - t.Errorf("Should not find any category after Clear") - } - - if count := DB.Model(&post).Association("Category").Count(); count != 0 { - t.Errorf("Post's category count should be 0 after Clear, but got %v", count) - } - - // Check Association mode with soft delete - category6 := Category{ - Name: "Category 6", - } - DB.Model(&post).Association("Category").Append(&category6) - - if count := DB.Model(&post).Association("Category").Count(); count != 1 { - t.Errorf("Post's category count should be 1 after Append, but got %v", count) - } - - DB.Delete(&category6) - - if count := DB.Model(&post).Association("Category").Count(); count != 0 { - t.Errorf("Post's category count should be 0 after the category has been deleted, but got %v", count) - } - - if err := DB.Model(&post).Association("Category").Find(&Category{}).Error; err == nil { - t.Errorf("Post's category is not findable after Delete") - } - - if count := DB.Unscoped().Model(&post).Association("Category").Count(); count != 1 { - t.Errorf("Post's category count should be 1 when query with Unscoped, but got %v", count) - } - - if err := DB.Unscoped().Model(&post).Association("Category").Find(&Category{}).Error; err != nil { - t.Errorf("Post's category should be findable when query with Unscoped, got %v", err) - } -} - -func TestBelongsToOverrideForeignKey1(t *testing.T) { - type Profile struct { - gorm.Model - Name string - } - - type User struct { - gorm.Model - Profile Profile `gorm:"ForeignKey:ProfileRefer"` - ProfileRefer int - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "belongs_to" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileRefer"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestBelongsToOverrideForeignKey2(t *testing.T) { - type Profile struct { - gorm.Model - Refer string - Name string - } - - type User struct { - gorm.Model - Profile Profile `gorm:"ForeignKey:ProfileID;AssociationForeignKey:Refer"` - ProfileID int - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "belongs_to" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileID"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasOne(t *testing.T) { - user := User{ - Name: "has one", - CreditCard: CreditCard{Number: "411111111111"}, - } - - if err := DB.Save(&user).Error; err != nil { - t.Error("Got errors when save user", err.Error()) - } - - if user.CreditCard.UserId.Int64 == 0 { - t.Errorf("CreditCard's foreign key should be updated") - } - - // Query - var creditCard1 CreditCard - DB.Model(&user).Related(&creditCard1) - - if creditCard1.Number != "411111111111" { - t.Errorf("Query has one relations with Related") - } - - var creditCard11 CreditCard - DB.Model(&user).Association("CreditCard").Find(&creditCard11) - - if creditCard11.Number != "411111111111" { - t.Errorf("Query has one relations with Related") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - // Append - var creditcard2 = CreditCard{ - Number: "411111111112", - } - DB.Model(&user).Association("CreditCard").Append(&creditcard2) - - if creditcard2.ID == 0 { - t.Errorf("Creditcard should has ID when created with Append") - } - - var creditcard21 CreditCard - DB.Model(&user).Related(&creditcard21) - if creditcard21.Number != "411111111112" { - t.Errorf("CreditCard should be updated with Append") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - // Replace - var creditcard3 = CreditCard{ - Number: "411111111113", - } - DB.Model(&user).Association("CreditCard").Replace(&creditcard3) - - if creditcard3.ID == 0 { - t.Errorf("Creditcard should has ID when created with Replace") - } - - var creditcard31 CreditCard - DB.Model(&user).Related(&creditcard31) - if creditcard31.Number != "411111111113" { - t.Errorf("CreditCard should be updated with Replace") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - // Delete - DB.Model(&user).Association("CreditCard").Delete(&creditcard2) - var creditcard4 CreditCard - DB.Model(&user).Related(&creditcard4) - if creditcard4.Number != "411111111113" { - t.Errorf("Should not delete credit card when Delete a unrelated CreditCard") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - DB.Model(&user).Association("CreditCard").Delete(&creditcard3) - if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { - t.Errorf("Should delete credit card with Delete") - } - - if DB.Model(&user).Association("CreditCard").Count() != 0 { - t.Errorf("User's credit card count should be 0 after Delete") - } - - // Clear - var creditcard5 = CreditCard{ - Number: "411111111115", - } - DB.Model(&user).Association("CreditCard").Append(&creditcard5) - - if DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { - t.Errorf("Should added credit card with Append") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - DB.Model(&user).Association("CreditCard").Clear() - if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { - t.Errorf("Credit card should be deleted with Clear") - } - - if DB.Model(&user).Association("CreditCard").Count() != 0 { - t.Errorf("User's credit card count should be 0 after Clear") - } - - // Check Association mode with soft delete - var creditcard6 = CreditCard{ - Number: "411111111116", - } - DB.Model(&user).Association("CreditCard").Append(&creditcard6) - - if count := DB.Model(&user).Association("CreditCard").Count(); count != 1 { - t.Errorf("User's credit card count should be 1 after Append, but got %v", count) - } - - DB.Delete(&creditcard6) - - if count := DB.Model(&user).Association("CreditCard").Count(); count != 0 { - t.Errorf("User's credit card count should be 0 after credit card deleted, but got %v", count) - } - - if err := DB.Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err == nil { - t.Errorf("User's creditcard is not findable after Delete") - } - - if count := DB.Unscoped().Model(&user).Association("CreditCard").Count(); count != 1 { - t.Errorf("User's credit card count should be 1 when query with Unscoped, but got %v", count) - } - - if err := DB.Unscoped().Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err != nil { - t.Errorf("User's creditcard should be findable when query with Unscoped, got %v", err) - } -} - -func TestHasOneOverrideForeignKey1(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserRefer uint - } - - type User struct { - gorm.Model - Profile Profile `gorm:"ForeignKey:UserRefer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_one" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasOneOverrideForeignKey2(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserID uint - } - - type User struct { - gorm.Model - Refer string - Profile Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_one" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasMany(t *testing.T) { - post := Post{ - Title: "post has many", - Body: "body has many", - Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}}, - } - - if err := DB.Save(&post).Error; err != nil { - t.Error("Got errors when save post", err) - } - - for _, comment := range post.Comments { - if comment.PostId == 0 { - t.Errorf("comment's PostID should be updated") - } - } - - var compareComments = func(comments []Comment, contents []string) bool { - var commentContents []string - for _, comment := range comments { - commentContents = append(commentContents, comment.Content) - } - sort.Strings(commentContents) - sort.Strings(contents) - return reflect.DeepEqual(commentContents, contents) - } - - // Query - if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil { - t.Errorf("Comment 1 should be saved") - } - - var comments1 []Comment - DB.Model(&post).Association("Comments").Find(&comments1) - if !compareComments(comments1, []string{"Comment 1", "Comment 2"}) { - t.Errorf("Query has many relations with Association") - } - - var comments11 []Comment - DB.Model(&post).Related(&comments11) - if !compareComments(comments11, []string{"Comment 1", "Comment 2"}) { - t.Errorf("Query has many relations with Related") - } - - if DB.Model(&post).Association("Comments").Count() != 2 { - t.Errorf("Post's comments count should be 2") - } - - // Append - DB.Model(&post).Association("Comments").Append(&Comment{Content: "Comment 3"}) - - var comments2 []Comment - DB.Model(&post).Related(&comments2) - if !compareComments(comments2, []string{"Comment 1", "Comment 2", "Comment 3"}) { - t.Errorf("Append new record to has many relations") - } - - if DB.Model(&post).Association("Comments").Count() != 3 { - t.Errorf("Post's comments count should be 3 after Append") - } - - // Delete - DB.Model(&post).Association("Comments").Delete(comments11) - - var comments3 []Comment - DB.Model(&post).Related(&comments3) - if !compareComments(comments3, []string{"Comment 3"}) { - t.Errorf("Delete an existing resource for has many relations") - } - - if DB.Model(&post).Association("Comments").Count() != 1 { - t.Errorf("Post's comments count should be 1 after Delete 2") - } - - // Replace - DB.Model(&Post{Id: 999}).Association("Comments").Replace() - - var comments4 []Comment - DB.Model(&post).Related(&comments4) - if len(comments4) == 0 { - t.Errorf("Replace for other resource should not clear all comments") - } - - DB.Model(&post).Association("Comments").Replace(&Comment{Content: "Comment 4"}, &Comment{Content: "Comment 5"}) - - var comments41 []Comment - DB.Model(&post).Related(&comments41) - if !compareComments(comments41, []string{"Comment 4", "Comment 5"}) { - t.Errorf("Replace has many relations") - } - - // Clear - DB.Model(&Post{Id: 999}).Association("Comments").Clear() - - var comments5 []Comment - DB.Model(&post).Related(&comments5) - if len(comments5) == 0 { - t.Errorf("Clear should not clear all comments") - } - - DB.Model(&post).Association("Comments").Clear() - - var comments51 []Comment - DB.Model(&post).Related(&comments51) - if len(comments51) != 0 { - t.Errorf("Clear has many relations") - } - - // Check Association mode with soft delete - var comment6 = Comment{ - Content: "comment 6", - } - DB.Model(&post).Association("Comments").Append(&comment6) - - if count := DB.Model(&post).Association("Comments").Count(); count != 1 { - t.Errorf("post's comments count should be 1 after Append, but got %v", count) - } - - DB.Delete(&comment6) - - if count := DB.Model(&post).Association("Comments").Count(); count != 0 { - t.Errorf("post's comments count should be 0 after comment been deleted, but got %v", count) - } - - var comments6 []Comment - if DB.Model(&post).Association("Comments").Find(&comments6); len(comments6) != 0 { - t.Errorf("post's comments count should be 0 when find with Find, but got %v", len(comments6)) - } - - if count := DB.Unscoped().Model(&post).Association("Comments").Count(); count != 1 { - t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", count) - } - - var comments61 []Comment - if DB.Unscoped().Model(&post).Association("Comments").Find(&comments61); len(comments61) != 1 { - t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", len(comments61)) - } -} - -func TestHasManyOverrideForeignKey1(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserRefer uint - } - - type User struct { - gorm.Model - Profile []Profile `gorm:"ForeignKey:UserRefer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_many" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasManyOverrideForeignKey2(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserID uint - } - - type User struct { - gorm.Model - Refer string - Profile []Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_many" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestManyToMany(t *testing.T) { - DB.Raw("delete from languages") - var languages = []Language{{Name: "ZH"}, {Name: "EN"}} - user := User{Name: "Many2Many", Languages: languages} - DB.Save(&user) - - // Query - var newLanguages []Language - DB.Model(&user).Related(&newLanguages, "Languages") - if len(newLanguages) != len([]string{"ZH", "EN"}) { - t.Errorf("Query many to many relations") - } - - DB.Model(&user).Association("Languages").Find(&newLanguages) - if len(newLanguages) != len([]string{"ZH", "EN"}) { - t.Errorf("Should be able to find many to many relations") - } - - if DB.Model(&user).Association("Languages").Count() != len([]string{"ZH", "EN"}) { - t.Errorf("Count should return correct result") - } - - // Append - DB.Model(&user).Association("Languages").Append(&Language{Name: "DE"}) - if DB.Where("name = ?", "DE").First(&Language{}).RecordNotFound() { - t.Errorf("New record should be saved when append") - } - - languageA := Language{Name: "AA"} - DB.Save(&languageA) - DB.Model(&User{Id: user.Id}).Association("Languages").Append(&languageA) - - languageC := Language{Name: "CC"} - DB.Save(&languageC) - DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC}) - - DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}}) - - totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"} - - if DB.Model(&user).Association("Languages").Count() != len(totalLanguages) { - t.Errorf("All appended languages should be saved") - } - - // Delete - user.Languages = []Language{} - DB.Model(&user).Association("Languages").Find(&user.Languages) - - var language Language - DB.Where("name = ?", "EE").First(&language) - DB.Model(&user).Association("Languages").Delete(language, &language) - - if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 { - t.Errorf("Relations should be deleted with Delete") - } - if DB.Where("name = ?", "EE").First(&Language{}).RecordNotFound() { - t.Errorf("Language EE should not be deleted") - } - - DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages) - - user2 := User{Name: "Many2Many_User2", Languages: languages} - DB.Save(&user2) - - DB.Model(&user).Association("Languages").Delete(languages, &languages) - if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-3 || len(user.Languages) != len(totalLanguages)-3 { - t.Errorf("Relations should be deleted with Delete") - } - - if DB.Model(&user2).Association("Languages").Count() == 0 { - t.Errorf("Other user's relations should not be deleted") - } - - // Replace - var languageB Language - DB.Where("name = ?", "BB").First(&languageB) - DB.Model(&user).Association("Languages").Replace(languageB) - if len(user.Languages) != 1 || DB.Model(&user).Association("Languages").Count() != 1 { - t.Errorf("Relations should be replaced") - } - - DB.Model(&user).Association("Languages").Replace() - if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { - t.Errorf("Relations should be replaced with empty") - } - - DB.Model(&user).Association("Languages").Replace(&[]Language{{Name: "FF"}, {Name: "JJ"}}) - if len(user.Languages) != 2 || DB.Model(&user).Association("Languages").Count() != len([]string{"FF", "JJ"}) { - t.Errorf("Relations should be replaced") - } - - // Clear - DB.Model(&user).Association("Languages").Clear() - if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { - t.Errorf("Relations should be cleared") - } - - // Check Association mode with soft delete - var language6 = Language{ - Name: "language 6", - } - DB.Model(&user).Association("Languages").Append(&language6) - - if count := DB.Model(&user).Association("Languages").Count(); count != 1 { - t.Errorf("user's languages count should be 1 after Append, but got %v", count) - } - - DB.Delete(&language6) - - if count := DB.Model(&user).Association("Languages").Count(); count != 0 { - t.Errorf("user's languages count should be 0 after language been deleted, but got %v", count) - } - - var languages6 []Language - if DB.Model(&user).Association("Languages").Find(&languages6); len(languages6) != 0 { - t.Errorf("user's languages count should be 0 when find with Find, but got %v", len(languages6)) - } - - if count := DB.Unscoped().Model(&user).Association("Languages").Count(); count != 1 { - t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", count) - } - - var languages61 []Language - if DB.Unscoped().Model(&user).Association("Languages").Find(&languages61); len(languages61) != 1 { - t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", len(languages61)) - } -} - -func TestRelated(t *testing.T) { - user := User{ - Name: "jinzhu", - BillingAddress: Address{Address1: "Billing Address - Address 1"}, - ShippingAddress: Address{Address1: "Shipping Address - Address 1"}, - Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}}, - CreditCard: CreditCard{Number: "1234567890"}, - Company: Company{Name: "company1"}, - } - - if err := DB.Save(&user).Error; err != nil { - t.Errorf("No error should happen when saving user") - } - - if user.CreditCard.ID == 0 { - t.Errorf("After user save, credit card should have id") - } - - if user.BillingAddress.ID == 0 { - t.Errorf("After user save, billing address should have id") - } - - if user.Emails[0].Id == 0 { - t.Errorf("After user save, billing address should have id") - } - - var emails []Email - DB.Model(&user).Related(&emails) - if len(emails) != 2 { - t.Errorf("Should have two emails") - } - - var emails2 []Email - DB.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2) - if len(emails2) != 1 { - t.Errorf("Should have two emails") - } - - var emails3 []*Email - DB.Model(&user).Related(&emails3) - if len(emails3) != 2 { - t.Errorf("Should have two emails") - } - - var user1 User - DB.Model(&user).Related(&user1.Emails) - if len(user1.Emails) != 2 { - t.Errorf("Should have only one email match related condition") - } - - var address1 Address - DB.Model(&user).Related(&address1, "BillingAddressId") - if address1.Address1 != "Billing Address - Address 1" { - t.Errorf("Should get billing address from user correctly") - } - - user1 = User{} - DB.Model(&address1).Related(&user1, "BillingAddressId") - if DB.NewRecord(user1) { - t.Errorf("Should get user from address correctly") - } - - var user2 User - DB.Model(&emails[0]).Related(&user2) - if user2.Id != user.Id || user2.Name != user.Name { - t.Errorf("Should get user from email correctly") - } - - var creditcard CreditCard - var user3 User - DB.First(&creditcard, "number = ?", "1234567890") - DB.Model(&creditcard).Related(&user3) - if user3.Id != user.Id || user3.Name != user.Name { - t.Errorf("Should get user from credit card correctly") - } - - if !DB.Model(&CreditCard{}).Related(&User{}).RecordNotFound() { - t.Errorf("RecordNotFound for Related") - } - - var company Company - if DB.Model(&user).Related(&company, "Company").RecordNotFound() || company.Name != "company1" { - t.Errorf("RecordNotFound for Related") - } -} - -func TestForeignKey(t *testing.T) { - for _, structField := range DB.NewScope(&User{}).GetStructFields() { - for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } - - for _, structField := range DB.NewScope(&Email{}).GetStructFields() { - for _, foreignKey := range []string{"UserId"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } - - for _, structField := range DB.NewScope(&Post{}).GetStructFields() { - for _, foreignKey := range []string{"CategoryId", "MainCategoryId"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } - - for _, structField := range DB.NewScope(&Comment{}).GetStructFields() { - for _, foreignKey := range []string{"PostId"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } -} - -func testForeignKey(t *testing.T, source interface{}, sourceFieldName string, target interface{}, targetFieldName string) { - if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { - // sqlite does not support ADD CONSTRAINT in ALTER TABLE - return - } - targetScope := DB.NewScope(target) - targetTableName := targetScope.TableName() - modelScope := DB.NewScope(source) - modelField, ok := modelScope.FieldByName(sourceFieldName) - if !ok { - t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", sourceFieldName)) - } - targetField, ok := targetScope.FieldByName(targetFieldName) - if !ok { - t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", targetFieldName)) - } - dest := fmt.Sprintf("%v(%v)", targetTableName, targetField.DBName) - err := DB.Model(source).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error - if err != nil { - t.Fatalf(fmt.Sprintf("Failed to create foreign key: %v", err)) - } -} - -func TestLongForeignKey(t *testing.T) { - testForeignKey(t, &NotSoLongTableName{}, "ReallyLongThingID", &ReallyLongTableNameToTestMySQLNameLengthLimit{}, "ID") -} - -func TestLongForeignKeyWithShortDest(t *testing.T) { - testForeignKey(t, &ReallyLongThingThatReferencesShort{}, "ShortID", &Short{}, "ID") -} - -func TestHasManyChildrenWithOneStruct(t *testing.T) { - category := Category{ - Name: "main", - Categories: []Category{ - {Name: "sub1"}, - {Name: "sub2"}, - }, - } - - DB.Save(&category) -} - -func TestAutoSaveBelongsToAssociation(t *testing.T) { - type Company struct { - gorm.Model - Name string - } - - type User struct { - gorm.Model - Name string - CompanyID uint - Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"` - } - - DB.Where("name = ?", "auto_save_association").Delete(&Company{}) - DB.AutoMigrate(&Company{}, &User{}) - - DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_association"}}) - - if !DB.Where("name = ?", "auto_save_association").First(&Company{}).RecordNotFound() { - t.Errorf("Company auto_save_association should not have been saved when autosave is false") - } - - // if foreign key is set, this should be saved even if association isn't - company := Company{Name: "auto_save_association"} - DB.Save(&company) - - company.Name = "auto_save_association_new_name" - user := User{Name: "jinzhu", Company: company} - - DB.Save(&user) - - if !DB.Where("name = ?", "auto_save_association_new_name").First(&Company{}).RecordNotFound() { - t.Errorf("Company should not have been updated") - } - - if DB.Where("id = ? AND company_id = ?", user.ID, company.ID).First(&User{}).RecordNotFound() { - t.Errorf("User's foreign key should have been saved") - } - - user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_association_2"}} - DB.Set("gorm:association_autocreate", true).Save(&user2) - if DB.Where("name = ?", "auto_save_association_2").First(&Company{}).RecordNotFound() { - t.Errorf("Company auto_save_association_2 should been created when autocreate is true") - } - - user2.Company.Name = "auto_save_association_2_newname" - DB.Set("gorm:association_autoupdate", true).Save(&user2) - - if DB.Where("name = ?", "auto_save_association_2_newname").First(&Company{}).RecordNotFound() { - t.Errorf("Company should been updated") - } -} - -func TestAutoSaveHasOneAssociation(t *testing.T) { - type Company struct { - gorm.Model - UserID uint - Name string - } - - type User struct { - gorm.Model - Name string - Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"` - } - - DB.Where("name = ?", "auto_save_has_one_association").Delete(&Company{}) - DB.AutoMigrate(&Company{}, &User{}) - - DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_has_one_association"}}) - - if !DB.Where("name = ?", "auto_save_has_one_association").First(&Company{}).RecordNotFound() { - t.Errorf("Company auto_save_has_one_association should not have been saved when autosave is false") - } - - company := Company{Name: "auto_save_has_one_association"} - DB.Save(&company) - - company.Name = "auto_save_has_one_association_new_name" - user := User{Name: "jinzhu", Company: company} - - DB.Save(&user) - - if !DB.Where("name = ?", "auto_save_has_one_association_new_name").First(&Company{}).RecordNotFound() { - t.Errorf("Company should not have been updated") - } - - if !DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association", user.ID).First(&Company{}).RecordNotFound() { - t.Errorf("Company should not have been updated") - } - - if user.Company.UserID == 0 { - t.Errorf("UserID should be assigned") - } - - company.Name = "auto_save_has_one_association_2_new_name" - DB.Set("gorm:association_autoupdate", true).Save(&user) - - if DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association_new_name", user.ID).First(&Company{}).RecordNotFound() { - t.Errorf("Company should been updated") - } - - user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_has_one_association_2"}} - DB.Set("gorm:association_autocreate", true).Save(&user2) - if DB.Where("name = ?", "auto_save_has_one_association_2").First(&Company{}).RecordNotFound() { - t.Errorf("Company auto_save_has_one_association_2 should been created when autocreate is true") - } -} - -func TestAutoSaveMany2ManyAssociation(t *testing.T) { - type Company struct { - gorm.Model - Name string - } - - type User struct { - gorm.Model - Name string - Companies []Company `gorm:"many2many:user_companies;association_autoupdate:false;association_autocreate:false;"` - } - - DB.AutoMigrate(&Company{}, &User{}) - - DB.Save(&User{Name: "jinzhu", Companies: []Company{{Name: "auto_save_m2m_association"}}}) - - if !DB.Where("name = ?", "auto_save_m2m_association").First(&Company{}).RecordNotFound() { - t.Errorf("Company auto_save_m2m_association should not have been saved when autosave is false") - } - - company := Company{Name: "auto_save_m2m_association"} - DB.Save(&company) - - company.Name = "auto_save_m2m_association_new_name" - user := User{Name: "jinzhu", Companies: []Company{company, {Name: "auto_save_m2m_association_new_name_2"}}} - - DB.Save(&user) - - if !DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() { - t.Errorf("Company should not have been updated") - } - - if !DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() { - t.Errorf("Company should not been created") - } - - if DB.Model(&user).Association("Companies").Count() != 1 { - t.Errorf("Relationship should been saved") - } - - DB.Set("gorm:association_autoupdate", true).Set("gorm:association_autocreate", true).Save(&user) - - if DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() { - t.Errorf("Company should been updated") - } - - if DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() { - t.Errorf("Company should been created") - } - - if DB.Model(&user).Association("Companies").Count() != 2 { - t.Errorf("Relationship should been updated") - } -} diff --git a/callback.go b/callback.go deleted file mode 100644 index 1f0e3c79..00000000 --- a/callback.go +++ /dev/null @@ -1,250 +0,0 @@ -package gorm - -import "fmt" - -// DefaultCallback default callbacks defined by gorm -var DefaultCallback = &Callback{logger: nopLogger{}} - -// Callback is a struct that contains all CRUD callbacks -// Field `creates` contains callbacks will be call when creating object -// Field `updates` contains callbacks will be call when updating object -// Field `deletes` contains callbacks will be call when deleting object -// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association... -// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows... -// Field `processors` contains all callback processors, will be used to generate above callbacks in order -type Callback struct { - logger logger - creates []*func(scope *Scope) - updates []*func(scope *Scope) - deletes []*func(scope *Scope) - queries []*func(scope *Scope) - rowQueries []*func(scope *Scope) - processors []*CallbackProcessor -} - -// CallbackProcessor contains callback informations -type CallbackProcessor struct { - logger logger - name string // current callback's name - before string // register current callback before a callback - after string // register current callback after a callback - replace bool // replace callbacks with same name - remove bool // delete callbacks with same name - kind string // callback type: create, update, delete, query, row_query - processor *func(scope *Scope) // callback handler - parent *Callback -} - -func (c *Callback) clone(logger logger) *Callback { - return &Callback{ - logger: logger, - creates: c.creates, - updates: c.updates, - deletes: c.deletes, - queries: c.queries, - rowQueries: c.rowQueries, - processors: c.processors, - } -} - -// Create could be used to register callbacks for creating object -// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) { -// // business logic -// ... -// -// // set error if some thing wrong happened, will rollback the creating -// scope.Err(errors.New("error")) -// }) -func (c *Callback) Create() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "create", parent: c} -} - -// Update could be used to register callbacks for updating object, refer `Create` for usage -func (c *Callback) Update() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "update", parent: c} -} - -// Delete could be used to register callbacks for deleting object, refer `Create` for usage -func (c *Callback) Delete() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "delete", parent: c} -} - -// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`... -// Refer `Create` for usage -func (c *Callback) Query() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "query", parent: c} -} - -// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage -func (c *Callback) RowQuery() *CallbackProcessor { - return &CallbackProcessor{logger: c.logger, kind: "row_query", parent: c} -} - -// After insert a new callback after callback `callbackName`, refer `Callbacks.Create` -func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor { - cp.after = callbackName - return cp -} - -// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create` -func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { - cp.before = callbackName - return cp -} - -// Register a new callback, refer `Callbacks.Create` -func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { - if cp.kind == "row_query" { - if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { - cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName)) - cp.before = "gorm:row_query" - } - } - - cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum())) - cp.name = callbackName - cp.processor = &callback - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Remove a registered callback -// db.Callback().Create().Remove("gorm:update_time_stamp_when_create") -func (cp *CallbackProcessor) Remove(callbackName string) { - cp.logger.Print("info", fmt.Sprintf("[info] removing callback `%v` from %v", callbackName, fileWithLineNum())) - cp.name = callbackName - cp.remove = true - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Replace a registered callback with new callback -// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) { -// scope.SetColumn("CreatedAt", now) -// scope.SetColumn("UpdatedAt", now) -// }) -func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { - cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum())) - cp.name = callbackName - cp.processor = &callback - cp.replace = true - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Get registered callback -// db.Callback().Create().Get("gorm:create") -func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { - for _, p := range cp.parent.processors { - if p.name == callbackName && p.kind == cp.kind { - if p.remove { - callback = nil - } else { - callback = *p.processor - } - } - } - return -} - -// getRIndex get right index from string slice -func getRIndex(strs []string, str string) int { - for i := len(strs) - 1; i >= 0; i-- { - if strs[i] == str { - return i - } - } - return -1 -} - -// sortProcessors sort callback processors based on its before, after, remove, replace -func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { - var ( - allNames, sortedNames []string - sortCallbackProcessor func(c *CallbackProcessor) - ) - - for _, cp := range cps { - // show warning message the callback name already exists - if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { - cp.logger.Print("warning", fmt.Sprintf("[warning] duplicated callback `%v` from %v", cp.name, fileWithLineNum())) - } - allNames = append(allNames, cp.name) - } - - sortCallbackProcessor = func(c *CallbackProcessor) { - if getRIndex(sortedNames, c.name) == -1 { // if not sorted - if c.before != "" { // if defined before callback - if index := getRIndex(sortedNames, c.before); index != -1 { - // if before callback already sorted, append current callback just after it - sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) - } else if index := getRIndex(allNames, c.before); index != -1 { - // if before callback exists but haven't sorted, append current callback to last - sortedNames = append(sortedNames, c.name) - sortCallbackProcessor(cps[index]) - } - } - - if c.after != "" { // if defined after callback - if index := getRIndex(sortedNames, c.after); index != -1 { - // if after callback already sorted, append current callback just before it - sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) - } else if index := getRIndex(allNames, c.after); index != -1 { - // if after callback exists but haven't sorted - cp := cps[index] - // set after callback's before callback to current callback - if cp.before == "" { - cp.before = c.name - } - sortCallbackProcessor(cp) - } - } - - // if current callback haven't been sorted, append it to last - if getRIndex(sortedNames, c.name) == -1 { - sortedNames = append(sortedNames, c.name) - } - } - } - - for _, cp := range cps { - sortCallbackProcessor(cp) - } - - var sortedFuncs []*func(scope *Scope) - for _, name := range sortedNames { - if index := getRIndex(allNames, name); !cps[index].remove { - sortedFuncs = append(sortedFuncs, cps[index].processor) - } - } - - return sortedFuncs -} - -// reorder all registered processors, and reset CRUD callbacks -func (c *Callback) reorder() { - var creates, updates, deletes, queries, rowQueries []*CallbackProcessor - - for _, processor := range c.processors { - if processor.name != "" { - switch processor.kind { - case "create": - creates = append(creates, processor) - case "update": - updates = append(updates, processor) - case "delete": - deletes = append(deletes, processor) - case "query": - queries = append(queries, processor) - case "row_query": - rowQueries = append(rowQueries, processor) - } - } - } - - c.creates = sortProcessors(creates) - c.updates = sortProcessors(updates) - c.deletes = sortProcessors(deletes) - c.queries = sortProcessors(queries) - c.rowQueries = sortProcessors(rowQueries) -} diff --git a/callback_create.go b/callback_create.go deleted file mode 100644 index c4d25f37..00000000 --- a/callback_create.go +++ /dev/null @@ -1,197 +0,0 @@ -package gorm - -import ( - "fmt" - "strings" -) - -// Define callbacks for creating -func init() { - DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback) - DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback) - DefaultCallback.Create().Register("gorm:create", createCallback) - DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback) - DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback) - DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating -func beforeCreateCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("BeforeSave") - } - if !scope.HasError() { - scope.CallMethod("BeforeCreate") - } -} - -// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating -func updateTimeStampForCreateCallback(scope *Scope) { - if !scope.HasError() { - now := scope.db.nowFunc() - - if createdAtField, ok := scope.FieldByName("CreatedAt"); ok { - if createdAtField.IsBlank { - createdAtField.Set(now) - } - } - - if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok { - if updatedAtField.IsBlank { - updatedAtField.Set(now) - } - } - } -} - -// createCallback the callback used to insert data into database -func createCallback(scope *Scope) { - if !scope.HasError() { - defer scope.trace(NowFunc()) - - var ( - columns, placeholders []string - blankColumnsWithDefaultValue []string - ) - - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if field.IsNormal && !field.IsIgnored { - if field.IsBlank && field.HasDefaultValue { - blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) - scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) - } else if !field.IsPrimaryKey || !field.IsBlank { - columns = append(columns, scope.Quote(field.DBName)) - placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) - } - } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { - for _, foreignKey := range field.Relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - columns = append(columns, scope.Quote(foreignField.DBName)) - placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) - } - } - } - } - } - - var ( - returningColumn = "*" - quotedTableName = scope.QuotedTableName() - primaryField = scope.PrimaryField() - extraOption string - insertModifier string - ) - - if str, ok := scope.Get("gorm:insert_option"); ok { - extraOption = fmt.Sprint(str) - } - if str, ok := scope.Get("gorm:insert_modifier"); ok { - insertModifier = strings.ToUpper(fmt.Sprint(str)) - if insertModifier == "INTO" { - insertModifier = "" - } - } - - if primaryField != nil { - returningColumn = scope.Quote(primaryField.DBName) - } - - lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns) - var lastInsertIDReturningSuffix string - if lastInsertIDOutputInterstitial == "" { - lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) - } - - if len(columns) == 0 { - scope.Raw(fmt.Sprintf( - "INSERT%v INTO %v %v%v%v", - addExtraSpaceIfExist(insertModifier), - quotedTableName, - scope.Dialect().DefaultValueStr(), - addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIDReturningSuffix), - )) - } else { - scope.Raw(fmt.Sprintf( - "INSERT%v INTO %v (%v)%v VALUES (%v)%v%v", - addExtraSpaceIfExist(insertModifier), - scope.QuotedTableName(), - strings.Join(columns, ","), - addExtraSpaceIfExist(lastInsertIDOutputInterstitial), - strings.Join(placeholders, ","), - addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIDReturningSuffix), - )) - } - - // execute create sql: no primaryField - if primaryField == nil { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - // set rows affected count - scope.db.RowsAffected, _ = result.RowsAffected() - - // set primary value to primary field - if primaryField != nil && primaryField.IsBlank { - if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { - scope.Err(primaryField.Set(primaryValue)) - } - } - } - return - } - - // execute create sql: lastInsertID implemention for majority of dialects - if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - // set rows affected count - scope.db.RowsAffected, _ = result.RowsAffected() - - // set primary value to primary field - if primaryField != nil && primaryField.IsBlank { - if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { - scope.Err(primaryField.Set(primaryValue)) - } - } - } - return - } - - // execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql) - if primaryField.Field.CanAddr() { - if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { - primaryField.IsBlank = false - scope.db.RowsAffected = 1 - } - } else { - scope.Err(ErrUnaddressable) - } - return - } -} - -// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object -func forceReloadAfterCreateCallback(scope *Scope) { - if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { - db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string)) - for _, field := range scope.Fields() { - if field.IsPrimaryKey && !field.IsBlank { - db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface()) - } - } - db.Scan(scope.Value) - } -} - -// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating -func afterCreateCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterCreate") - } - if !scope.HasError() { - scope.CallMethod("AfterSave") - } -} diff --git a/callback_delete.go b/callback_delete.go deleted file mode 100644 index 48b97acb..00000000 --- a/callback_delete.go +++ /dev/null @@ -1,63 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" -) - -// Define callbacks for deleting -func init() { - DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback) - DefaultCallback.Delete().Register("gorm:delete", deleteCallback) - DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback) - DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// beforeDeleteCallback will invoke `BeforeDelete` method before deleting -func beforeDeleteCallback(scope *Scope) { - if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { - scope.Err(errors.New("missing WHERE clause while deleting")) - return - } - if !scope.HasError() { - scope.CallMethod("BeforeDelete") - } -} - -// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete) -func deleteCallback(scope *Scope) { - if !scope.HasError() { - var extraOption string - if str, ok := scope.Get("gorm:delete_option"); ok { - extraOption = fmt.Sprint(str) - } - - deletedAtField, hasDeletedAtField := scope.FieldByName("DeletedAt") - - if !scope.Search.Unscoped && hasDeletedAtField { - scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v=%v%v%v", - scope.QuotedTableName(), - scope.Quote(deletedAtField.DBName), - scope.AddToVars(scope.db.nowFunc()), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } else { - scope.Raw(fmt.Sprintf( - "DELETE FROM %v%v%v", - scope.QuotedTableName(), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } - } -} - -// afterDeleteCallback will invoke `AfterDelete` method after deleting -func afterDeleteCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterDelete") - } -} diff --git a/callback_query.go b/callback_query.go deleted file mode 100644 index 544afd63..00000000 --- a/callback_query.go +++ /dev/null @@ -1,109 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" -) - -// Define callbacks for querying -func init() { - DefaultCallback.Query().Register("gorm:query", queryCallback) - DefaultCallback.Query().Register("gorm:preload", preloadCallback) - DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback) -} - -// queryCallback used to query data from database -func queryCallback(scope *Scope) { - if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { - return - } - - //we are only preloading relations, dont touch base model - if _, skip := scope.InstanceGet("gorm:only_preload"); skip { - return - } - - defer scope.trace(NowFunc()) - - var ( - isSlice, isPtr bool - resultType reflect.Type - results = scope.IndirectValue() - ) - - if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { - if primaryField := scope.PrimaryField(); primaryField != nil { - scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy)) - } - } - - if value, ok := scope.Get("gorm:query_destination"); ok { - results = indirect(reflect.ValueOf(value)) - } - - if kind := results.Kind(); kind == reflect.Slice { - isSlice = true - resultType = results.Type().Elem() - results.Set(reflect.MakeSlice(results.Type(), 0, 0)) - - if resultType.Kind() == reflect.Ptr { - isPtr = true - resultType = resultType.Elem() - } - } else if kind != reflect.Struct { - scope.Err(errors.New("unsupported destination, should be slice or struct")) - return - } - - scope.prepareQuerySQL() - - if !scope.HasError() { - scope.db.RowsAffected = 0 - - if str, ok := scope.Get("gorm:query_hint"); ok { - scope.SQL = fmt.Sprint(str) + scope.SQL - } - - if str, ok := scope.Get("gorm:query_option"); ok { - scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) - } - - if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - scope.db.RowsAffected++ - - elem := results - if isSlice { - elem = reflect.New(resultType).Elem() - } - - scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) - - if isSlice { - if isPtr { - results.Set(reflect.Append(results, elem.Addr())) - } else { - results.Set(reflect.Append(results, elem)) - } - } - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } else if scope.db.RowsAffected == 0 && !isSlice { - scope.Err(ErrRecordNotFound) - } - } - } -} - -// afterQueryCallback will invoke `AfterFind` method after querying -func afterQueryCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterFind") - } -} diff --git a/callback_query_preload.go b/callback_query_preload.go deleted file mode 100644 index a936180a..00000000 --- a/callback_query_preload.go +++ /dev/null @@ -1,410 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" - "strconv" - "strings" -) - -// preloadCallback used to preload associations -func preloadCallback(scope *Scope) { - if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { - return - } - - if ap, ok := scope.Get("gorm:auto_preload"); ok { - // If gorm:auto_preload IS NOT a bool then auto preload. - // Else if it IS a bool, use the value - if apb, ok := ap.(bool); !ok { - autoPreload(scope) - } else if apb { - autoPreload(scope) - } - } - - if scope.Search.preload == nil || scope.HasError() { - return - } - - var ( - preloadedMap = map[string]bool{} - fields = scope.Fields() - ) - - for _, preload := range scope.Search.preload { - var ( - preloadFields = strings.Split(preload.schema, ".") - currentScope = scope - currentFields = fields - ) - - for idx, preloadField := range preloadFields { - var currentPreloadConditions []interface{} - - if currentScope == nil { - continue - } - - // if not preloaded - if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { - - // assign search conditions to last preload - if idx == len(preloadFields)-1 { - currentPreloadConditions = preload.conditions - } - - for _, field := range currentFields { - if field.Name != preloadField || field.Relationship == nil { - continue - } - - switch field.Relationship.Kind { - case "has_one": - currentScope.handleHasOnePreload(field, currentPreloadConditions) - case "has_many": - currentScope.handleHasManyPreload(field, currentPreloadConditions) - case "belongs_to": - currentScope.handleBelongsToPreload(field, currentPreloadConditions) - case "many_to_many": - currentScope.handleManyToManyPreload(field, currentPreloadConditions) - default: - scope.Err(errors.New("unsupported relation")) - } - - preloadedMap[preloadKey] = true - break - } - - if !preloadedMap[preloadKey] { - scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType)) - return - } - } - - // preload next level - if idx < len(preloadFields)-1 { - currentScope = currentScope.getColumnAsScope(preloadField) - if currentScope != nil { - currentFields = currentScope.Fields() - } - } - } - } -} - -func autoPreload(scope *Scope) { - for _, field := range scope.Fields() { - if field.Relationship == nil { - continue - } - - if val, ok := field.TagSettingsGet("PRELOAD"); ok { - if preload, err := strconv.ParseBool(val); err != nil { - scope.Err(errors.New("invalid preload option")) - return - } else if !preload { - continue - } - } - - scope.Search.Preload(field.Name) - } -} - -func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) { - var ( - preloadDB = scope.NewDB() - preloadConditions []interface{} - ) - - for _, condition := range conditions { - if scopes, ok := condition.(func(*DB) *DB); ok { - preloadDB = scopes(preloadDB) - } else { - preloadConditions = append(preloadConditions, condition) - } - } - - return preloadDB, preloadConditions -} - -// handleHasOnePreload used to preload has one associations -func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // find relations - query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) - values := toQueryValues(primaryKeys) - if relation.PolymorphicType != "" { - query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, relation.PolymorphicValue) - } - - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - if indirectScopeValue.Kind() == reflect.Slice { - foreignValuesToResults := make(map[string]reflect.Value) - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames)) - foreignValuesToResults[foreignValues] = result - } - for j := 0; j < indirectScopeValue.Len(); j++ { - indirectValue := indirect(indirectScopeValue.Index(j)) - valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames)) - if result, found := foreignValuesToResults[valueString]; found { - indirectValue.FieldByName(field.Name).Set(result) - } - } - } else { - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - scope.Err(field.Set(result)) - } - } -} - -// handleHasManyPreload used to preload has many associations -func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // find relations - query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) - values := toQueryValues(primaryKeys) - if relation.PolymorphicType != "" { - query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, relation.PolymorphicValue) - } - - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - if indirectScopeValue.Kind() == reflect.Slice { - preloadMap := make(map[string][]reflect.Value) - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result) - } - - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames) - f := object.FieldByName(field.Name) - if results, ok := preloadMap[toString(objectRealValue)]; ok { - f.Set(reflect.Append(f, results...)) - } else { - f.Set(reflect.MakeSlice(f.Type(), 0, 0)) - } - } - } else { - scope.Err(field.Set(resultsValue)) - } -} - -// handleBelongsToPreload used to preload belongs to associations -func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // find relations - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - foreignFieldToObjects := make(map[string][]*reflect.Value) - if indirectScopeValue.Kind() == reflect.Slice { - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - valueString := toString(getValueFromFields(object, relation.ForeignFieldNames)) - foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object) - } - } - - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - if indirectScopeValue.Kind() == reflect.Slice { - valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames)) - if objects, found := foreignFieldToObjects[valueString]; found { - for _, object := range objects { - object.FieldByName(field.Name).Set(result) - } - } - } else { - scope.Err(field.Set(result)) - } - } -} - -// handleManyToManyPreload used to preload many to many associations -func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { - var ( - relation = field.Relationship - joinTableHandler = relation.JoinTableHandler - fieldType = field.Struct.Type.Elem() - foreignKeyValue interface{} - foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type() - linkHash = map[string][]reflect.Value{} - isPtr bool - ) - - if fieldType.Kind() == reflect.Ptr { - isPtr = true - fieldType = fieldType.Elem() - } - - var sourceKeys = []string{} - for _, key := range joinTableHandler.SourceForeignKeys() { - sourceKeys = append(sourceKeys, key.DBName) - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // generate query with join table - newScope := scope.New(reflect.New(fieldType).Interface()) - preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value) - - if len(preloadDB.search.selects) == 0 { - preloadDB = preloadDB.Select("*") - } - - preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) - - // preload inline conditions - if len(preloadConditions) > 0 { - preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...) - } - - rows, err := preloadDB.Rows() - - if scope.Err(err) != nil { - return - } - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - var ( - elem = reflect.New(fieldType).Elem() - fields = scope.New(elem.Addr().Interface()).Fields() - ) - - // register foreign keys in join tables - var joinTableFields []*Field - for _, sourceKey := range sourceKeys { - joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()}) - } - - scope.scan(rows, columns, append(fields, joinTableFields...)) - - scope.New(elem.Addr().Interface()). - InstanceSet("gorm:skip_query_callback", true). - callCallbacks(scope.db.parent.callbacks.queries) - - var foreignKeys = make([]interface{}, len(sourceKeys)) - // generate hashed forkey keys in join table - for idx, joinTableField := range joinTableFields { - if !joinTableField.Field.IsNil() { - foreignKeys[idx] = joinTableField.Field.Elem().Interface() - } - } - hashedSourceKeys := toString(foreignKeys) - - if isPtr { - linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr()) - } else { - linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem) - } - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } - - // assign find results - var ( - indirectScopeValue = scope.IndirectValue() - fieldsSourceMap = map[string][]reflect.Value{} - foreignFieldNames = []string{} - ) - - for _, dbName := range relation.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - foreignFieldNames = append(foreignFieldNames, field.Name) - } - } - - if indirectScopeValue.Kind() == reflect.Slice { - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - key := toString(getValueFromFields(object, foreignFieldNames)) - fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name)) - } - } else if indirectScopeValue.IsValid() { - key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) - fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) - } - - for source, fields := range fieldsSourceMap { - for _, f := range fields { - //If not 0 this means Value is a pointer and we already added preloaded models to it - if f.Len() != 0 { - continue - } - - v := reflect.MakeSlice(f.Type(), 0, 0) - if len(linkHash[source]) > 0 { - v = reflect.Append(f, linkHash[source]...) - } - - f.Set(v) - } - } -} diff --git a/callback_row_query.go b/callback_row_query.go deleted file mode 100644 index 323b1605..00000000 --- a/callback_row_query.go +++ /dev/null @@ -1,41 +0,0 @@ -package gorm - -import ( - "database/sql" - "fmt" -) - -// Define callbacks for row query -func init() { - DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback) -} - -type RowQueryResult struct { - Row *sql.Row -} - -type RowsQueryResult struct { - Rows *sql.Rows - Error error -} - -// queryCallback used to query data from database -func rowQueryCallback(scope *Scope) { - if result, ok := scope.InstanceGet("row_query_result"); ok { - scope.prepareQuerySQL() - - if str, ok := scope.Get("gorm:query_hint"); ok { - scope.SQL = fmt.Sprint(str) + scope.SQL - } - - if str, ok := scope.Get("gorm:query_option"); ok { - scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) - } - - if rowResult, ok := result.(*RowQueryResult); ok { - rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) - } else if rowsResult, ok := result.(*RowsQueryResult); ok { - rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...) - } - } -} diff --git a/callback_save.go b/callback_save.go deleted file mode 100644 index 3b4e0589..00000000 --- a/callback_save.go +++ /dev/null @@ -1,170 +0,0 @@ -package gorm - -import ( - "reflect" - "strings" -) - -func beginTransactionCallback(scope *Scope) { - scope.Begin() -} - -func commitOrRollbackTransactionCallback(scope *Scope) { - scope.CommitOrRollback() -} - -func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) { - checkTruth := func(value interface{}) bool { - if v, ok := value.(bool); ok && !v { - return false - } - - if v, ok := value.(string); ok { - v = strings.ToLower(v) - return v == "true" - } - - return true - } - - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if r = field.Relationship; r != nil { - autoUpdate, autoCreate, saveReference = true, true, true - - if value, ok := scope.Get("gorm:save_associations"); ok { - autoUpdate = checkTruth(value) - autoCreate = autoUpdate - saveReference = autoUpdate - } else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok { - autoUpdate = checkTruth(value) - autoCreate = autoUpdate - saveReference = autoUpdate - } - - if value, ok := scope.Get("gorm:association_autoupdate"); ok { - autoUpdate = checkTruth(value) - } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok { - autoUpdate = checkTruth(value) - } - - if value, ok := scope.Get("gorm:association_autocreate"); ok { - autoCreate = checkTruth(value) - } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok { - autoCreate = checkTruth(value) - } - - if value, ok := scope.Get("gorm:association_save_reference"); ok { - saveReference = checkTruth(value) - } else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok { - saveReference = checkTruth(value) - } - } - } - - return -} - -func saveBeforeAssociationsCallback(scope *Scope) { - for _, field := range scope.Fields() { - autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) - - if relationship != nil && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() - newScope := scope.New(fieldValue) - - if newScope.PrimaryKeyZero() { - if autoCreate { - scope.Err(scope.NewDB().Save(fieldValue).Error) - } - } else if autoUpdate { - scope.Err(scope.NewDB().Save(fieldValue).Error) - } - - if saveReference { - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) - } - } - } - } - } - } -} - -func saveAfterAssociationsCallback(scope *Scope) { - for _, field := range scope.Fields() { - autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) - - if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { - value := field.Field - - switch value.Kind() { - case reflect.Slice: - for i := 0; i < value.Len(); i++ { - newDB := scope.NewDB() - elem := value.Index(i).Addr().Interface() - newScope := newDB.NewScope(elem) - - if saveReference { - if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - } - - if newScope.PrimaryKeyZero() { - if autoCreate { - scope.Err(newDB.Save(elem).Error) - } - } else if autoUpdate { - scope.Err(newDB.Save(elem).Error) - } - - if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference { - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { - scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) - } - } - } - default: - elem := value.Addr().Interface() - newScope := scope.New(elem) - - if saveReference { - if len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - } - - if newScope.PrimaryKeyZero() { - if autoCreate { - scope.Err(scope.NewDB().Save(elem).Error) - } - } else if autoUpdate { - scope.Err(scope.NewDB().Save(elem).Error) - } - } - } - } -} diff --git a/callback_system_test.go b/callback_system_test.go deleted file mode 100644 index 2482eda4..00000000 --- a/callback_system_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package gorm - -import ( - "reflect" - "runtime" - "strings" - "testing" -) - -func equalFuncs(funcs []*func(s *Scope), fnames []string) bool { - var names []string - for _, f := range funcs { - fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".") - names = append(names, fnames[len(fnames)-1]) - } - return reflect.DeepEqual(names, fnames) -} - -func create(s *Scope) {} -func beforeCreate1(s *Scope) {} -func beforeCreate2(s *Scope) {} -func afterCreate1(s *Scope) {} -func afterCreate2(s *Scope) {} - -func TestRegisterCallback(t *testing.T) { - var callback = &Callback{logger: defaultLogger} - - callback.Create().Register("before_create1", beforeCreate1) - callback.Create().Register("before_create2", beforeCreate2) - callback.Create().Register("create", create) - callback.Create().Register("after_create1", afterCreate1) - callback.Create().Register("after_create2", afterCreate2) - - if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { - t.Errorf("register callback") - } -} - -func TestRegisterCallbackWithOrder(t *testing.T) { - var callback1 = &Callback{logger: defaultLogger} - callback1.Create().Register("before_create1", beforeCreate1) - callback1.Create().Register("create", create) - callback1.Create().Register("after_create1", afterCreate1) - callback1.Create().Before("after_create1").Register("after_create2", afterCreate2) - if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { - t.Errorf("register callback with order") - } - - var callback2 = &Callback{logger: defaultLogger} - - callback2.Update().Register("create", create) - callback2.Update().Before("create").Register("before_create1", beforeCreate1) - callback2.Update().After("after_create2").Register("after_create1", afterCreate1) - callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2) - callback2.Update().Register("after_create2", afterCreate2) - - if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { - t.Errorf("register callback with order") - } -} - -func TestRegisterCallbackWithComplexOrder(t *testing.T) { - var callback1 = &Callback{logger: defaultLogger} - - callback1.Query().Before("after_create1").After("before_create1").Register("create", create) - callback1.Query().Register("before_create1", beforeCreate1) - callback1.Query().Register("after_create1", afterCreate1) - - if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) { - t.Errorf("register callback with order") - } - - var callback2 = &Callback{logger: defaultLogger} - - callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) - callback2.Delete().Before("create").Register("before_create1", beforeCreate1) - callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2) - callback2.Delete().Register("after_create1", afterCreate1) - callback2.Delete().After("after_create1").Register("after_create2", afterCreate2) - - if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { - t.Errorf("register callback with order") - } -} - -func replaceCreate(s *Scope) {} - -func TestReplaceCallback(t *testing.T) { - var callback = &Callback{logger: defaultLogger} - - callback.Create().Before("after_create1").After("before_create1").Register("create", create) - callback.Create().Register("before_create1", beforeCreate1) - callback.Create().Register("after_create1", afterCreate1) - callback.Create().Replace("create", replaceCreate) - - if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) { - t.Errorf("replace callback") - } -} - -func TestRemoveCallback(t *testing.T) { - var callback = &Callback{logger: defaultLogger} - - callback.Create().Before("after_create1").After("before_create1").Register("create", create) - callback.Create().Register("before_create1", beforeCreate1) - callback.Create().Register("after_create1", afterCreate1) - callback.Create().Remove("create") - - if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) { - t.Errorf("remove callback") - } -} diff --git a/callback_update.go b/callback_update.go deleted file mode 100644 index 699e534b..00000000 --- a/callback_update.go +++ /dev/null @@ -1,121 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "sort" - "strings" -) - -// Define callbacks for updating -func init() { - DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback) - DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback) - DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback) - DefaultCallback.Update().Register("gorm:update", updateCallback) - DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback) - DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// assignUpdatingAttributesCallback assign updating attributes to model -func assignUpdatingAttributesCallback(scope *Scope) { - if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { - if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate { - scope.InstanceSet("gorm:update_attrs", updateMaps) - } else { - scope.SkipLeft() - } - } -} - -// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating -func beforeUpdateCallback(scope *Scope) { - if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { - scope.Err(errors.New("missing WHERE clause while updating")) - return - } - if _, ok := scope.Get("gorm:update_column"); !ok { - if !scope.HasError() { - scope.CallMethod("BeforeSave") - } - if !scope.HasError() { - scope.CallMethod("BeforeUpdate") - } - } -} - -// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating -func updateTimeStampForUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - scope.SetColumn("UpdatedAt", scope.db.nowFunc()) - } -} - -// updateCallback the callback used to update data to database -func updateCallback(scope *Scope) { - if !scope.HasError() { - var sqls []string - - if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - // Sort the column names so that the generated SQL is the same every time. - updateMap := updateAttrs.(map[string]interface{}) - var columns []string - for c := range updateMap { - columns = append(columns, c) - } - sort.Strings(columns) - - for _, column := range columns { - value := updateMap[column] - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) - } - } else { - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if !field.IsPrimaryKey && field.IsNormal && (field.Name != "CreatedAt" || !field.IsBlank) { - if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } - } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - for _, foreignKey := range relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - sqls = append(sqls, - fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface()))) - } - } - } - } - } - } - - var extraOption string - if str, ok := scope.Get("gorm:update_option"); ok { - extraOption = fmt.Sprint(str) - } - - if len(sqls) > 0 { - scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v%v%v", - scope.QuotedTableName(), - strings.Join(sqls, ", "), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } - } -} - -// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating -func afterUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - if !scope.HasError() { - scope.CallMethod("AfterUpdate") - } - if !scope.HasError() { - scope.CallMethod("AfterSave") - } - } -} diff --git a/callbacks_test.go b/callbacks_test.go deleted file mode 100644 index bebd0e38..00000000 --- a/callbacks_test.go +++ /dev/null @@ -1,249 +0,0 @@ -package gorm_test - -import ( - "errors" - "reflect" - "testing" - - "github.com/jinzhu/gorm" -) - -func (s *Product) BeforeCreate() (err error) { - if s.Code == "Invalid" { - err = errors.New("invalid product") - } - s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 - return -} - -func (s *Product) BeforeUpdate() (err error) { - if s.Code == "dont_update" { - err = errors.New("can't update") - } - s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 - return -} - -func (s *Product) BeforeSave() (err error) { - if s.Code == "dont_save" { - err = errors.New("can't save") - } - s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 - return -} - -func (s *Product) AfterFind() { - s.AfterFindCallTimes = s.AfterFindCallTimes + 1 -} - -func (s *Product) AfterCreate(tx *gorm.DB) { - tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1}) -} - -func (s *Product) AfterUpdate() { - s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 -} - -func (s *Product) AfterSave() (err error) { - if s.Code == "after_save_error" { - err = errors.New("can't save") - } - s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 - return -} - -func (s *Product) BeforeDelete() (err error) { - if s.Code == "dont_delete" { - err = errors.New("can't delete") - } - s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 - return -} - -func (s *Product) AfterDelete() (err error) { - if s.Code == "after_delete_error" { - err = errors.New("can't delete") - } - s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 - return -} - -func (s *Product) GetCallTimes() []int64 { - return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} -} - -func TestRunCallbacks(t *testing.T) { - p := Product{Code: "unique_code", Price: 100} - DB.Save(&p) - - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { - t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) - } - - DB.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { - t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes()) - } - - p.Price = 200 - DB.Save(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { - t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) - } - - var products []Product - DB.Find(&products, "code = ?", "unique_code") - if products[0].AfterFindCallTimes != 2 { - t.Errorf("AfterFind callbacks should work with slice") - } - - DB.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { - t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes()) - } - - DB.Delete(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { - t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) - } - - if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { - t.Errorf("Can't find a deleted record") - } -} - -func TestCallbacksWithErrors(t *testing.T) { - p := Product{Code: "Invalid", Price: 100} - if DB.Save(&p).Error == nil { - t.Errorf("An error from before create callbacks happened when create with invalid value") - } - - if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { - t.Errorf("Should not save record that have errors") - } - - if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { - t.Errorf("An error from after create callbacks happened when create with invalid value") - } - - p2 := Product{Code: "update_callback", Price: 100} - DB.Save(&p2) - - p2.Code = "dont_update" - if DB.Save(&p2).Error == nil { - t.Errorf("An error from before update callbacks happened when update with invalid value") - } - - if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { - t.Errorf("Record Should not be updated due to errors happened in before update callback") - } - - if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { - t.Errorf("Record Should not be updated due to errors happened in before update callback") - } - - p2.Code = "dont_save" - if DB.Save(&p2).Error == nil { - t.Errorf("An error from before save callbacks happened when update with invalid value") - } - - p3 := Product{Code: "dont_delete", Price: 100} - DB.Save(&p3) - if DB.Delete(&p3).Error == nil { - t.Errorf("An error from before delete callbacks happened when delete") - } - - if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { - t.Errorf("An error from before delete callbacks happened") - } - - p4 := Product{Code: "after_save_error", Price: 100} - DB.Save(&p4) - if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { - t.Errorf("Record should be reverted if get an error in after save callback") - } - - p5 := Product{Code: "after_delete_error", Price: 100} - DB.Save(&p5) - if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { - t.Errorf("Record should be found") - } - - DB.Delete(&p5) - if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { - t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback") - } -} - -func TestGetCallback(t *testing.T) { - scope := DB.NewScope(nil) - - if DB.Callback().Create().Get("gorm:test_callback") != nil { - t.Errorf("`gorm:test_callback` should be nil") - } - - DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) }) - callback := DB.Callback().Create().Get("gorm:test_callback") - if callback == nil { - t.Errorf("`gorm:test_callback` should be non-nil") - } - callback(scope) - if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 { - t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok) - } - - DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) }) - callback = DB.Callback().Create().Get("gorm:test_callback") - if callback == nil { - t.Errorf("`gorm:test_callback` should be non-nil") - } - callback(scope) - if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 { - t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok) - } - - DB.Callback().Create().Remove("gorm:test_callback") - if DB.Callback().Create().Get("gorm:test_callback") != nil { - t.Errorf("`gorm:test_callback` should be nil") - } - - DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) }) - callback = DB.Callback().Create().Get("gorm:test_callback") - if callback == nil { - t.Errorf("`gorm:test_callback` should be non-nil") - } - callback(scope) - if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 { - t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok) - } -} - -func TestUseDefaultCallback(t *testing.T) { - createCallbackName := "gorm:test_use_default_callback_for_create" - gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) { - // nop - }) - if gorm.DefaultCallback.Create().Get(createCallbackName) == nil { - t.Errorf("`%s` expected non-nil, but got nil", createCallbackName) - } - gorm.DefaultCallback.Create().Remove(createCallbackName) - if gorm.DefaultCallback.Create().Get(createCallbackName) != nil { - t.Errorf("`%s` expected nil, but got non-nil", createCallbackName) - } - - updateCallbackName := "gorm:test_use_default_callback_for_update" - scopeValueName := "gorm:test_use_default_callback_for_update_value" - gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) { - scope.Set(scopeValueName, 1) - }) - gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) { - scope.Set(scopeValueName, 2) - }) - - scope := DB.NewScope(nil) - callback := gorm.DefaultCallback.Update().Get(updateCallbackName) - callback(scope) - if v, ok := scope.Get(scopeValueName); !ok || v != 2 { - t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok) - } -} diff --git a/create_test.go b/create_test.go deleted file mode 100644 index c80bdcbb..00000000 --- a/create_test.go +++ /dev/null @@ -1,288 +0,0 @@ -package gorm_test - -import ( - "os" - "reflect" - "testing" - "time" - - "github.com/jinzhu/now" -) - -func TestCreate(t *testing.T) { - float := 35.03554004971999 - now := time.Now() - user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float} - - if !DB.NewRecord(user) || !DB.NewRecord(&user) { - t.Error("User should be new record before create") - } - - if count := DB.Save(&user).RowsAffected; count != 1 { - t.Error("There should be one record be affected when create record") - } - - if DB.NewRecord(user) || DB.NewRecord(&user) { - t.Error("User should not new record after save") - } - - var newUser User - if err := DB.First(&newUser, user.Id).Error; err != nil { - t.Errorf("No error should happen, but got %v", err) - } - - if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) { - t.Errorf("User's PasswordHash should be saved ([]byte)") - } - - if newUser.Age != 18 { - t.Errorf("User's Age should be saved (int)") - } - - if newUser.UserNum != Num(111) { - t.Errorf("User's UserNum should be saved (custom type), but got %v", newUser.UserNum) - } - - if newUser.Latitude != float { - t.Errorf("Float64 should not be changed after save") - } - - if user.CreatedAt.IsZero() { - t.Errorf("Should have created_at after create") - } - - if newUser.CreatedAt.IsZero() { - t.Errorf("Should have created_at after create") - } - - DB.Model(user).Update("name", "create_user_new_name") - DB.First(&user, user.Id) - if user.CreatedAt.Format(time.RFC3339Nano) != newUser.CreatedAt.Format(time.RFC3339Nano) { - t.Errorf("CreatedAt should not be changed after update") - } -} - -func TestCreateEmptyStrut(t *testing.T) { - type EmptyStruct struct { - ID uint - } - DB.AutoMigrate(&EmptyStruct{}) - - if err := DB.Create(&EmptyStruct{}).Error; err != nil { - t.Errorf("No error should happen when creating user, but got %v", err) - } -} - -func TestCreateWithExistingTimestamp(t *testing.T) { - user := User{Name: "CreateUserExistingTimestamp"} - - timeA := now.MustParse("2016-01-01") - user.CreatedAt = timeA - user.UpdatedAt = timeA - DB.Save(&user) - - if user.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { - t.Errorf("CreatedAt should not be changed") - } - - if user.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { - t.Errorf("UpdatedAt should not be changed") - } - - var newUser User - DB.First(&newUser, user.Id) - - if newUser.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { - t.Errorf("CreatedAt should not be changed") - } - - if newUser.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { - t.Errorf("UpdatedAt should not be changed") - } -} - -func TestCreateWithNowFuncOverride(t *testing.T) { - user1 := User{Name: "CreateUserTimestampOverride"} - - timeA := now.MustParse("2016-01-01") - - // do DB.New() because we don't want this test to affect other tests - db1 := DB.New() - // set the override to use static timeA - db1.SetNowFuncOverride(func() time.Time { - return timeA - }) - // call .New again to check the override is carried over as well during clone - db1 = db1.New() - - db1.Save(&user1) - - if user1.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { - t.Errorf("CreatedAt be using the nowFuncOverride") - } - if user1.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { - t.Errorf("UpdatedAt be using the nowFuncOverride") - } - - // now create another user with a fresh DB.Now() that doesn't have the nowFuncOverride set - // to make sure that setting it only affected the above instance - - user2 := User{Name: "CreateUserTimestampOverrideNoMore"} - - db2 := DB.New() - - db2.Save(&user2) - - if user2.CreatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) { - t.Errorf("CreatedAt no longer be using the nowFuncOverride") - } - if user2.UpdatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) { - t.Errorf("UpdatedAt no longer be using the nowFuncOverride") - } -} - -type AutoIncrementUser struct { - User - Sequence uint `gorm:"AUTO_INCREMENT"` -} - -func TestCreateWithAutoIncrement(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { - t.Skip("Skipping this because only postgres properly support auto_increment on a non-primary_key column") - } - - DB.AutoMigrate(&AutoIncrementUser{}) - - user1 := AutoIncrementUser{} - user2 := AutoIncrementUser{} - - DB.Create(&user1) - DB.Create(&user2) - - if user2.Sequence-user1.Sequence != 1 { - t.Errorf("Auto increment should apply on Sequence") - } -} - -func TestCreateWithNoGORMPrimayKey(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" { - t.Skip("Skipping this because MSSQL will return identity only if the table has an Id column") - } - - jt := JoinTable{From: 1, To: 2} - err := DB.Create(&jt).Error - if err != nil { - t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) - } -} - -func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { - animal := Animal{Name: "Ferdinand"} - if DB.Save(&animal).Error != nil { - t.Errorf("No error should happen when create a record without std primary key") - } - - if animal.Counter == 0 { - t.Errorf("No std primary key should be filled value after create") - } - - if animal.Name != "Ferdinand" { - t.Errorf("Default value should be overrided") - } - - // Test create with default value not overrided - an := Animal{From: "nerdz"} - - if DB.Save(&an).Error != nil { - t.Errorf("No error should happen when create an record without std primary key") - } - - // We must fetch the value again, to have the default fields updated - // (We can't do this in the update statements, since sql default can be expressions - // And be different from the fields' type (eg. a time.Time fields has a default value of "now()" - DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an) - - if an.Name != "galeone" { - t.Errorf("Default value should fill the field. But got %v", an.Name) - } -} - -func TestAnonymousScanner(t *testing.T) { - user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}} - DB.Save(&user) - - var user2 User - DB.First(&user2, "name = ?", "anonymous_scanner") - if user2.Role.Name != "admin" { - t.Errorf("Should be able to get anonymous scanner") - } - - if !user2.Role.IsAdmin() { - t.Errorf("Should be able to get anonymous scanner") - } -} - -func TestAnonymousField(t *testing.T) { - user := User{Name: "anonymous_field", Company: Company{Name: "company"}} - DB.Save(&user) - - var user2 User - DB.First(&user2, "name = ?", "anonymous_field") - DB.Model(&user2).Related(&user2.Company) - if user2.Company.Name != "company" { - t.Errorf("Should be able to get anonymous field") - } -} - -func TestSelectWithCreate(t *testing.T) { - user := getPreparedUser("select_user", "select_with_create") - DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user) - - var queryuser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id) - - if queryuser.Name != user.Name || queryuser.Age == user.Age { - t.Errorf("Should only create users with name column") - } - - if queryuser.BillingAddressID.Int64 == 0 || queryuser.ShippingAddressId != 0 || - queryuser.CreditCard.ID == 0 || len(queryuser.Emails) == 0 { - t.Errorf("Should only create selected relationships") - } -} - -func TestOmitWithCreate(t *testing.T) { - user := getPreparedUser("omit_user", "omit_with_create") - DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user) - - var queryuser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id) - - if queryuser.Name == user.Name || queryuser.Age != user.Age { - t.Errorf("Should only create users with age column") - } - - if queryuser.BillingAddressID.Int64 != 0 || queryuser.ShippingAddressId == 0 || - queryuser.CreditCard.ID != 0 || len(queryuser.Emails) != 0 { - t.Errorf("Should not create omitted relationships") - } -} - -func TestCreateIgnore(t *testing.T) { - float := 35.03554004971999 - now := time.Now() - user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float} - - if !DB.NewRecord(user) || !DB.NewRecord(&user) { - t.Error("User should be new record before create") - } - - if count := DB.Create(&user).RowsAffected; count != 1 { - t.Error("There should be one record be affected when create record") - } - if DB.Dialect().GetName() == "mysql" && DB.Set("gorm:insert_modifier", "IGNORE").Create(&user).Error != nil { - t.Error("Should ignore duplicate user insert by insert modifier:IGNORE ") - } -} diff --git a/customize_column_test.go b/customize_column_test.go deleted file mode 100644 index c236ac24..00000000 --- a/customize_column_test.go +++ /dev/null @@ -1,357 +0,0 @@ -package gorm_test - -import ( - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -type CustomizeColumn struct { - ID int64 `gorm:"column:mapped_id; primary_key:yes"` - Name string `gorm:"column:mapped_name"` - Date *time.Time `gorm:"column:mapped_time"` -} - -// Make sure an ignored field does not interfere with another field's custom -// column name that matches the ignored field. -type CustomColumnAndIgnoredFieldClash struct { - Body string `sql:"-"` - RawBody string `gorm:"column:body"` -} - -func TestCustomizeColumn(t *testing.T) { - col := "mapped_name" - DB.DropTable(&CustomizeColumn{}) - DB.AutoMigrate(&CustomizeColumn{}) - - scope := DB.NewScope(&CustomizeColumn{}) - if !scope.Dialect().HasColumn(scope.TableName(), col) { - t.Errorf("CustomizeColumn should have column %s", col) - } - - col = "mapped_id" - if scope.PrimaryKey() != col { - t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey()) - } - - expected := "foo" - now := time.Now() - cc := CustomizeColumn{ID: 666, Name: expected, Date: &now} - - if count := DB.Create(&cc).RowsAffected; count != 1 { - t.Error("There should be one record be affected when create record") - } - - var cc1 CustomizeColumn - DB.First(&cc1, 666) - - if cc1.Name != expected { - t.Errorf("Failed to query CustomizeColumn") - } - - cc.Name = "bar" - DB.Save(&cc) - - var cc2 CustomizeColumn - DB.First(&cc2, 666) - if cc2.Name != "bar" { - t.Errorf("Failed to query CustomizeColumn") - } -} - -func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { - DB.DropTable(&CustomColumnAndIgnoredFieldClash{}) - if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil { - t.Errorf("Should not raise error: %s", err) - } -} - -type CustomizePerson struct { - IdPerson string `gorm:"column:idPerson;primary_key:true"` - Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"` -} - -type CustomizeAccount struct { - IdAccount string `gorm:"column:idAccount;primary_key:true"` - Name string -} - -func TestManyToManyWithCustomizedColumn(t *testing.T) { - DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount") - DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{}) - - account := CustomizeAccount{IdAccount: "account", Name: "id1"} - person := CustomizePerson{ - IdPerson: "person", - Accounts: []CustomizeAccount{account}, - } - - if err := DB.Create(&account).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - if err := DB.Create(&person).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - var person1 CustomizePerson - scope := DB.NewScope(nil) - if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil { - t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err) - } - - if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" { - t.Errorf("should preload correct accounts") - } -} - -type CustomizeUser struct { - gorm.Model - Email string `sql:"column:email_address"` -} - -type CustomizeInvitation struct { - gorm.Model - Address string `sql:"column:invitation"` - Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"` -} - -func TestOneToOneWithCustomizedColumn(t *testing.T) { - DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{}) - DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{}) - - user := CustomizeUser{ - Email: "hello@example.com", - } - invitation := CustomizeInvitation{ - Address: "hello@example.com", - } - - DB.Create(&user) - DB.Create(&invitation) - - var invitation2 CustomizeInvitation - if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - if invitation2.Person.Email != user.Email { - t.Errorf("Should preload one to one relation with customize foreign keys") - } -} - -type PromotionDiscount struct { - gorm.Model - Name string - Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"` - Rule *PromotionRule `gorm:"ForeignKey:discount_id"` - Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"` -} - -type PromotionBenefit struct { - gorm.Model - Name string - PromotionID uint - Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"` -} - -type PromotionCoupon struct { - gorm.Model - Code string - DiscountID uint - Discount PromotionDiscount -} - -type PromotionRule struct { - gorm.Model - Name string - Begin *time.Time - End *time.Time - DiscountID uint - Discount *PromotionDiscount -} - -func TestOneToManyWithCustomizedColumn(t *testing.T) { - DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{}) - DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{}) - - discount := PromotionDiscount{ - Name: "Happy New Year", - Coupons: []*PromotionCoupon{ - {Code: "newyear1"}, - {Code: "newyear2"}, - }, - } - - if err := DB.Create(&discount).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - var discount1 PromotionDiscount - if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if len(discount.Coupons) != 2 { - t.Errorf("should find two coupons") - } - - var coupon PromotionCoupon - if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if coupon.Discount.Name != "Happy New Year" { - t.Errorf("should preload discount from coupon") - } -} - -func TestHasOneWithPartialCustomizedColumn(t *testing.T) { - DB.DropTable(&PromotionDiscount{}, &PromotionRule{}) - DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{}) - - var begin = time.Now() - var end = time.Now().Add(24 * time.Hour) - discount := PromotionDiscount{ - Name: "Happy New Year 2", - Rule: &PromotionRule{ - Name: "time_limited", - Begin: &begin, - End: &end, - }, - } - - if err := DB.Create(&discount).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - var discount1 PromotionDiscount - if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) { - t.Errorf("Should be able to preload Rule") - } - - var rule PromotionRule - if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if rule.Discount.Name != "Happy New Year 2" { - t.Errorf("should preload discount from rule") - } -} - -func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { - DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{}) - DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{}) - - discount := PromotionDiscount{ - Name: "Happy New Year 3", - Benefits: []PromotionBenefit{ - {Name: "free cod"}, - {Name: "free shipping"}, - }, - } - - if err := DB.Create(&discount).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - var discount1 PromotionDiscount - if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if len(discount.Benefits) != 2 { - t.Errorf("should find two benefits") - } - - var benefit PromotionBenefit - if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if benefit.Discount.Name != "Happy New Year 3" { - t.Errorf("should preload discount from coupon") - } -} - -type SelfReferencingUser struct { - gorm.Model - Name string - Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"` -} - -func TestSelfReferencingMany2ManyColumn(t *testing.T) { - DB.DropTable(&SelfReferencingUser{}, "UserFriends") - DB.AutoMigrate(&SelfReferencingUser{}) - if !DB.HasTable("UserFriends") { - t.Errorf("auto migrate error, table UserFriends should be created") - } - - friend1 := SelfReferencingUser{Name: "friend1_m2m"} - if err := DB.Create(&friend1).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - friend2 := SelfReferencingUser{Name: "friend2_m2m"} - if err := DB.Create(&friend2).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - user := SelfReferencingUser{ - Name: "self_m2m", - Friends: []*SelfReferencingUser{&friend1, &friend2}, - } - - if err := DB.Create(&user).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - if DB.Model(&user).Association("Friends").Count() != 2 { - t.Errorf("Should find created friends correctly") - } - - var count int - if err := DB.Table("UserFriends").Count(&count).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - if count == 0 { - t.Errorf("table UserFriends should have records") - } - - var newUser = SelfReferencingUser{} - - if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - if len(newUser.Friends) != 2 { - t.Errorf("Should preload created frineds for self reference m2m") - } - - DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"}) - if DB.Model(&user).Association("Friends").Count() != 3 { - t.Errorf("Should find created friends correctly") - } - - DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"}) - if DB.Model(&user).Association("Friends").Count() != 1 { - t.Errorf("Should find created friends correctly") - } - - friend := SelfReferencingUser{} - DB.Model(&newUser).Association("Friends").Find(&friend) - if friend.Name != "friend4_m2m" { - t.Errorf("Should find created friends correctly") - } - - DB.Model(&newUser).Association("Friends").Delete(friend) - if DB.Model(&user).Association("Friends").Count() != 0 { - t.Errorf("All friends should be deleted") - } -} diff --git a/delete_test.go b/delete_test.go deleted file mode 100644 index 043641f7..00000000 --- a/delete_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package gorm_test - -import ( - "testing" - "time" -) - -func TestDelete(t *testing.T) { - user1, user2 := User{Name: "delete1"}, User{Name: "delete2"} - DB.Save(&user1) - DB.Save(&user2) - - if err := DB.Delete(&user1).Error; err != nil { - t.Errorf("No error should happen when delete a record, err=%s", err) - } - - if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { - t.Errorf("User can't be found after delete") - } - - if DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { - t.Errorf("Other users that not deleted should be found-able") - } -} - -func TestInlineDelete(t *testing.T) { - user1, user2 := User{Name: "inline_delete1"}, User{Name: "inline_delete2"} - DB.Save(&user1) - DB.Save(&user2) - - if DB.Delete(&User{}, user1.Id).Error != nil { - t.Errorf("No error should happen when delete a record") - } else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { - t.Errorf("User can't be found after delete") - } - - if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil { - t.Errorf("No error should happen when delete a record, err=%s", err) - } else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { - t.Errorf("User can't be found after delete") - } -} - -func TestSoftDelete(t *testing.T) { - type User struct { - Id int64 - Name string - DeletedAt *time.Time - } - DB.AutoMigrate(&User{}) - - user := User{Name: "soft_delete"} - DB.Save(&user) - DB.Delete(&user) - - if DB.First(&User{}, "name = ?", user.Name).Error == nil { - t.Errorf("Can't find a soft deleted record") - } - - if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { - t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err) - } - - DB.Unscoped().Delete(&user) - if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() { - t.Errorf("Can't find permanently deleted record") - } -} - -func TestSoftDeleteWithCustomizedDeletedAtColumnName(t *testing.T) { - creditCard := CreditCard{Number: "411111111234567"} - DB.Save(&creditCard) - DB.Delete(&creditCard) - - if deletedAtField, ok := DB.NewScope(&CreditCard{}).FieldByName("DeletedAt"); !ok || deletedAtField.DBName != "deleted_time" { - t.Errorf("CreditCard's DeletedAt's column name should be `deleted_time`") - } - - if DB.First(&CreditCard{}, "number = ?", creditCard.Number).Error == nil { - t.Errorf("Can't find a soft deleted record") - } - - if err := DB.Unscoped().First(&CreditCard{}, "number = ?", creditCard.Number).Error; err != nil { - t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err) - } - - DB.Unscoped().Delete(&creditCard) - if !DB.Unscoped().First(&CreditCard{}, "number = ?", creditCard.Number).RecordNotFound() { - t.Errorf("Can't find permanently deleted record") - } -} diff --git a/dialect.go b/dialect.go deleted file mode 100644 index 749587f4..00000000 --- a/dialect.go +++ /dev/null @@ -1,147 +0,0 @@ -package gorm - -import ( - "database/sql" - "fmt" - "reflect" - "strconv" - "strings" -) - -// Dialect interface contains behaviors that differ across SQL database -type Dialect interface { - // GetName get dialect's name - GetName() string - - // SetDB set db for dialect - SetDB(db SQLCommon) - - // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 - BindVar(i int) string - // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name - Quote(key string) string - // DataTypeOf return data's sql type - DataTypeOf(field *StructField) string - - // HasIndex check has index or not - HasIndex(tableName string, indexName string) bool - // HasForeignKey check has foreign key or not - HasForeignKey(tableName string, foreignKeyName string) bool - // RemoveIndex remove index - RemoveIndex(tableName string, indexName string) error - // HasTable check has table or not - HasTable(tableName string) bool - // HasColumn check has column or not - HasColumn(tableName string, columnName string) bool - // ModifyColumn modify column's type - ModifyColumn(tableName string, columnName string, typ string) error - - // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case - LimitAndOffsetSQL(limit, offset interface{}) (string, error) - // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` - SelectFromDummyTable() string - // LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT` - LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string - // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` - LastInsertIDReturningSuffix(tableName, columnName string) string - // DefaultValueStr - DefaultValueStr() string - - // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference - BuildKeyName(kind, tableName string, fields ...string) string - - // NormalizeIndexAndColumn returns valid index name and column name depending on each dialect - NormalizeIndexAndColumn(indexName, columnName string) (string, string) - - // CurrentDatabase return current database name - CurrentDatabase() string -} - -var dialectsMap = map[string]Dialect{} - -func newDialect(name string, db SQLCommon) Dialect { - if value, ok := dialectsMap[name]; ok { - dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect) - dialect.SetDB(db) - return dialect - } - - fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name) - commontDialect := &commonDialect{} - commontDialect.SetDB(db) - return commontDialect -} - -// RegisterDialect register new dialect -func RegisterDialect(name string, dialect Dialect) { - dialectsMap[name] = dialect -} - -// GetDialect gets the dialect for the specified dialect name -func GetDialect(name string) (dialect Dialect, ok bool) { - dialect, ok = dialectsMap[name] - return -} - -// ParseFieldStructForDialect get field's sql data type -var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { - // Get redirected field type - var ( - reflectType = field.Struct.Type - dataType, _ = field.TagSettingsGet("TYPE") - ) - - for reflectType.Kind() == reflect.Ptr { - reflectType = reflectType.Elem() - } - - // Get redirected field value - fieldValue = reflect.Indirect(reflect.New(reflectType)) - - if gormDataType, ok := fieldValue.Interface().(interface { - GormDataType(Dialect) string - }); ok { - dataType = gormDataType.GormDataType(dialect) - } - - // Get scanner's real value - if dataType == "" { - var getScannerValue func(reflect.Value) - getScannerValue = func(value reflect.Value) { - fieldValue = value - if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { - getScannerValue(fieldValue.Field(0)) - } - } - getScannerValue(fieldValue) - } - - // Default Size - if num, ok := field.TagSettingsGet("SIZE"); ok { - size, _ = strconv.Atoi(num) - } else { - size = 255 - } - - // Default type from tag setting - notNull, _ := field.TagSettingsGet("NOT NULL") - unique, _ := field.TagSettingsGet("UNIQUE") - additionalType = notNull + " " + unique - if value, ok := field.TagSettingsGet("DEFAULT"); ok { - additionalType = additionalType + " DEFAULT " + value - } - - if value, ok := field.TagSettingsGet("COMMENT"); ok { - additionalType = additionalType + " COMMENT " + value - } - - return fieldValue, dataType, size, strings.TrimSpace(additionalType) -} - -func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) { - if strings.Contains(tableName, ".") { - splitStrings := strings.SplitN(tableName, ".", 2) - return splitStrings[0], splitStrings[1] - } - return dialect.CurrentDatabase(), tableName -} diff --git a/dialect_common.go b/dialect_common.go deleted file mode 100644 index d549510c..00000000 --- a/dialect_common.go +++ /dev/null @@ -1,196 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "regexp" - "strconv" - "strings" - "time" -) - -var keyNameRegex = regexp.MustCompile("[^a-zA-Z0-9]+") - -// DefaultForeignKeyNamer contains the default foreign key name generator method -type DefaultForeignKeyNamer struct { -} - -type commonDialect struct { - db SQLCommon - DefaultForeignKeyNamer -} - -func init() { - RegisterDialect("common", &commonDialect{}) -} - -func (commonDialect) GetName() string { - return "common" -} - -func (s *commonDialect) SetDB(db SQLCommon) { - s.db = db -} - -func (commonDialect) BindVar(i int) string { - return "$$$" // ? -} - -func (commonDialect) Quote(key string) string { - return fmt.Sprintf(`"%s"`, key) -} - -func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { - if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { - return strings.ToLower(value) != "false" - } - return field.IsPrimaryKey -} - -func (s *commonDialect) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "BOOLEAN" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - sqlType = "INTEGER AUTO_INCREMENT" - } else { - sqlType = "INTEGER" - } - case reflect.Int64, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - sqlType = "BIGINT AUTO_INCREMENT" - } else { - sqlType = "BIGINT" - } - case reflect.Float32, reflect.Float64: - sqlType = "FLOAT" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("VARCHAR(%d)", size) - } else { - sqlType = "VARCHAR(65532)" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "TIMESTAMP" - } - default: - if _, ok := dataValue.Interface().([]byte); ok { - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("BINARY(%d)", size) - } else { - sqlType = "BINARY(65532)" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s commonDialect) HasIndex(tableName string, indexName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) - return count > 0 -} - -func (s commonDialect) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName)) - return err -} - -func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { - return false -} - -func (s commonDialect) HasTable(tableName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) - return count > 0 -} - -func (s commonDialect) HasColumn(tableName string, columnName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) - return count > 0 -} - -func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ)) - return err -} - -func (s commonDialect) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) - return -} - -// LimitAndOffsetSQL return generated SQL with Limit and Offset -func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { - if limit != nil { - if parsedLimit, err := s.parseInt(limit); err != nil { - return "", err - } else if parsedLimit >= 0 { - sql += fmt.Sprintf(" LIMIT %d", parsedLimit) - } - } - if offset != nil { - if parsedOffset, err := s.parseInt(offset); err != nil { - return "", err - } else if parsedOffset >= 0 { - sql += fmt.Sprintf(" OFFSET %d", parsedOffset) - } - } - return -} - -func (commonDialect) SelectFromDummyTable() string { - return "" -} - -func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { - return "" -} - -func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { - return "" -} - -func (commonDialect) DefaultValueStr() string { - return "DEFAULT VALUES" -} - -// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference -func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { - keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) - keyName = keyNameRegex.ReplaceAllString(keyName, "_") - return keyName -} - -// NormalizeIndexAndColumn returns argument's index name and column name without doing anything -func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { - return indexName, columnName -} - -func (commonDialect) parseInt(value interface{}) (int64, error) { - return strconv.ParseInt(fmt.Sprint(value), 0, 0) -} - -// IsByteArrayOrSlice returns true of the reflected value is an array or slice -func IsByteArrayOrSlice(value reflect.Value) bool { - return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) -} diff --git a/dialect_mysql.go b/dialect_mysql.go deleted file mode 100644 index b4467ffa..00000000 --- a/dialect_mysql.go +++ /dev/null @@ -1,246 +0,0 @@ -package gorm - -import ( - "crypto/sha1" - "database/sql" - "fmt" - "reflect" - "regexp" - "strings" - "time" - "unicode/utf8" -) - -var mysqlIndexRegex = regexp.MustCompile(`^(.+)\((\d+)\)$`) - -type mysql struct { - commonDialect -} - -func init() { - RegisterDialect("mysql", &mysql{}) -} - -func (mysql) GetName() string { - return "mysql" -} - -func (mysql) Quote(key string) string { - return fmt.Sprintf("`%s`", key) -} - -// Get Data Type for MySQL Dialect -func (s *mysql) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - // MySQL allows only one auto increment column per table, and it must - // be a KEY column. - if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { - if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey { - field.TagSettingsDelete("AUTO_INCREMENT") - } - } - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "boolean" - case reflect.Int8: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "tinyint AUTO_INCREMENT" - } else { - sqlType = "tinyint" - } - case reflect.Int, reflect.Int16, reflect.Int32: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "int AUTO_INCREMENT" - } else { - sqlType = "int" - } - case reflect.Uint8: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "tinyint unsigned AUTO_INCREMENT" - } else { - sqlType = "tinyint unsigned" - } - case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "int unsigned AUTO_INCREMENT" - } else { - sqlType = "int unsigned" - } - case reflect.Int64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "bigint AUTO_INCREMENT" - } else { - sqlType = "bigint" - } - case reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "bigint unsigned AUTO_INCREMENT" - } else { - sqlType = "bigint unsigned" - } - case reflect.Float32, reflect.Float64: - sqlType = "double" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "longtext" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - precision := "" - if p, ok := field.TagSettingsGet("PRECISION"); ok { - precision = fmt.Sprintf("(%s)", p) - } - - if _, ok := field.TagSettings["NOT NULL"]; ok || field.IsPrimaryKey { - sqlType = fmt.Sprintf("DATETIME%v", precision) - } else { - sqlType = fmt.Sprintf("DATETIME%v NULL", precision) - } - } - default: - if IsByteArrayOrSlice(dataValue) { - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varbinary(%d)", size) - } else { - sqlType = "longblob" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) in field %s for mysql", dataValue.Type().Name(), dataValue.Kind().String(), field.Name)) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s mysql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) - return err -} - -func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ)) - return err -} - -func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { - if limit != nil { - parsedLimit, err := s.parseInt(limit) - if err != nil { - return "", err - } - if parsedLimit >= 0 { - sql += fmt.Sprintf(" LIMIT %d", parsedLimit) - - if offset != nil { - parsedOffset, err := s.parseInt(offset) - if err != nil { - return "", err - } - if parsedOffset >= 0 { - sql += fmt.Sprintf(" OFFSET %d", parsedOffset) - } - } - } - } - return -} - -func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) - return count > 0 -} - -func (s mysql) HasTable(tableName string) bool { - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - var name string - // allow mysql database name with '-' character - if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil { - if err == sql.ErrNoRows { - return false - } - panic(err) - } else { - return true - } -} - -func (s mysql) HasIndex(tableName string, indexName string) bool { - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil { - panic(err) - } else { - defer rows.Close() - return rows.Next() - } -} - -func (s mysql) HasColumn(tableName string, columnName string) bool { - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil { - panic(err) - } else { - defer rows.Close() - return rows.Next() - } -} - -func (s mysql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) - return -} - -func (mysql) SelectFromDummyTable() string { - return "FROM DUAL" -} - -func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string { - keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...) - if utf8.RuneCountInString(keyName) <= 64 { - return keyName - } - h := sha1.New() - h.Write([]byte(keyName)) - bs := h.Sum(nil) - - // sha1 is 40 characters, keep first 24 characters of destination - destRunes := []rune(keyNameRegex.ReplaceAllString(fields[0], "_")) - if len(destRunes) > 24 { - destRunes = destRunes[:24] - } - - return fmt.Sprintf("%s%x", string(destRunes), bs) -} - -// NormalizeIndexAndColumn returns index name and column name for specify an index prefix length if needed -func (mysql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { - submatch := mysqlIndexRegex.FindStringSubmatch(indexName) - if len(submatch) != 3 { - return indexName, columnName - } - indexName = submatch[1] - columnName = fmt.Sprintf("%s(%s)", columnName, submatch[2]) - return indexName, columnName -} - -func (mysql) DefaultValueStr() string { - return "VALUES()" -} diff --git a/dialect_postgres.go b/dialect_postgres.go deleted file mode 100644 index d2df3131..00000000 --- a/dialect_postgres.go +++ /dev/null @@ -1,147 +0,0 @@ -package gorm - -import ( - "encoding/json" - "fmt" - "reflect" - "strings" - "time" -) - -type postgres struct { - commonDialect -} - -func init() { - RegisterDialect("postgres", &postgres{}) - RegisterDialect("cloudsqlpostgres", &postgres{}) -} - -func (postgres) GetName() string { - return "postgres" -} - -func (postgres) BindVar(i int) string { - return fmt.Sprintf("$%v", i) -} - -func (s *postgres) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "serial" - } else { - sqlType = "integer" - } - case reflect.Int64, reflect.Uint32, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "bigserial" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "numeric" - case reflect.String: - if _, ok := field.TagSettingsGet("SIZE"); !ok { - size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different - } - - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "text" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "timestamp with time zone" - } - case reflect.Map: - if dataValue.Type().Name() == "Hstore" { - sqlType = "hstore" - } - default: - if IsByteArrayOrSlice(dataValue) { - sqlType = "bytea" - - if isUUID(dataValue) { - sqlType = "uuid" - } - - if isJSON(dataValue) { - sqlType = "jsonb" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s postgres) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count) - return count > 0 -} - -func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count) - return count > 0 -} - -func (s postgres) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count) - return count > 0 -} - -func (s postgres) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count) - return count > 0 -} - -func (s postgres) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) - return -} - -func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string { - return "" -} - -func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { - return fmt.Sprintf("RETURNING %v.%v", tableName, key) -} - -func (postgres) SupportLastInsertID() bool { - return false -} - -func isUUID(value reflect.Value) bool { - if value.Kind() != reflect.Array || value.Type().Len() != 16 { - return false - } - typename := value.Type().Name() - lower := strings.ToLower(typename) - return "uuid" == lower || "guid" == lower -} - -func isJSON(value reflect.Value) bool { - _, ok := value.Interface().(json.RawMessage) - return ok -} diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go deleted file mode 100644 index 5f96c363..00000000 --- a/dialect_sqlite3.go +++ /dev/null @@ -1,107 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "strings" - "time" -) - -type sqlite3 struct { - commonDialect -} - -func init() { - RegisterDialect("sqlite3", &sqlite3{}) -} - -func (sqlite3) GetName() string { - return "sqlite3" -} - -// Get Data Type for Sqlite Dialect -func (s *sqlite3) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "bool" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "integer primary key autoincrement" - } else { - sqlType = "integer" - } - case reflect.Int64, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "integer primary key autoincrement" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "real" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "text" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "datetime" - } - default: - if IsByteArrayOrSlice(dataValue) { - sqlType = "blob" - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s sqlite3) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) CurrentDatabase() (name string) { - var ( - ifaces = make([]interface{}, 3) - pointers = make([]*string, 3) - i int - ) - for i = 0; i < 3; i++ { - ifaces[i] = &pointers[i] - } - if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil { - return - } - if pointers[1] != nil { - name = *pointers[1] - } - return -} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go deleted file mode 100644 index a516ed4a..00000000 --- a/dialects/mssql/mssql.go +++ /dev/null @@ -1,253 +0,0 @@ -package mssql - -import ( - "database/sql/driver" - "encoding/json" - "errors" - "fmt" - "reflect" - "strconv" - "strings" - "time" - - // Importing mssql driver package only in dialect file, otherwide not needed - _ "github.com/denisenkom/go-mssqldb" - "github.com/jinzhu/gorm" -) - -func setIdentityInsert(scope *gorm.Scope) { - if scope.Dialect().GetName() == "mssql" { - for _, field := range scope.PrimaryFields() { - if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank { - scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) - scope.InstanceSet("mssql:identity_insert_on", true) - } - } - } -} - -func turnOffIdentityInsert(scope *gorm.Scope) { - if scope.Dialect().GetName() == "mssql" { - if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok { - scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName())) - } - } -} - -func init() { - gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert) - gorm.DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").Register("mssql:turn_off_identity_insert", turnOffIdentityInsert) - gorm.RegisterDialect("mssql", &mssql{}) -} - -type mssql struct { - db gorm.SQLCommon - gorm.DefaultForeignKeyNamer -} - -func (mssql) GetName() string { - return "mssql" -} - -func (s *mssql) SetDB(db gorm.SQLCommon) { - s.db = db -} - -func (mssql) BindVar(i int) string { - return "$$$" // ? -} - -func (mssql) Quote(key string) string { - return fmt.Sprintf(`[%s]`, key) -} - -func (s *mssql) DataTypeOf(field *gorm.StructField) string { - var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "bit" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "int IDENTITY(1,1)" - } else { - sqlType = "int" - } - case reflect.Int64, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "bigint IDENTITY(1,1)" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "float" - case reflect.String: - if size > 0 && size < 8000 { - sqlType = fmt.Sprintf("nvarchar(%d)", size) - } else { - sqlType = "nvarchar(max)" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "datetimeoffset" - } - default: - if gorm.IsByteArrayOrSlice(dataValue) { - if size > 0 && size < 8000 { - sqlType = fmt.Sprintf("varbinary(%d)", size) - } else { - sqlType = "varbinary(max)" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool { - if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { - return value != "FALSE" - } - return field.IsPrimaryKey -} - -func (s mssql) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) - return count > 0 -} - -func (s mssql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) - return err -} - -func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow(`SELECT count(*) - FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id - inner join information_schema.tables as I on I.TABLE_NAME = T.name - WHERE F.name = ? - AND T.Name = ? AND I.TABLE_CATALOG = ?;`, foreignKeyName, tableName, currentDatabase).Scan(&count) - return count > 0 -} - -func (s mssql) HasTable(tableName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count) - return count > 0 -} - -func (s mssql) HasColumn(tableName string, columnName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) - return count > 0 -} - -func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ)) - return err -} - -func (s mssql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) - return -} - -func parseInt(value interface{}) (int64, error) { - return strconv.ParseInt(fmt.Sprint(value), 0, 0) -} - -func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { - if offset != nil { - if parsedOffset, err := parseInt(offset); err != nil { - return "", err - } else if parsedOffset >= 0 { - sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset) - } - } - if limit != nil { - if parsedLimit, err := parseInt(limit); err != nil { - return "", err - } else if parsedLimit >= 0 { - if sql == "" { - // add default zero offset - sql += " OFFSET 0 ROWS" - } - sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit) - } - } - return -} - -func (mssql) SelectFromDummyTable() string { - return "" -} - -func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { - if len(columns) == 0 { - // No OUTPUT to query - return "" - } - return fmt.Sprintf("OUTPUT Inserted.%v", columnName) -} - -func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { - // https://stackoverflow.com/questions/5228780/how-to-get-last-inserted-id - return "; SELECT SCOPE_IDENTITY()" -} - -func (mssql) DefaultValueStr() string { - return "DEFAULT VALUES" -} - -// NormalizeIndexAndColumn returns argument's index name and column name without doing anything -func (mssql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { - return indexName, columnName -} - -func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { - if strings.Contains(tableName, ".") { - splitStrings := strings.SplitN(tableName, ".", 2) - return splitStrings[0], splitStrings[1] - } - return dialect.CurrentDatabase(), tableName -} - -// JSON type to support easy handling of JSON data in character table fields -// using golang json.RawMessage for deferred decoding/encoding -type JSON struct { - json.RawMessage -} - -// Value get value of JSON -func (j JSON) Value() (driver.Value, error) { - if len(j.RawMessage) == 0 { - return nil, nil - } - return j.MarshalJSON() -} - -// Scan scan value into JSON -func (j *JSON) Scan(value interface{}) error { - str, ok := value.(string) - if !ok { - return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value)) - } - bytes := []byte(str) - return json.Unmarshal(bytes, j) -} diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go deleted file mode 100644 index 9deba48a..00000000 --- a/dialects/mysql/mysql.go +++ /dev/null @@ -1,3 +0,0 @@ -package mysql - -import _ "github.com/go-sql-driver/mysql" diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go deleted file mode 100644 index e6c088b1..00000000 --- a/dialects/postgres/postgres.go +++ /dev/null @@ -1,81 +0,0 @@ -package postgres - -import ( - "database/sql" - "database/sql/driver" - - "encoding/json" - "errors" - "fmt" - - _ "github.com/lib/pq" - "github.com/lib/pq/hstore" -) - -type Hstore map[string]*string - -// Value get value of Hstore -func (h Hstore) Value() (driver.Value, error) { - hstore := hstore.Hstore{Map: map[string]sql.NullString{}} - if len(h) == 0 { - return nil, nil - } - - for key, value := range h { - var s sql.NullString - if value != nil { - s.String = *value - s.Valid = true - } - hstore.Map[key] = s - } - return hstore.Value() -} - -// Scan scan value into Hstore -func (h *Hstore) Scan(value interface{}) error { - hstore := hstore.Hstore{} - - if err := hstore.Scan(value); err != nil { - return err - } - - if len(hstore.Map) == 0 { - return nil - } - - *h = Hstore{} - for k := range hstore.Map { - if hstore.Map[k].Valid { - s := hstore.Map[k].String - (*h)[k] = &s - } else { - (*h)[k] = nil - } - } - - return nil -} - -// Jsonb Postgresql's JSONB data type -type Jsonb struct { - json.RawMessage -} - -// Value get value of Jsonb -func (j Jsonb) Value() (driver.Value, error) { - if len(j.RawMessage) == 0 { - return nil, nil - } - return j.MarshalJSON() -} - -// Scan scan value into Jsonb -func (j *Jsonb) Scan(value interface{}) error { - bytes, ok := value.([]byte) - if !ok { - return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value)) - } - - return json.Unmarshal(bytes, j) -} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go deleted file mode 100644 index 069ad3a9..00000000 --- a/dialects/sqlite/sqlite.go +++ /dev/null @@ -1,3 +0,0 @@ -package sqlite - -import _ "github.com/mattn/go-sqlite3" diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index 79bf5fc3..00000000 --- a/docker-compose.yml +++ /dev/null @@ -1,30 +0,0 @@ -version: '3' - -services: - mysql: - image: 'mysql:latest' - ports: - - 9910:3306 - environment: - - MYSQL_DATABASE=gorm - - MYSQL_USER=gorm - - MYSQL_PASSWORD=gorm - - MYSQL_RANDOM_ROOT_PASSWORD="yes" - postgres: - image: 'postgres:latest' - ports: - - 9920:5432 - environment: - - POSTGRES_USER=gorm - - POSTGRES_DB=gorm - - POSTGRES_PASSWORD=gorm - mssql: - image: 'mcmoe/mssqldocker:latest' - ports: - - 9930:1433 - environment: - - ACCEPT_EULA=Y - - SA_PASSWORD=LoremIpsum86 - - MSSQL_DB=gorm - - MSSQL_USER=gorm - - MSSQL_PASSWORD=LoremIpsum86 diff --git a/embedded_struct_test.go b/embedded_struct_test.go deleted file mode 100644 index 5f8ece57..00000000 --- a/embedded_struct_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package gorm_test - -import "testing" - -type BasePost struct { - Id int64 - Title string - URL string -} - -type Author struct { - ID string - Name string - Email string -} - -type HNPost struct { - BasePost - Author `gorm:"embedded_prefix:user_"` // Embedded struct - Upvotes int32 -} - -type EngadgetPost struct { - BasePost BasePost `gorm:"embedded"` - Author Author `gorm:"embedded;embedded_prefix:author_"` // Embedded struct - ImageUrl string -} - -func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) { - dialect := DB.NewScope(&EngadgetPost{}).Dialect() - engadgetPostScope := DB.NewScope(&EngadgetPost{}) - if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") { - t.Errorf("should has prefix for embedded columns") - } - - if len(engadgetPostScope.PrimaryFields()) != 1 { - t.Errorf("should have only one primary field with embedded struct, but got %v", len(engadgetPostScope.PrimaryFields())) - } - - hnScope := DB.NewScope(&HNPost{}) - if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") { - t.Errorf("should has prefix for embedded columns") - } -} - -func TestSaveAndQueryEmbeddedStruct(t *testing.T) { - DB.Save(&HNPost{BasePost: BasePost{Title: "news"}}) - DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}}) - var news HNPost - if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { - t.Errorf("no error should happen when query with embedded struct, but got %v", err) - } else if news.Title != "hn_news" { - t.Errorf("embedded struct's value should be scanned correctly") - } - - DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) - var egNews EngadgetPost - if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { - t.Errorf("no error should happen when query with embedded struct, but got %v", err) - } else if egNews.BasePost.Title != "engadget_news" { - t.Errorf("embedded struct's value should be scanned correctly") - } - - if DB.NewScope(&HNPost{}).PrimaryField() == nil { - t.Errorf("primary key with embedded struct should works") - } - - for _, field := range DB.NewScope(&HNPost{}).Fields() { - if field.Name == "BasePost" { - t.Errorf("scope Fields should not contain embedded struct") - } - } -} - -func TestEmbeddedPointerTypeStruct(t *testing.T) { - type HNPost struct { - *BasePost - Upvotes int32 - } - - DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}}) - - var hnPost HNPost - if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil { - t.Errorf("No error should happen when find embedded pointer type, but got %v", err) - } - - if hnPost.Title != "embedded_pointer_type" { - t.Errorf("Should find correct value for embedded pointer type") - } -} diff --git a/errors.go b/errors.go deleted file mode 100644 index d5ef8d57..00000000 --- a/errors.go +++ /dev/null @@ -1,72 +0,0 @@ -package gorm - -import ( - "errors" - "strings" -) - -var ( - // ErrRecordNotFound returns a "record not found error". Occurs only when attempting to query the database with a struct; querying with a slice won't return this error - ErrRecordNotFound = errors.New("record not found") - // ErrInvalidSQL occurs when you attempt a query with invalid SQL - ErrInvalidSQL = errors.New("invalid SQL") - // ErrInvalidTransaction occurs when you are trying to `Commit` or `Rollback` - ErrInvalidTransaction = errors.New("no valid transaction") - // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin` - ErrCantStartTransaction = errors.New("can't start transaction") - // ErrUnaddressable unaddressable value - ErrUnaddressable = errors.New("using unaddressable value") -) - -// Errors contains all happened errors -type Errors []error - -// IsRecordNotFoundError returns true if error contains a RecordNotFound error -func IsRecordNotFoundError(err error) bool { - if errs, ok := err.(Errors); ok { - for _, err := range errs { - if err == ErrRecordNotFound { - return true - } - } - } - return err == ErrRecordNotFound -} - -// GetErrors gets all errors that have occurred and returns a slice of errors (Error type) -func (errs Errors) GetErrors() []error { - return errs -} - -// Add adds an error to a given slice of errors -func (errs Errors) Add(newErrors ...error) Errors { - for _, err := range newErrors { - if err == nil { - continue - } - - if errors, ok := err.(Errors); ok { - errs = errs.Add(errors...) - } else { - ok = true - for _, e := range errs { - if err == e { - ok = false - } - } - if ok { - errs = append(errs, err) - } - } - } - return errs -} - -// Error takes a slice of all errors that have occurred and returns it as a formatted string -func (errs Errors) Error() string { - var errors = []string{} - for _, e := range errs { - errors = append(errors, e.Error()) - } - return strings.Join(errors, "; ") -} diff --git a/errors_test.go b/errors_test.go deleted file mode 100644 index 9a428dec..00000000 --- a/errors_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package gorm_test - -import ( - "errors" - "testing" - - "github.com/jinzhu/gorm" -) - -func TestErrorsCanBeUsedOutsideGorm(t *testing.T) { - errs := []error{errors.New("First"), errors.New("Second")} - - gErrs := gorm.Errors(errs) - gErrs = gErrs.Add(errors.New("Third")) - gErrs = gErrs.Add(gErrs) - - if gErrs.Error() != "First; Second; Third" { - t.Fatalf("Gave wrong error, got %s", gErrs.Error()) - } -} diff --git a/field.go b/field.go deleted file mode 100644 index acd06e20..00000000 --- a/field.go +++ /dev/null @@ -1,66 +0,0 @@ -package gorm - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "reflect" -) - -// Field model field definition -type Field struct { - *StructField - IsBlank bool - Field reflect.Value -} - -// Set set a value to the field -func (field *Field) Set(value interface{}) (err error) { - if !field.Field.IsValid() { - return errors.New("field value not valid") - } - - if !field.Field.CanAddr() { - return ErrUnaddressable - } - - reflectValue, ok := value.(reflect.Value) - if !ok { - reflectValue = reflect.ValueOf(value) - } - - fieldValue := field.Field - if reflectValue.IsValid() { - if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { - fieldValue.Set(reflectValue.Convert(fieldValue.Type())) - } else { - if fieldValue.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(field.Struct.Type.Elem())) - } - fieldValue = fieldValue.Elem() - } - - if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { - fieldValue.Set(reflectValue.Convert(fieldValue.Type())) - } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - v := reflectValue.Interface() - if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - err = scanner.Scan(v) - } - } else { - err = scanner.Scan(v) - } - } else { - err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) - } - } - } else { - field.Field.Set(reflect.Zero(field.Field.Type())) - } - - field.IsBlank = isBlank(field.Field) - return err -} diff --git a/field_test.go b/field_test.go deleted file mode 100644 index 715661f0..00000000 --- a/field_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package gorm_test - -import ( - "database/sql/driver" - "encoding/hex" - "fmt" - "testing" - - "github.com/jinzhu/gorm" -) - -type CalculateField struct { - gorm.Model - Name string - Children []CalculateFieldChild - Category CalculateFieldCategory - EmbeddedField -} - -type EmbeddedField struct { - EmbeddedName string `sql:"NOT NULL;DEFAULT:'hello'"` -} - -type CalculateFieldChild struct { - gorm.Model - CalculateFieldID uint - Name string -} - -type CalculateFieldCategory struct { - gorm.Model - CalculateFieldID uint - Name string -} - -func TestCalculateField(t *testing.T) { - var field CalculateField - var scope = DB.NewScope(&field) - if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil { - t.Errorf("Should calculate fields correctly for the first time") - } - - if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil { - t.Errorf("Should calculate fields correctly for the first time") - } - - if field, ok := scope.FieldByName("embedded_name"); !ok { - t.Errorf("should find embedded field") - } else if _, ok := field.TagSettingsGet("NOT NULL"); !ok { - t.Errorf("should find embedded field's tag settings") - } -} - -type UUID [16]byte - -type NullUUID struct { - UUID - Valid bool -} - -func FromString(input string) (u UUID) { - src := []byte(input) - return FromBytes(src) -} - -func FromBytes(src []byte) (u UUID) { - dst := u[:] - hex.Decode(dst[0:4], src[0:8]) - hex.Decode(dst[4:6], src[9:13]) - hex.Decode(dst[6:8], src[14:18]) - hex.Decode(dst[8:10], src[19:23]) - hex.Decode(dst[10:], src[24:]) - return -} - -func (u UUID) String() string { - buf := make([]byte, 36) - src := u[:] - hex.Encode(buf[0:8], src[0:4]) - buf[8] = '-' - hex.Encode(buf[9:13], src[4:6]) - buf[13] = '-' - hex.Encode(buf[14:18], src[6:8]) - buf[18] = '-' - hex.Encode(buf[19:23], src[8:10]) - buf[23] = '-' - hex.Encode(buf[24:], src[10:]) - return string(buf) -} - -func (u UUID) Value() (driver.Value, error) { - return u.String(), nil -} - -func (u *UUID) Scan(src interface{}) error { - switch src := src.(type) { - case UUID: // support gorm convert from UUID to NullUUID - *u = src - return nil - case []byte: - *u = FromBytes(src) - return nil - case string: - *u = FromString(src) - return nil - } - return fmt.Errorf("uuid: cannot convert %T to UUID", src) -} - -func (u *NullUUID) Scan(src interface{}) error { - u.Valid = true - return u.UUID.Scan(src) -} - -func TestFieldSet(t *testing.T) { - type TestFieldSetNullUUID struct { - NullUUID NullUUID - } - scope := DB.NewScope(&TestFieldSetNullUUID{}) - field := scope.Fields()[0] - err := field.Set(FromString("3034d44a-da03-11e8-b366-4a00070b9f00")) - if err != nil { - t.Fatal(err) - } - if id, ok := field.Field.Addr().Interface().(*NullUUID); !ok { - t.Fatal() - } else if !id.Valid || id.UUID.String() != "3034d44a-da03-11e8-b366-4a00070b9f00" { - t.Fatal(id) - } -} diff --git a/go.mod b/go.mod index 91ff3cb8..0b3e3065 100644 --- a/go.mod +++ b/go.mod @@ -1,14 +1 @@ module github.com/jinzhu/gorm - -go 1.12 - -require ( - github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd - github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 - github.com/go-sql-driver/mysql v1.5.0 - github.com/jinzhu/inflection v1.0.0 - github.com/jinzhu/now v1.0.1 - github.com/lib/pq v1.1.1 - github.com/mattn/go-sqlite3 v2.0.1+incompatible - golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd // indirect -) diff --git a/go.sum b/go.sum deleted file mode 100644 index e09a0352..00000000 --- a/go.sum +++ /dev/null @@ -1,25 +0,0 @@ -github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd h1:83Wprp6ROGeiHFAP8WJdI2RoxALQYgdllERc3N5N2DM= -github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= -github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DPaBjB8zlTR87/ElzFsnQfuHnVUVqpZZIcV5Y= -github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= -github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= -github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= -github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= -github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= -github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= -github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M= -github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4= -github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw= -github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= -golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd h1:GGJVjV8waZKRHrgwvtH66z9ZGVurTD1MT0n1Bb+q4aM= -golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/interface.go b/interface.go deleted file mode 100644 index fe649231..00000000 --- a/interface.go +++ /dev/null @@ -1,24 +0,0 @@ -package gorm - -import ( - "context" - "database/sql" -) - -// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. -type SQLCommon interface { - Exec(query string, args ...interface{}) (sql.Result, error) - Prepare(query string) (*sql.Stmt, error) - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row -} - -type sqlDb interface { - Begin() (*sql.Tx, error) - BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) -} - -type sqlTx interface { - Commit() error - Rollback() error -} diff --git a/join_table_handler.go b/join_table_handler.go deleted file mode 100644 index a036d46d..00000000 --- a/join_table_handler.go +++ /dev/null @@ -1,211 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" - "strings" -) - -// JoinTableHandlerInterface is an interface for how to handle many2many relations -type JoinTableHandlerInterface interface { - // initialize join table handler - Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) - // Table return join table's table name - Table(db *DB) string - // Add create relationship in join table for source and destination - Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error - // Delete delete relationship in join table for sources - Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error - // JoinWith query with `Join` conditions - JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB - // SourceForeignKeys return source foreign keys - SourceForeignKeys() []JoinTableForeignKey - // DestinationForeignKeys return destination foreign keys - DestinationForeignKeys() []JoinTableForeignKey -} - -// JoinTableForeignKey join table foreign key struct -type JoinTableForeignKey struct { - DBName string - AssociationDBName string -} - -// JoinTableSource is a struct that contains model type and foreign keys -type JoinTableSource struct { - ModelType reflect.Type - ForeignKeys []JoinTableForeignKey -} - -// JoinTableHandler default join table handler -type JoinTableHandler struct { - TableName string `sql:"-"` - Source JoinTableSource `sql:"-"` - Destination JoinTableSource `sql:"-"` -} - -// SourceForeignKeys return source foreign keys -func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey { - return s.Source.ForeignKeys -} - -// DestinationForeignKeys return destination foreign keys -func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey { - return s.Destination.ForeignKeys -} - -// Setup initialize a default join table handler -func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { - s.TableName = tableName - - s.Source = JoinTableSource{ModelType: source} - s.Source.ForeignKeys = []JoinTableForeignKey{} - for idx, dbName := range relationship.ForeignFieldNames { - s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.ForeignDBNames[idx], - AssociationDBName: dbName, - }) - } - - s.Destination = JoinTableSource{ModelType: destination} - s.Destination.ForeignKeys = []JoinTableForeignKey{} - for idx, dbName := range relationship.AssociationForeignFieldNames { - s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.AssociationForeignDBNames[idx], - AssociationDBName: dbName, - }) - } -} - -// Table return join table's table name -func (s JoinTableHandler) Table(db *DB) string { - return DefaultTableNameHandler(db, s.TableName) -} - -func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) { - for _, source := range sources { - scope := db.NewScope(source) - modelType := scope.GetModelStruct().ModelType - - for _, joinTableSource := range joinTableSources { - if joinTableSource.ModelType == modelType { - for _, foreignKey := range joinTableSource.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - conditionMap[foreignKey.DBName] = field.Field.Interface() - } - } - break - } - } - } -} - -// Add create relationship in join table for source and destination -func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { - var ( - scope = db.NewScope("") - conditionMap = map[string]interface{}{} - ) - - // Update condition map for source - s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source) - - // Update condition map for destination - s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination) - - var assignColumns, binVars, conditions []string - var values []interface{} - for key, value := range conditionMap { - assignColumns = append(assignColumns, scope.Quote(key)) - binVars = append(binVars, `?`) - conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) - values = append(values, value) - } - - for _, value := range values { - values = append(values, value) - } - - quotedTable := scope.Quote(handler.Table(db)) - sql := fmt.Sprintf( - "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", - quotedTable, - strings.Join(assignColumns, ","), - strings.Join(binVars, ","), - scope.Dialect().SelectFromDummyTable(), - quotedTable, - strings.Join(conditions, " AND "), - ) - - return db.Exec(sql, values...).Error -} - -// Delete delete relationship in join table for sources -func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { - var ( - scope = db.NewScope(nil) - conditions []string - values []interface{} - conditionMap = map[string]interface{}{} - ) - - s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...) - - for key, value := range conditionMap { - conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) - values = append(values, value) - } - - return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error -} - -// JoinWith query with `Join` conditions -func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { - var ( - scope = db.NewScope(source) - tableName = handler.Table(db) - quotedTableName = scope.Quote(tableName) - joinConditions []string - values []interface{} - ) - - if s.Source.ModelType == scope.GetModelStruct().ModelType { - destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() - for _, foreignKey := range s.Destination.ForeignKeys { - joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) - } - - var foreignDBNames []string - var foreignFieldNames []string - - for _, foreignKey := range s.Source.ForeignKeys { - foreignDBNames = append(foreignDBNames, foreignKey.DBName) - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - foreignFieldNames = append(foreignFieldNames, field.Name) - } - } - - foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value) - - var condString string - if len(foreignFieldValues) > 0 { - var quotedForeignDBNames []string - for _, dbName := range foreignDBNames { - quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName) - } - - condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) - - keys := scope.getColumnAsArray(foreignFieldNames, scope.Value) - values = append(values, toQueryValues(keys)) - } else { - condString = fmt.Sprintf("1 <> 1") - } - - return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))). - Where(condString, toQueryValues(foreignFieldValues)...) - } - - db.Error = errors.New("wrong source type for join table handler") - return db -} diff --git a/join_table_test.go b/join_table_test.go deleted file mode 100644 index 6d5f427d..00000000 --- a/join_table_test.go +++ /dev/null @@ -1,117 +0,0 @@ -package gorm_test - -import ( - "fmt" - "strconv" - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -type Person struct { - Id int - Name string - Addresses []*Address `gorm:"many2many:person_addresses;"` -} - -type PersonAddress struct { - gorm.JoinTableHandler - PersonID int - AddressID int - DeletedAt *time.Time - CreatedAt time.Time -} - -func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { - foreignPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(foreignValue).PrimaryKeyValue())) - associationPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(associationValue).PrimaryKeyValue())) - if result := db.Unscoped().Model(&PersonAddress{}).Where(map[string]interface{}{ - "person_id": foreignPrimaryKey, - "address_id": associationPrimaryKey, - }).Update(map[string]interface{}{ - "person_id": foreignPrimaryKey, - "address_id": associationPrimaryKey, - "deleted_at": gorm.Expr("NULL"), - }).RowsAffected; result == 0 { - return db.Create(&PersonAddress{ - PersonID: foreignPrimaryKey, - AddressID: associationPrimaryKey, - }).Error - } - - return nil -} - -func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error { - return db.Delete(&PersonAddress{}).Error -} - -func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { - table := pa.Table(db) - return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) -} - -func TestJoinTable(t *testing.T) { - DB.Exec("drop table person_addresses;") - DB.AutoMigrate(&Person{}) - DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{}) - - address1 := &Address{Address1: "address 1"} - address2 := &Address{Address1: "address 2"} - person := &Person{Name: "person", Addresses: []*Address{address1, address2}} - DB.Save(person) - - DB.Model(person).Association("Addresses").Delete(address1) - - if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 1 { - t.Errorf("Should found one address") - } - - if DB.Model(person).Association("Addresses").Count() != 1 { - t.Errorf("Should found one address") - } - - if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 2 { - t.Errorf("Found two addresses with Unscoped") - } - - if DB.Model(person).Association("Addresses").Clear(); DB.Model(person).Association("Addresses").Count() != 0 { - t.Errorf("Should deleted all addresses") - } -} - -func TestEmbeddedMany2ManyRelationship(t *testing.T) { - type EmbeddedPerson struct { - ID int - Name string - Addresses []*Address `gorm:"many2many:person_addresses;"` - } - - type NewPerson struct { - EmbeddedPerson - ExternalID uint - } - DB.Exec("drop table person_addresses;") - DB.AutoMigrate(&NewPerson{}) - - address1 := &Address{Address1: "address 1"} - address2 := &Address{Address1: "address 2"} - person := &NewPerson{ExternalID: 100, EmbeddedPerson: EmbeddedPerson{Name: "person", Addresses: []*Address{address1, address2}}} - if err := DB.Save(person).Error; err != nil { - t.Errorf("no error should return when save embedded many2many relationship, but got %v", err) - } - - if err := DB.Model(person).Association("Addresses").Delete(address1).Error; err != nil { - t.Errorf("no error should return when delete embedded many2many relationship, but got %v", err) - } - - association := DB.Model(person).Association("Addresses") - if count := association.Count(); count != 1 || association.Error != nil { - t.Errorf("Should found one address, but got %v, error is %v", count, association.Error) - } - - if association.Clear(); association.Count() != 0 { - t.Errorf("Should deleted all addresses") - } -} diff --git a/logger.go b/logger.go deleted file mode 100644 index 88e167dd..00000000 --- a/logger.go +++ /dev/null @@ -1,141 +0,0 @@ -package gorm - -import ( - "database/sql/driver" - "fmt" - "log" - "os" - "reflect" - "regexp" - "strconv" - "time" - "unicode" -) - -var ( - defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} - sqlRegexp = regexp.MustCompile(`\?`) - numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`) -) - -func isPrintable(s string) bool { - for _, r := range s { - if !unicode.IsPrint(r) { - return false - } - } - return true -} - -var LogFormatter = func(values ...interface{}) (messages []interface{}) { - if len(values) > 1 { - var ( - sql string - formattedValues []string - level = values[0] - currentTime = "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m" - source = fmt.Sprintf("\033[35m(%v)\033[0m", values[1]) - ) - - messages = []interface{}{source, currentTime} - - if len(values) == 2 { - //remove the line break - currentTime = currentTime[1:] - //remove the brackets - source = fmt.Sprintf("\033[35m%v\033[0m", values[1]) - - messages = []interface{}{currentTime, source} - } - - if level == "sql" { - // duration - messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) - // sql - - for _, value := range values[4].([]interface{}) { - indirectValue := reflect.Indirect(reflect.ValueOf(value)) - if indirectValue.IsValid() { - value = indirectValue.Interface() - if t, ok := value.(time.Time); ok { - if t.IsZero() { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", "0000-00-00 00:00:00")) - } else { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05"))) - } - } else if b, ok := value.([]byte); ok { - if str := string(b); isPrintable(str) { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str)) - } else { - formattedValues = append(formattedValues, "''") - } - } else if r, ok := value.(driver.Valuer); ok { - if value, err := r.Value(); err == nil && value != nil { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) - } else { - formattedValues = append(formattedValues, "NULL") - } - } else { - switch value.(type) { - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool: - formattedValues = append(formattedValues, fmt.Sprintf("%v", value)) - default: - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) - } - } - } else { - formattedValues = append(formattedValues, "NULL") - } - } - - // differentiate between $n placeholders or else treat like ? - if numericPlaceHolderRegexp.MatchString(values[3].(string)) { - sql = values[3].(string) - for index, value := range formattedValues { - placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1) - sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1") - } - } else { - formattedValuesLength := len(formattedValues) - for index, value := range sqlRegexp.Split(values[3].(string), -1) { - sql += value - if index < formattedValuesLength { - sql += formattedValues[index] - } - } - } - - messages = append(messages, sql) - messages = append(messages, fmt.Sprintf(" \n\033[36;31m[%v]\033[0m ", strconv.FormatInt(values[5].(int64), 10)+" rows affected or returned ")) - } else { - messages = append(messages, "\033[31;1m") - messages = append(messages, values[2:]...) - messages = append(messages, "\033[0m") - } - } - - return -} - -type logger interface { - Print(v ...interface{}) -} - -// LogWriter log writer interface -type LogWriter interface { - Println(v ...interface{}) -} - -// Logger default logger -type Logger struct { - LogWriter -} - -// Print format & print log -func (logger Logger) Print(values ...interface{}) { - logger.Println(LogFormatter(values...)...) -} - -type nopLogger struct{} - -func (nopLogger) Print(values ...interface{}) {} diff --git a/main.go b/main.go deleted file mode 100644 index 3db87870..00000000 --- a/main.go +++ /dev/null @@ -1,881 +0,0 @@ -package gorm - -import ( - "context" - "database/sql" - "errors" - "fmt" - "reflect" - "strings" - "sync" - "time" -) - -// DB contains information for current db connection -type DB struct { - sync.RWMutex - Value interface{} - Error error - RowsAffected int64 - - // single db - db SQLCommon - blockGlobalUpdate bool - logMode logModeValue - logger logger - search *search - values sync.Map - - // global db - parent *DB - callbacks *Callback - dialect Dialect - singularTable bool - - // function to be used to override the creating of a new timestamp - nowFuncOverride func() time.Time -} - -type logModeValue int - -const ( - defaultLogMode logModeValue = iota - noLogMode - detailedLogMode -) - -// Open initialize a new db connection, need to import driver first, e.g: -// -// import _ "github.com/go-sql-driver/mysql" -// func main() { -// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") -// } -// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with -// import _ "github.com/jinzhu/gorm/dialects/mysql" -// // import _ "github.com/jinzhu/gorm/dialects/postgres" -// // import _ "github.com/jinzhu/gorm/dialects/sqlite" -// // import _ "github.com/jinzhu/gorm/dialects/mssql" -func Open(dialect string, args ...interface{}) (db *DB, err error) { - if len(args) == 0 { - err = errors.New("invalid database source") - return nil, err - } - var source string - var dbSQL SQLCommon - var ownDbSQL bool - - switch value := args[0].(type) { - case string: - var driver = dialect - if len(args) == 1 { - source = value - } else if len(args) >= 2 { - driver = value - source = args[1].(string) - } - dbSQL, err = sql.Open(driver, source) - ownDbSQL = true - case SQLCommon: - dbSQL = value - ownDbSQL = false - default: - return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) - } - - db = &DB{ - db: dbSQL, - logger: defaultLogger, - callbacks: DefaultCallback, - dialect: newDialect(dialect, dbSQL), - } - db.parent = db - if err != nil { - return - } - // Send a ping to make sure the database connection is alive. - if d, ok := dbSQL.(*sql.DB); ok { - if err = d.Ping(); err != nil && ownDbSQL { - d.Close() - } - } - return -} - -// New clone a new db connection without search conditions -func (s *DB) New() *DB { - clone := s.clone() - clone.search = nil - clone.Value = nil - return clone -} - -type closer interface { - Close() error -} - -// Close close current db connection. If database connection is not an io.Closer, returns an error. -func (s *DB) Close() error { - if db, ok := s.parent.db.(closer); ok { - return db.Close() - } - return errors.New("can't close current db") -} - -// DB get `*sql.DB` from current connection -// If the underlying database connection is not a *sql.DB, returns nil -func (s *DB) DB() *sql.DB { - db, ok := s.db.(*sql.DB) - if !ok { - panic("can't support full GORM on currently status, maybe this is a TX instance.") - } - return db -} - -// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. -func (s *DB) CommonDB() SQLCommon { - return s.db -} - -// Dialect get dialect -func (s *DB) Dialect() Dialect { - return s.dialect -} - -// Callback return `Callbacks` container, you could add/change/delete callbacks with it -// db.Callback().Create().Register("update_created_at", updateCreated) -// Refer https://jinzhu.github.io/gorm/development.html#callbacks -func (s *DB) Callback() *Callback { - s.parent.callbacks = s.parent.callbacks.clone(s.logger) - return s.parent.callbacks -} - -// SetLogger replace default logger -func (s *DB) SetLogger(log logger) { - s.logger = log -} - -// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs -func (s *DB) LogMode(enable bool) *DB { - if enable { - s.logMode = detailedLogMode - } else { - s.logMode = noLogMode - } - return s -} - -// SetNowFuncOverride set the function to be used when creating a new timestamp -func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB { - s.nowFuncOverride = nowFuncOverride - return s -} - -// Get a new timestamp, using the provided nowFuncOverride on the DB instance if set, -// otherwise defaults to the global NowFunc() -func (s *DB) nowFunc() time.Time { - if s.nowFuncOverride != nil { - return s.nowFuncOverride() - } - - return NowFunc() -} - -// BlockGlobalUpdate if true, generates an error on update/delete without where clause. -// This is to prevent eventual error with empty objects updates/deletions -func (s *DB) BlockGlobalUpdate(enable bool) *DB { - s.blockGlobalUpdate = enable - return s -} - -// HasBlockGlobalUpdate return state of block -func (s *DB) HasBlockGlobalUpdate() bool { - return s.blockGlobalUpdate -} - -// SingularTable use singular table by default -func (s *DB) SingularTable(enable bool) { - s.parent.Lock() - defer s.parent.Unlock() - s.parent.singularTable = enable -} - -// NewScope create a scope for current operation -func (s *DB) NewScope(value interface{}) *Scope { - dbClone := s.clone() - dbClone.Value = value - scope := &Scope{db: dbClone, Value: value} - if s.search != nil { - scope.Search = s.search.clone() - } else { - scope.Search = &search{} - } - return scope -} - -// QueryExpr returns the query as SqlExpr object -func (s *DB) QueryExpr() *SqlExpr { - scope := s.NewScope(s.Value) - scope.InstanceSet("skip_bindvar", true) - scope.prepareQuerySQL() - - return Expr(scope.SQL, scope.SQLVars...) -} - -// SubQuery returns the query as sub query -func (s *DB) SubQuery() *SqlExpr { - scope := s.NewScope(s.Value) - scope.InstanceSet("skip_bindvar", true) - scope.prepareQuerySQL() - - return Expr(fmt.Sprintf("(%v)", scope.SQL), scope.SQLVars...) -} - -// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query -func (s *DB) Where(query interface{}, args ...interface{}) *DB { - return s.clone().search.Where(query, args...).db -} - -// Or filter records that match before conditions or this one, similar to `Where` -func (s *DB) Or(query interface{}, args ...interface{}) *DB { - return s.clone().search.Or(query, args...).db -} - -// Not filter records that don't match current conditions, similar to `Where` -func (s *DB) Not(query interface{}, args ...interface{}) *DB { - return s.clone().search.Not(query, args...).db -} - -// Limit specify the number of records to be retrieved -func (s *DB) Limit(limit interface{}) *DB { - return s.clone().search.Limit(limit).db -} - -// Offset specify the number of records to skip before starting to return the records -func (s *DB) Offset(offset interface{}) *DB { - return s.clone().search.Offset(offset).db -} - -// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions -// db.Order("name DESC") -// db.Order("name DESC", true) // reorder -// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression -func (s *DB) Order(value interface{}, reorder ...bool) *DB { - return s.clone().search.Order(value, reorder...).db -} - -// Select specify fields that you want to retrieve from database when querying, by default, will select all fields; -// When creating/updating, specify fields that you want to save to database -func (s *DB) Select(query interface{}, args ...interface{}) *DB { - return s.clone().search.Select(query, args...).db -} - -// Omit specify fields that you want to ignore when saving to database for creating, updating -func (s *DB) Omit(columns ...string) *DB { - return s.clone().search.Omit(columns...).db -} - -// Group specify the group method on the find -func (s *DB) Group(query string) *DB { - return s.clone().search.Group(query).db -} - -// Having specify HAVING conditions for GROUP BY -func (s *DB) Having(query interface{}, values ...interface{}) *DB { - return s.clone().search.Having(query, values...).db -} - -// Joins specify Joins conditions -// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) -func (s *DB) Joins(query string, args ...interface{}) *DB { - return s.clone().search.Joins(query, args...).db -} - -// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically -// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { -// return db.Where("amount > ?", 1000) -// } -// -// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { -// return func (db *gorm.DB) *gorm.DB { -// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) -// } -// } -// -// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) -// Refer https://jinzhu.github.io/gorm/crud.html#scopes -func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { - for _, f := range funcs { - s = f(s) - } - return s -} - -// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/crud.html#soft-delete -func (s *DB) Unscoped() *DB { - return s.clone().search.unscoped().db -} - -// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) Attrs(attrs ...interface{}) *DB { - return s.clone().search.Attrs(attrs...).db -} - -// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) Assign(attrs ...interface{}) *DB { - return s.clone().search.Assign(attrs...).db -} - -// First find first record that match given conditions, order by primary key -func (s *DB) First(out interface{}, where ...interface{}) *DB { - newScope := s.NewScope(out) - newScope.Search.Limit(1) - - return newScope.Set("gorm:order_by_primary_key", "ASC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Take return a record that match given conditions, the order will depend on the database implementation -func (s *DB) Take(out interface{}, where ...interface{}) *DB { - newScope := s.NewScope(out) - newScope.Search.Limit(1) - return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Last find last record that match given conditions, order by primary key -func (s *DB) Last(out interface{}, where ...interface{}) *DB { - newScope := s.NewScope(out) - newScope.Search.Limit(1) - return newScope.Set("gorm:order_by_primary_key", "DESC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Find find records that match given conditions -func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -//Preloads preloads relations, don`t touch out -func (s *DB) Preloads(out interface{}) *DB { - return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db -} - -// Scan scan value to a struct -func (s *DB) Scan(dest interface{}) *DB { - return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db -} - -// Row return `*sql.Row` with given conditions -func (s *DB) Row() *sql.Row { - return s.NewScope(s.Value).row() -} - -// Rows return `*sql.Rows` with given conditions -func (s *DB) Rows() (*sql.Rows, error) { - return s.NewScope(s.Value).rows() -} - -// ScanRows scan `*sql.Rows` to give struct -func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { - var ( - scope = s.NewScope(result) - clone = scope.db - columns, err = rows.Columns() - ) - - if clone.AddError(err) == nil { - scope.scan(rows, columns, scope.Fields()) - } - - return clone.Error -} - -// Pluck used to query single column from a model as a map -// var ages []int64 -// db.Find(&users).Pluck("age", &ages) -func (s *DB) Pluck(column string, value interface{}) *DB { - return s.NewScope(s.Value).pluck(column, value).db -} - -// Count get how many records for a model -func (s *DB) Count(value interface{}) *DB { - return s.NewScope(s.Value).count(value).db -} - -// Related get related associations -func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.NewScope(s.Value).related(value, foreignKeys...).db -} - -// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/crud.html#firstorinit -func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { - c := s.clone() - if result := c.First(out, where...); result.Error != nil { - if !result.RecordNotFound() { - return result - } - c.NewScope(out).inlineCondition(where...).initialize() - } else { - c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs) - } - return c -} - -// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { - c := s.clone() - if result := s.First(out, where...); result.Error != nil { - if !result.RecordNotFound() { - return result - } - return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db - } else if len(c.search.assignAttrs) > 0 { - return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db - } - return c -} - -// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -// WARNING when update with struct, GORM will not update fields that with zero value -func (s *DB) Update(attrs ...interface{}) *DB { - return s.Updates(toSearchableMap(attrs...), true) -} - -// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { - return s.NewScope(s.Value). - Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). - InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db -} - -// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) UpdateColumn(attrs ...interface{}) *DB { - return s.UpdateColumns(toSearchableMap(attrs...)) -} - -// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) UpdateColumns(values interface{}) *DB { - return s.NewScope(s.Value). - Set("gorm:update_column", true). - Set("gorm:save_associations", false). - InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db -} - -// Save update value in database, if the value doesn't have primary key, will insert it -func (s *DB) Save(value interface{}) *DB { - scope := s.NewScope(value) - if !scope.PrimaryKeyZero() { - newDB := scope.callCallbacks(s.parent.callbacks.updates).db - if newDB.Error == nil && newDB.RowsAffected == 0 { - return s.New().Table(scope.TableName()).FirstOrCreate(value) - } - return newDB - } - return scope.callCallbacks(s.parent.callbacks.creates).db -} - -// Create insert the value into database -func (s *DB) Create(value interface{}) *DB { - scope := s.NewScope(value) - return scope.callCallbacks(s.parent.callbacks.creates).db -} - -// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition -// WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time -func (s *DB) Delete(value interface{}, where ...interface{}) *DB { - return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db -} - -// Raw use raw sql as conditions, won't run it unless invoked by other methods -// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) -func (s *DB) Raw(sql string, values ...interface{}) *DB { - return s.clone().search.Raw(true).Where(sql, values...).db -} - -// Exec execute raw sql -func (s *DB) Exec(sql string, values ...interface{}) *DB { - scope := s.NewScope(nil) - generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true) - generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") - scope.Raw(generatedSQL) - return scope.Exec().db -} - -// Model specify the model you would like to run db operations -// // update all users's name to `hello` -// db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` -// db.Model(&user).Update("name", "hello") -func (s *DB) Model(value interface{}) *DB { - c := s.clone() - c.Value = value - return c -} - -// Table specify the table you would like to run db operations -func (s *DB) Table(name string) *DB { - clone := s.clone() - clone.search.Table(name) - clone.Value = nil - return clone -} - -// Debug start debug mode -func (s *DB) Debug() *DB { - return s.clone().LogMode(true) -} - -// Transaction start a transaction as a block, -// return error will rollback, otherwise to commit. -func (s *DB) Transaction(fc func(tx *DB) error) (err error) { - panicked := true - tx := s.Begin() - defer func() { - // Make sure to rollback when panic, Block error or Commit error - if panicked || err != nil { - tx.Rollback() - } - }() - - err = fc(tx) - - if err == nil { - err = tx.Commit().Error - } - - panicked = false - return -} - -// Begin begins a transaction -func (s *DB) Begin() *DB { - return s.BeginTx(context.Background(), &sql.TxOptions{}) -} - -// BeginTx begins a transaction with options -func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB { - c := s.clone() - if db, ok := c.db.(sqlDb); ok && db != nil { - tx, err := db.BeginTx(ctx, opts) - c.db = interface{}(tx).(SQLCommon) - - c.dialect.SetDB(c.db) - c.AddError(err) - } else { - c.AddError(ErrCantStartTransaction) - } - return c -} - -// Commit commit a transaction -func (s *DB) Commit() *DB { - var emptySQLTx *sql.Tx - if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - s.AddError(db.Commit()) - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// Rollback rollback a transaction -func (s *DB) Rollback() *DB { - var emptySQLTx *sql.Tx - if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - if err := db.Rollback(); err != nil && err != sql.ErrTxDone { - s.AddError(err) - } - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// RollbackUnlessCommitted rollback a transaction if it has not yet been -// committed. -func (s *DB) RollbackUnlessCommitted() *DB { - var emptySQLTx *sql.Tx - if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - err := db.Rollback() - // Ignore the error indicating that the transaction has already - // been committed. - if err != sql.ErrTxDone { - s.AddError(err) - } - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// NewRecord check if value's primary key is blank -func (s *DB) NewRecord(value interface{}) bool { - return s.NewScope(value).PrimaryKeyZero() -} - -// RecordNotFound check if returning ErrRecordNotFound error -func (s *DB) RecordNotFound() bool { - for _, err := range s.GetErrors() { - if err == ErrRecordNotFound { - return true - } - } - return false -} - -// CreateTable create table for models -func (s *DB) CreateTable(models ...interface{}) *DB { - db := s.Unscoped() - for _, model := range models { - db = db.NewScope(model).createTable().db - } - return db -} - -// DropTable drop table for models -func (s *DB) DropTable(values ...interface{}) *DB { - db := s.clone() - for _, value := range values { - if tableName, ok := value.(string); ok { - db = db.Table(tableName) - } - - db = db.NewScope(value).dropTable().db - } - return db -} - -// DropTableIfExists drop table if it is exist -func (s *DB) DropTableIfExists(values ...interface{}) *DB { - db := s.clone() - for _, value := range values { - if s.HasTable(value) { - db.AddError(s.DropTable(value).Error) - } - } - return db -} - -// HasTable check has table or not -func (s *DB) HasTable(value interface{}) bool { - var ( - scope = s.NewScope(value) - tableName string - ) - - if name, ok := value.(string); ok { - tableName = name - } else { - tableName = scope.TableName() - } - - has := scope.Dialect().HasTable(tableName) - s.AddError(scope.db.Error) - return has -} - -// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data -func (s *DB) AutoMigrate(values ...interface{}) *DB { - db := s.Unscoped() - for _, value := range values { - db = db.NewScope(value).autoMigrate().db - } - return db -} - -// ModifyColumn modify column to type -func (s *DB) ModifyColumn(column string, typ string) *DB { - scope := s.NewScope(s.Value) - scope.modifyColumn(column, typ) - return scope.db -} - -// DropColumn drop a column -func (s *DB) DropColumn(column string) *DB { - scope := s.NewScope(s.Value) - scope.dropColumn(column) - return scope.db -} - -// AddIndex add index for columns with given name -func (s *DB) AddIndex(indexName string, columns ...string) *DB { - scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(false, indexName, columns...) - return scope.db -} - -// AddUniqueIndex add unique index for columns with given name -func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { - scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(true, indexName, columns...) - return scope.db -} - -// RemoveIndex remove index with name -func (s *DB) RemoveIndex(indexName string) *DB { - scope := s.NewScope(s.Value) - scope.removeIndex(indexName) - return scope.db -} - -// AddForeignKey Add foreign key to the given scope, e.g: -// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") -func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { - scope := s.NewScope(s.Value) - scope.addForeignKey(field, dest, onDelete, onUpdate) - return scope.db -} - -// RemoveForeignKey Remove foreign key from the given scope, e.g: -// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") -func (s *DB) RemoveForeignKey(field string, dest string) *DB { - scope := s.clone().NewScope(s.Value) - scope.removeForeignKey(field, dest) - return scope.db -} - -// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode -func (s *DB) Association(column string) *Association { - var err error - var scope = s.Set("gorm:association:source", s.Value).NewScope(s.Value) - - if primaryField := scope.PrimaryField(); primaryField.IsBlank { - err = errors.New("primary key can't be nil") - } else { - if field, ok := scope.FieldByName(column); ok { - if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { - err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) - } else { - return &Association{scope: scope, column: column, field: field} - } - } else { - err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) - } - } - - return &Association{Error: err} -} - -// Preload preload associations with given conditions -// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) -func (s *DB) Preload(column string, conditions ...interface{}) *DB { - return s.clone().search.Preload(column, conditions...).db -} - -// Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting -func (s *DB) Set(name string, value interface{}) *DB { - return s.clone().InstantSet(name, value) -} - -// InstantSet instant set setting, will affect current db -func (s *DB) InstantSet(name string, value interface{}) *DB { - s.values.Store(name, value) - return s -} - -// Get get setting by name -func (s *DB) Get(name string) (value interface{}, ok bool) { - value, ok = s.values.Load(name) - return -} - -// SetJoinTableHandler set a model's join table handler for a relation -func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { - scope := s.NewScope(source) - for _, field := range scope.GetModelStruct().StructFields { - if field.Name == column || field.DBName == column { - if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { - source := (&Scope{Value: source}).GetModelStruct().ModelType - destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType - handler.Setup(field.Relationship, many2many, source, destination) - field.Relationship.JoinTableHandler = handler - if table := handler.Table(s); scope.Dialect().HasTable(table) { - s.Table(table).AutoMigrate(handler) - } - } - } - } -} - -// AddError add error to the db -func (s *DB) AddError(err error) error { - if err != nil { - if err != ErrRecordNotFound { - if s.logMode == defaultLogMode { - go s.print("error", fileWithLineNum(), err) - } else { - s.log(err) - } - - errors := Errors(s.GetErrors()) - errors = errors.Add(err) - if len(errors) > 1 { - err = errors - } - } - - s.Error = err - } - return err -} - -// GetErrors get happened errors from the db -func (s *DB) GetErrors() []error { - if errs, ok := s.Error.(Errors); ok { - return errs - } else if s.Error != nil { - return []error{s.Error} - } - return []error{} -} - -//////////////////////////////////////////////////////////////////////////////// -// Private Methods For DB -//////////////////////////////////////////////////////////////////////////////// - -func (s *DB) clone() *DB { - db := &DB{ - db: s.db, - parent: s.parent, - logger: s.logger, - logMode: s.logMode, - Value: s.Value, - Error: s.Error, - blockGlobalUpdate: s.blockGlobalUpdate, - dialect: newDialect(s.dialect.GetName(), s.db), - nowFuncOverride: s.nowFuncOverride, - } - - s.values.Range(func(k, v interface{}) bool { - db.values.Store(k, v) - return true - }) - - if s.search == nil { - db.search = &search{limit: -1, offset: -1} - } else { - db.search = s.search.clone() - } - - db.search.db = db - return db -} - -func (s *DB) print(v ...interface{}) { - s.logger.Print(v...) -} - -func (s *DB) log(v ...interface{}) { - if s != nil && s.logMode == detailedLogMode { - s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) - } -} - -func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { - if s.logMode == detailedLogMode { - s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected) - } -} diff --git a/main_test.go b/main_test.go deleted file mode 100644 index b51fe413..00000000 --- a/main_test.go +++ /dev/null @@ -1,1444 +0,0 @@ -package gorm_test - -// Run tests -// $ docker-compose up -// $ ./test_all.sh - -import ( - "context" - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "os" - "path/filepath" - "reflect" - "strconv" - "strings" - "sync" - "testing" - "time" - - "github.com/erikstmartin/go-testdb" - "github.com/jinzhu/gorm" - _ "github.com/jinzhu/gorm/dialects/mssql" - _ "github.com/jinzhu/gorm/dialects/mysql" - "github.com/jinzhu/gorm/dialects/postgres" - _ "github.com/jinzhu/gorm/dialects/sqlite" - "github.com/jinzhu/now" -) - -var ( - DB *gorm.DB - t1, t2, t3, t4, t5 time.Time -) - -func init() { - var err error - - if DB, err = OpenTestConnection(); err != nil { - panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err)) - } - - runMigration() -} - -func OpenTestConnection() (db *gorm.DB, err error) { - dbDSN := os.Getenv("GORM_DSN") - switch os.Getenv("GORM_DIALECT") { - case "mysql": - fmt.Println("testing mysql...") - if dbDSN == "" { - dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" - } - db, err = gorm.Open("mysql", dbDSN) - case "postgres": - fmt.Println("testing postgres...") - if dbDSN == "" { - dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" - } - db, err = gorm.Open("postgres", dbDSN) - case "mssql": - // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; - // CREATE DATABASE gorm; - // USE gorm; - // CREATE USER gorm FROM LOGIN gorm; - // sp_changedbowner 'gorm'; - fmt.Println("testing mssql...") - if dbDSN == "" { - dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" - } - db, err = gorm.Open("mssql", dbDSN) - default: - fmt.Println("testing sqlite3...") - db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) - } - - // db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) - // db.SetLogger(log.New(os.Stdout, "\r\n", 0)) - if debug := os.Getenv("DEBUG"); debug == "true" { - db.LogMode(true) - } else if debug == "false" { - db.LogMode(false) - } - - db.DB().SetMaxIdleConns(10) - - return -} - -func TestOpen_ReturnsError_WithBadArgs(t *testing.T) { - stringRef := "foo" - testCases := []interface{}{42, time.Now(), &stringRef} - for _, tc := range testCases { - t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) { - _, err := gorm.Open("postgresql", tc) - if err == nil { - t.Error("Should got error with invalid database source") - } - if !strings.HasPrefix(err.Error(), "invalid database source:") { - t.Errorf("Should got error starting with \"invalid database source:\", but got %q", err.Error()) - } - }) - } -} - -func TestStringPrimaryKey(t *testing.T) { - type UUIDStruct struct { - ID string `gorm:"primary_key"` - Name string - } - DB.DropTable(&UUIDStruct{}) - DB.AutoMigrate(&UUIDStruct{}) - - data := UUIDStruct{ID: "uuid", Name: "hello"} - if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.Name != "hello" { - t.Errorf("string primary key should not be populated") - } - - data = UUIDStruct{ID: "uuid", Name: "hello world"} - if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.Name != "hello world" { - t.Errorf("string primary key should not be populated") - } -} - -func TestExceptionsWithInvalidSql(t *testing.T) { - var columns []string - if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - var count1, count2 int64 - DB.Model(&User{}).Count(&count1) - if count1 <= 0 { - t.Errorf("Should find some users") - } - - if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - DB.Model(&User{}).Count(&count2) - if count1 != count2 { - t.Errorf("No user should not be deleted by invalid SQL") - } -} - -func TestSetTable(t *testing.T) { - DB.Create(getPreparedUser("pluck_user1", "pluck_user")) - DB.Create(getPreparedUser("pluck_user2", "pluck_user")) - DB.Create(getPreparedUser("pluck_user3", "pluck_user")) - - if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil { - t.Error("No errors should happen if set table for pluck", err) - } - - var users []User - if DB.Table("users").Find(&[]User{}).Error != nil { - t.Errorf("No errors should happen if set table for find") - } - - if DB.Table("invalid_table").Find(&users).Error == nil { - t.Errorf("Should got error when table is set to an invalid table") - } - - DB.Exec("drop table deleted_users;") - if DB.Table("deleted_users").CreateTable(&User{}).Error != nil { - t.Errorf("Create table with specified table") - } - - DB.Table("deleted_users").Save(&User{Name: "DeletedUser"}) - - var deletedUsers []User - DB.Table("deleted_users").Find(&deletedUsers) - if len(deletedUsers) != 1 { - t.Errorf("Query from specified table") - } - - var user User - DB.Table("deleted_users").First(&user, "name = ?", "DeletedUser") - - user.Age = 20 - DB.Table("deleted_users").Save(&user) - if DB.Table("deleted_users").First(&user, "name = ? AND age = ?", "DeletedUser", 20).RecordNotFound() { - t.Errorf("Failed to found updated user") - } - - DB.Save(getPreparedUser("normal_user", "reset_table")) - DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table")) - var user1, user2, user3 User - DB.Where("role = ?", "reset_table").First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3) - if (user1.Name != "normal_user") || (user2.Name != "deleted_user") || (user3.Name != "normal_user") { - t.Errorf("unset specified table with blank string") - } -} - -type Order struct { -} - -type Cart struct { -} - -func (c Cart) TableName() string { - return "shopping_cart" -} - -func TestHasTable(t *testing.T) { - type Foo struct { - Id int - Stuff string - } - DB.DropTable(&Foo{}) - - // Table should not exist at this point, HasTable should return false - if ok := DB.HasTable("foos"); ok { - t.Errorf("Table should not exist, but does") - } - if ok := DB.HasTable(&Foo{}); ok { - t.Errorf("Table should not exist, but does") - } - - // We create the table - if err := DB.CreateTable(&Foo{}).Error; err != nil { - t.Errorf("Table should be created") - } - - // And now it should exits, and HasTable should return true - if ok := DB.HasTable("foos"); !ok { - t.Errorf("Table should exist, but HasTable informs it does not") - } - if ok := DB.HasTable(&Foo{}); !ok { - t.Errorf("Table should exist, but HasTable informs it does not") - } -} - -func TestTableName(t *testing.T) { - DB := DB.Model("") - if DB.NewScope(Order{}).TableName() != "orders" { - t.Errorf("Order's table name should be orders") - } - - if DB.NewScope(&Order{}).TableName() != "orders" { - t.Errorf("&Order's table name should be orders") - } - - if DB.NewScope([]Order{}).TableName() != "orders" { - t.Errorf("[]Order's table name should be orders") - } - - if DB.NewScope(&[]Order{}).TableName() != "orders" { - t.Errorf("&[]Order's table name should be orders") - } - - DB.SingularTable(true) - if DB.NewScope(Order{}).TableName() != "order" { - t.Errorf("Order's singular table name should be order") - } - - if DB.NewScope(&Order{}).TableName() != "order" { - t.Errorf("&Order's singular table name should be order") - } - - if DB.NewScope([]Order{}).TableName() != "order" { - t.Errorf("[]Order's singular table name should be order") - } - - if DB.NewScope(&[]Order{}).TableName() != "order" { - t.Errorf("&[]Order's singular table name should be order") - } - - if DB.NewScope(&Cart{}).TableName() != "shopping_cart" { - t.Errorf("&Cart's singular table name should be shopping_cart") - } - - if DB.NewScope(Cart{}).TableName() != "shopping_cart" { - t.Errorf("Cart's singular table name should be shopping_cart") - } - - if DB.NewScope(&[]Cart{}).TableName() != "shopping_cart" { - t.Errorf("&[]Cart's singular table name should be shopping_cart") - } - - if DB.NewScope([]Cart{}).TableName() != "shopping_cart" { - t.Errorf("[]Cart's singular table name should be shopping_cart") - } - DB.SingularTable(false) -} - -func TestTableNameConcurrently(t *testing.T) { - DB := DB.Model("") - if DB.NewScope(Order{}).TableName() != "orders" { - t.Errorf("Order's table name should be orders") - } - - var wg sync.WaitGroup - wg.Add(10) - - for i := 1; i <= 10; i++ { - go func(db *gorm.DB) { - DB.SingularTable(true) - wg.Done() - }(DB) - } - wg.Wait() - - if DB.NewScope(Order{}).TableName() != "order" { - t.Errorf("Order's singular table name should be order") - } - - DB.SingularTable(false) -} - -func TestNullValues(t *testing.T) { - DB.DropTable(&NullValue{}) - DB.AutoMigrate(&NullValue{}) - - if err := DB.Save(&NullValue{ - Name: sql.NullString{String: "hello", Valid: true}, - Gender: &sql.NullString{String: "M", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: true}, - Male: sql.NullBool{Bool: true, Valid: true}, - Height: sql.NullFloat64{Float64: 100.11, Valid: true}, - AddedAt: NullTime{Time: time.Now(), Valid: true}, - }).Error; err != nil { - t.Errorf("Not error should raise when test null value") - } - - var nv NullValue - DB.First(&nv, "name = ?", "hello") - - if nv.Name.String != "hello" || nv.Gender.String != "M" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true { - t.Errorf("Should be able to fetch null value") - } - - if err := DB.Save(&NullValue{ - Name: sql.NullString{String: "hello-2", Valid: true}, - Gender: &sql.NullString{String: "F", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: false}, - Male: sql.NullBool{Bool: true, Valid: true}, - Height: sql.NullFloat64{Float64: 100.11, Valid: true}, - AddedAt: NullTime{Time: time.Now(), Valid: false}, - }).Error; err != nil { - t.Errorf("Not error should raise when test null value") - } - - var nv2 NullValue - DB.First(&nv2, "name = ?", "hello-2") - if nv2.Name.String != "hello-2" || nv2.Gender.String != "F" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false { - t.Errorf("Should be able to fetch null value") - } - - if err := DB.Save(&NullValue{ - Name: sql.NullString{String: "hello-3", Valid: false}, - Gender: &sql.NullString{String: "M", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: false}, - Male: sql.NullBool{Bool: true, Valid: true}, - Height: sql.NullFloat64{Float64: 100.11, Valid: true}, - AddedAt: NullTime{Time: time.Now(), Valid: false}, - }).Error; err == nil { - t.Errorf("Can't save because of name can't be null") - } -} - -func TestNullValuesWithFirstOrCreate(t *testing.T) { - var nv1 = NullValue{ - Name: sql.NullString{String: "first_or_create", Valid: true}, - Gender: &sql.NullString{String: "M", Valid: true}, - } - - var nv2 NullValue - result := DB.Where(nv1).FirstOrCreate(&nv2) - - if result.RowsAffected != 1 { - t.Errorf("RowsAffected should be 1 after create some record") - } - - if result.Error != nil { - t.Errorf("Should not raise any error, but got %v", result.Error) - } - - if nv2.Name.String != "first_or_create" || nv2.Gender.String != "M" { - t.Errorf("first or create with nullvalues") - } - - if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).Error; err != nil { - t.Errorf("Should not raise any error, but got %v", err) - } - - if nv2.Age.Int64 != 18 { - t.Errorf("should update age to 18") - } -} - -func TestTransaction(t *testing.T) { - tx := DB.Begin() - u := User{Name: "transcation"} - if err := tx.Save(&u).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { - t.Errorf("Should find saved record") - } - - if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil { - t.Errorf("Should return the underlying sql.Tx") - } - - tx.Rollback() - - if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { - t.Errorf("Should not find record after rollback") - } - - tx2 := DB.Begin() - u2 := User{Name: "transcation-2"} - if err := tx2.Save(&u2).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should find saved record") - } - - tx2.Commit() - - if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should be able to find committed record") - } - - tx3 := DB.Begin() - u3 := User{Name: "transcation-3"} - if err := tx3.Save(&u3).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx3.First(&User{}, "name = ?", "transcation-3").Error; err != nil { - t.Errorf("Should find saved record") - } - - tx3.RollbackUnlessCommitted() - - if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { - t.Errorf("Should not find record after rollback") - } - - tx4 := DB.Begin() - u4 := User{Name: "transcation-4"} - if err := tx4.Save(&u4).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx4.First(&User{}, "name = ?", "transcation-4").Error; err != nil { - t.Errorf("Should find saved record") - } - - tx4.Commit() - - tx4.RollbackUnlessCommitted() - - if err := DB.First(&User{}, "name = ?", "transcation-4").Error; err != nil { - t.Errorf("Should be able to find committed record") - } -} - -func assertPanic(t *testing.T, f func()) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - f() -} - -func TestTransactionWithBlock(t *testing.T) { - // rollback - err := DB.Transaction(func(tx *gorm.DB) error { - u := User{Name: "transcation"} - if err := tx.Save(&u).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { - t.Errorf("Should find saved record") - } - - return errors.New("the error message") - }) - - if err.Error() != "the error message" { - t.Errorf("Transaction return error will equal the block returns error") - } - - if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil { - t.Errorf("Should not find record after rollback") - } - - // commit - DB.Transaction(func(tx *gorm.DB) error { - u2 := User{Name: "transcation-2"} - if err := tx.Save(&u2).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should find saved record") - } - return nil - }) - - if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should be able to find committed record") - } - - // panic will rollback - assertPanic(t, func() { - DB.Transaction(func(tx *gorm.DB) error { - u3 := User{Name: "transcation-3"} - if err := tx.Save(&u3).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx.First(&User{}, "name = ?", "transcation-3").Error; err != nil { - t.Errorf("Should find saved record") - } - - panic("force panic") - }) - }) - - if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil { - t.Errorf("Should not find record after panic rollback") - } -} - -func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) { - tx := DB.Begin() - u := User{Name: "transcation"} - if err := tx.Save(&u).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx.Commit().Error; err != nil { - t.Errorf("Commit should not raise error") - } - - if err := tx.Rollback().Error; err != nil { - t.Errorf("Rollback should not raise error") - } -} - -func TestTransactionReadonly(t *testing.T) { - dialect := os.Getenv("GORM_DIALECT") - if dialect == "" { - dialect = "sqlite" - } - switch dialect { - case "mssql", "sqlite": - t.Skipf("%s does not support readonly transactions\n", dialect) - } - - tx := DB.Begin() - u := User{Name: "transcation"} - if err := tx.Save(&u).Error; err != nil { - t.Errorf("No error should raise") - } - tx.Commit() - - tx = DB.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) - if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { - t.Errorf("Should find saved record") - } - - if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil { - t.Errorf("Should return the underlying sql.Tx") - } - - u = User{Name: "transcation-2"} - if err := tx.Save(&u).Error; err == nil { - t.Errorf("Error should have been raised in a readonly transaction") - } - - tx.Rollback() -} - -func TestRow(t *testing.T) { - user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "RowUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row() - var age int64 - row.Scan(&age) - if age != 10 { - t.Errorf("Scan with Row") - } -} - -func TestRows(t *testing.T) { - user1 := User{Name: "RowsUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "RowsUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "RowsUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() - if err != nil { - t.Errorf("Not error should happen, got %v", err) - } - - count := 0 - for rows.Next() { - var name string - var age int64 - rows.Scan(&name, &age) - count++ - } - - if count != 2 { - t.Errorf("Should found two records") - } -} - -func TestScanRows(t *testing.T) { - user1 := User{Name: "ScanRowsUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "ScanRowsUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "ScanRowsUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() - if err != nil { - t.Errorf("Not error should happen, got %v", err) - } - - type Result struct { - Name string - Age int - } - - var results []Result - for rows.Next() { - var result Result - if err := DB.ScanRows(rows, &result); err != nil { - t.Errorf("should get no error, but got %v", err) - } - results = append(results, result) - } - - if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { - t.Errorf("Should find expected results") - } -} - -func TestScan(t *testing.T) { - user1 := User{Name: "ScanUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "ScanUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "ScanUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - type result struct { - Name string - Age int - } - - var res result - DB.Table("users").Select("name, age").Where("name = ?", user3.Name).Scan(&res) - if res.Name != user3.Name { - t.Errorf("Scan into struct should work") - } - - var doubleAgeRes = &result{} - if err := DB.Table("users").Select("age + age as age").Where("name = ?", user3.Name).Scan(&doubleAgeRes).Error; err != nil { - t.Errorf("Scan to pointer of pointer") - } - if doubleAgeRes.Age != res.Age*2 { - t.Errorf("Scan double age as age") - } - - var ress []result - DB.Table("users").Select("name, age").Where("name in (?)", []string{user2.Name, user3.Name}).Scan(&ress) - if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { - t.Errorf("Scan into struct map") - } -} - -func TestRaw(t *testing.T) { - user1 := User{Name: "ExecRawSqlUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "ExecRawSqlUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "ExecRawSqlUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - type result struct { - Name string - Email string - } - - var ress []result - DB.Raw("SELECT name, age FROM users WHERE name = ? or name = ?", user2.Name, user3.Name).Scan(&ress) - if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { - t.Errorf("Raw with scan") - } - - rows, _ := DB.Raw("select name, age from users where name = ?", user3.Name).Rows() - count := 0 - for rows.Next() { - count++ - } - if count != 1 { - t.Errorf("Raw with Rows should find one record with name 3") - } - - DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name}) - if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound { - t.Error("Raw sql to update records") - } -} - -func TestGroup(t *testing.T) { - rows, err := DB.Select("name").Table("users").Group("name").Rows() - - if err == nil { - defer rows.Close() - for rows.Next() { - var name string - rows.Scan(&name) - } - } else { - t.Errorf("Should not raise any error") - } -} - -func TestJoins(t *testing.T) { - var user = User{ - Name: "joins", - CreditCard: CreditCard{Number: "411111111111"}, - Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, - } - DB.Save(&user) - - var users1 []User - DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Find(&users1) - if len(users1) != 2 { - t.Errorf("should find two users using left join") - } - - var users2 []User - DB.Joins("left join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Where("name = ?", "joins").First(&users2) - if len(users2) != 1 { - t.Errorf("should find one users using left join with conditions") - } - - var users3 []User - DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where("name = ?", "joins").First(&users3) - if len(users3) != 1 { - t.Errorf("should find one users using multiple left join conditions") - } - - var users4 []User - DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "422222222222").Where("name = ?", "joins").First(&users4) - if len(users4) != 0 { - t.Errorf("should find no user when searching with unexisting credit card") - } - - var users5 []User - db5 := DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where(User{Id: 1}).Where(Email{Id: 1}).Not(Email{Id: 10}).First(&users5) - if db5.Error != nil { - t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) - } -} - -type JoinedIds struct { - UserID int64 `gorm:"column:id"` - BillingAddressID int64 `gorm:"column:id"` - EmailID int64 `gorm:"column:id"` -} - -func TestScanIdenticalColumnNames(t *testing.T) { - var user = User{ - Name: "joinsIds", - Email: "joinIds@example.com", - BillingAddress: Address{ - Address1: "One Park Place", - }, - Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, - } - DB.Save(&user) - - var users []JoinedIds - DB.Select("users.id, addresses.id, emails.id").Table("users"). - Joins("left join addresses on users.billing_address_id = addresses.id"). - Joins("left join emails on emails.user_id = users.id"). - Where("name = ?", "joinsIds").Scan(&users) - - if len(users) != 2 { - t.Fatal("should find two rows using left join") - } - - if user.Id != users[0].UserID { - t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[0].UserID) - } - if user.Id != users[1].UserID { - t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[1].UserID) - } - - if user.BillingAddressID.Int64 != users[0].BillingAddressID { - t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID) - } - if user.BillingAddressID.Int64 != users[1].BillingAddressID { - t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID) - } - - if users[0].EmailID == users[1].EmailID { - t.Errorf("Email ids should be unique. Got %d and %d", users[0].EmailID, users[1].EmailID) - } - - if int64(user.Emails[0].Id) != users[0].EmailID && int64(user.Emails[1].Id) != users[0].EmailID { - t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[0].EmailID) - } - - if int64(user.Emails[0].Id) != users[1].EmailID && int64(user.Emails[1].Id) != users[1].EmailID { - t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[1].EmailID) - } -} - -func TestJoinsWithSelect(t *testing.T) { - type result struct { - Name string - Email string - } - - user := User{ - Name: "joins_with_select", - Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, - } - DB.Save(&user) - - var results []result - DB.Table("users").Select("name, emails.email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results) - if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" { - t.Errorf("Should find all two emails with Join select") - } -} - -func TestHaving(t *testing.T) { - rows, err := DB.Select("name, count(*) as total").Table("users").Group("name").Having("name IN (?)", []string{"2", "3"}).Rows() - - if err == nil { - defer rows.Close() - for rows.Next() { - var name string - var total int64 - rows.Scan(&name, &total) - - if name == "2" && total != 1 { - t.Errorf("Should have one user having name 2") - } - if name == "3" && total != 2 { - t.Errorf("Should have two users having name 3") - } - } - } else { - t.Errorf("Should not raise any error") - } -} - -func TestQueryBuilderSubselectInWhere(t *testing.T) { - user := User{Name: "query_expr_select_ruser1", Email: "root@user1.com", Age: 32} - DB.Save(&user) - user = User{Name: "query_expr_select_ruser2", Email: "nobody@user2.com", Age: 16} - DB.Save(&user) - user = User{Name: "query_expr_select_ruser3", Email: "root@user3.com", Age: 64} - DB.Save(&user) - user = User{Name: "query_expr_select_ruser4", Email: "somebody@user3.com", Age: 128} - DB.Save(&user) - - var users []User - DB.Select("*").Where("name IN (?)", DB. - Select("name").Table("users").Where("name LIKE ?", "query_expr_select%").QueryExpr()).Find(&users) - - if len(users) != 4 { - t.Errorf("Four users should be found, instead found %d", len(users)) - } - - DB.Select("*").Where("name LIKE ?", "query_expr_select%").Where("age >= (?)", DB. - Select("AVG(age)").Table("users").Where("name LIKE ?", "query_expr_select%").QueryExpr()).Find(&users) - - if len(users) != 2 { - t.Errorf("Two users should be found, instead found %d", len(users)) - } -} - -func TestQueryBuilderRawQueryWithSubquery(t *testing.T) { - user := User{Name: "subquery_test_user1", Age: 10} - DB.Save(&user) - user = User{Name: "subquery_test_user2", Age: 11} - DB.Save(&user) - user = User{Name: "subquery_test_user3", Age: 12} - DB.Save(&user) - - var count int - err := DB.Raw("select count(*) from (?) tmp", - DB.Table("users"). - Select("name"). - Where("age >= ? and name in (?)", 10, []string{"subquery_test_user1", "subquery_test_user2"}). - Group("name"). - QueryExpr(), - ).Count(&count).Error - - if err != nil { - t.Errorf("Expected to get no errors, but got %v", err) - } - if count != 2 { - t.Errorf("Row count must be 2, instead got %d", count) - } - - err = DB.Raw("select count(*) from (?) tmp", - DB.Table("users"). - Select("name"). - Where("name LIKE ?", "subquery_test%"). - Not("age <= ?", 10).Not("name in (?)", []string{"subquery_test_user1", "subquery_test_user2"}). - Group("name"). - QueryExpr(), - ).Count(&count).Error - - if err != nil { - t.Errorf("Expected to get no errors, but got %v", err) - } - if count != 1 { - t.Errorf("Row count must be 1, instead got %d", count) - } -} - -func TestQueryBuilderSubselectInHaving(t *testing.T) { - user := User{Name: "query_expr_having_ruser1", Email: "root@user1.com", Age: 64} - DB.Save(&user) - user = User{Name: "query_expr_having_ruser2", Email: "root@user2.com", Age: 128} - DB.Save(&user) - user = User{Name: "query_expr_having_ruser3", Email: "root@user1.com", Age: 64} - DB.Save(&user) - user = User{Name: "query_expr_having_ruser4", Email: "root@user2.com", Age: 128} - DB.Save(&user) - - var users []User - DB.Select("AVG(age) as avgage").Where("name LIKE ?", "query_expr_having_%").Group("email").Having("AVG(age) > (?)", DB. - Select("AVG(age)").Where("name LIKE ?", "query_expr_having_%").Table("users").QueryExpr()).Find(&users) - - if len(users) != 1 { - t.Errorf("Two user group should be found, instead found %d", len(users)) - } -} - -func DialectHasTzSupport() bool { - // NB: mssql and FoundationDB do not support time zones. - if dialect := os.Getenv("GORM_DIALECT"); dialect == "foundation" { - return false - } - return true -} - -func TestTimeWithZone(t *testing.T) { - var format = "2006-01-02 15:04:05 -0700" - var times []time.Time - GMT8, _ := time.LoadLocation("Asia/Shanghai") - times = append(times, time.Date(2013, 02, 19, 1, 51, 49, 123456789, GMT8)) - times = append(times, time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC)) - - for index, vtime := range times { - name := "time_with_zone_" + strconv.Itoa(index) - user := User{Name: name, Birthday: &vtime} - - if !DialectHasTzSupport() { - // If our driver dialect doesn't support TZ's, just use UTC for everything here. - utcBirthday := user.Birthday.UTC() - user.Birthday = &utcBirthday - } - - DB.Save(&user) - expectedBirthday := "2013-02-18 17:51:49 +0000" - foundBirthday := user.Birthday.UTC().Format(format) - if foundBirthday != expectedBirthday { - t.Errorf("User's birthday should not be changed after save for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday) - } - - var findUser, findUser2, findUser3 User - DB.First(&findUser, "name = ?", name) - foundBirthday = findUser.Birthday.UTC().Format(format) - if foundBirthday != expectedBirthday { - t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday) - } - - if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() { - t.Errorf("User should be found") - } - - if !DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(time.Minute)).First(&findUser3).RecordNotFound() { - t.Errorf("User should not be found") - } - } -} - -func TestHstore(t *testing.T) { - type Details struct { - Id int64 - Bulk postgres.Hstore - } - - if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { - t.Skip() - } - - if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS hstore").Error; err != nil { - fmt.Println("\033[31mHINT: Must be superuser to create hstore extension (ALTER USER gorm WITH SUPERUSER;)\033[0m") - panic(fmt.Sprintf("No error should happen when create hstore extension, but got %+v", err)) - } - - DB.Exec("drop table details") - - if err := DB.CreateTable(&Details{}).Error; err != nil { - panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) - } - - bankAccountId, phoneNumber, opinion := "123456", "14151321232", "sharkbait" - bulk := map[string]*string{ - "bankAccountId": &bankAccountId, - "phoneNumber": &phoneNumber, - "opinion": &opinion, - } - d := Details{Bulk: bulk} - DB.Save(&d) - - var d2 Details - if err := DB.First(&d2).Error; err != nil { - t.Errorf("Got error when tried to fetch details: %+v", err) - } - - for k := range bulk { - if r, ok := d2.Bulk[k]; ok { - if res, _ := bulk[k]; *res != *r { - t.Errorf("Details should be equal") - } - } else { - t.Errorf("Details should be existed") - } - } -} - -func TestSetAndGet(t *testing.T) { - if value, ok := DB.Set("hello", "world").Get("hello"); !ok { - t.Errorf("Should be able to get setting after set") - } else { - if value.(string) != "world" { - t.Errorf("Setted value should not be changed") - } - } - - if _, ok := DB.Get("non_existing"); ok { - t.Errorf("Get non existing key should return error") - } -} - -func TestCompatibilityMode(t *testing.T) { - DB, _ := gorm.Open("testdb", "") - testdb.SetQueryFunc(func(query string) (driver.Rows, error) { - columns := []string{"id", "name", "age"} - result := ` - 1,Tim,20 - 2,Joe,25 - 3,Bob,30 - ` - return testdb.RowsFromCSVString(columns, result), nil - }) - - var users []User - DB.Find(&users) - if (users[0].Name != "Tim") || len(users) != 3 { - t.Errorf("Unexcepted result returned") - } -} - -func TestOpenExistingDB(t *testing.T) { - DB.Save(&User{Name: "jnfeinstein"}) - dialect := os.Getenv("GORM_DIALECT") - - db, err := gorm.Open(dialect, DB.DB()) - if err != nil { - t.Errorf("Should have wrapped the existing DB connection") - } - - var user User - if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound { - t.Errorf("Should have found existing record") - } -} - -func TestDdlErrors(t *testing.T) { - var err error - - if err = DB.Close(); err != nil { - t.Errorf("Closing DDL test db connection err=%s", err) - } - defer func() { - // Reopen DB connection. - if DB, err = OpenTestConnection(); err != nil { - t.Fatalf("Failed re-opening db connection: %s", err) - } - }() - - if err := DB.Find(&User{}).Error; err == nil { - t.Errorf("Expected operation on closed db to produce an error, but err was nil") - } -} - -func TestOpenWithOneParameter(t *testing.T) { - db, err := gorm.Open("dialect") - if db != nil { - t.Error("Open with one parameter returned non nil for db") - } - if err == nil { - t.Error("Open with one parameter returned err as nil") - } -} - -func TestSaveAssociations(t *testing.T) { - db := DB.New() - deltaAddressCount := 0 - if err := db.Model(&Address{}).Count(&deltaAddressCount).Error; err != nil { - t.Errorf("failed to fetch address count") - t.FailNow() - } - - placeAddress := &Address{ - Address1: "somewhere on earth", - } - ownerAddress1 := &Address{ - Address1: "near place address", - } - ownerAddress2 := &Address{ - Address1: "address2", - } - db.Create(placeAddress) - - addressCountShouldBe := func(t *testing.T, expectedCount int) { - countFromDB := 0 - t.Helper() - err := db.Model(&Address{}).Count(&countFromDB).Error - if err != nil { - t.Error("failed to fetch address count") - } - if countFromDB != expectedCount { - t.Errorf("address count mismatch: %d", countFromDB) - } - } - addressCountShouldBe(t, deltaAddressCount+1) - - // owner address should be created, place address should be reused - place1 := &Place{ - PlaceAddressID: placeAddress.ID, - PlaceAddress: placeAddress, - OwnerAddress: ownerAddress1, - } - err := db.Create(place1).Error - if err != nil { - t.Errorf("failed to store place: %s", err.Error()) - } - addressCountShouldBe(t, deltaAddressCount+2) - - // owner address should be created again, place address should be reused - place2 := &Place{ - PlaceAddressID: placeAddress.ID, - PlaceAddress: &Address{ - ID: 777, - Address1: "address1", - }, - OwnerAddress: ownerAddress2, - OwnerAddressID: 778, - } - err = db.Create(place2).Error - if err != nil { - t.Errorf("failed to store place: %s", err.Error()) - } - addressCountShouldBe(t, deltaAddressCount+3) - - count := 0 - db.Model(&Place{}).Where(&Place{ - PlaceAddressID: placeAddress.ID, - OwnerAddressID: ownerAddress1.ID, - }).Count(&count) - if count != 1 { - t.Errorf("only one instance of (%d, %d) should be available, found: %d", - placeAddress.ID, ownerAddress1.ID, count) - } - - db.Model(&Place{}).Where(&Place{ - PlaceAddressID: placeAddress.ID, - OwnerAddressID: ownerAddress2.ID, - }).Count(&count) - if count != 1 { - t.Errorf("only one instance of (%d, %d) should be available, found: %d", - placeAddress.ID, ownerAddress2.ID, count) - } - - db.Model(&Place{}).Where(&Place{ - PlaceAddressID: placeAddress.ID, - }).Count(&count) - if count != 2 { - t.Errorf("two instances of (%d) should be available, found: %d", - placeAddress.ID, count) - } -} - -func TestBlockGlobalUpdate(t *testing.T) { - db := DB.New() - db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"}) - - err := db.Model(&Toy{}).Update("OwnerType", "Human").Error - if err != nil { - t.Error("Unexpected error on global update") - } - - err = db.Delete(&Toy{}).Error - if err != nil { - t.Error("Unexpected error on global delete") - } - - db.BlockGlobalUpdate(true) - - db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"}) - - err = db.Model(&Toy{}).Update("OwnerType", "Human").Error - if err == nil { - t.Error("Expected error on global update") - } - - err = db.Model(&Toy{}).Where(&Toy{OwnerType: "Martian"}).Update("OwnerType", "Astronaut").Error - if err != nil { - t.Error("Unxpected error on conditional update") - } - - err = db.Delete(&Toy{}).Error - if err == nil { - t.Error("Expected error on global delete") - } - err = db.Where(&Toy{OwnerType: "Martian"}).Delete(&Toy{}).Error - if err != nil { - t.Error("Unexpected error on conditional delete") - } -} - -func TestCountWithHaving(t *testing.T) { - db := DB.New() - db.Delete(User{}) - defer db.Delete(User{}) - - DB.Create(getPreparedUser("user1", "pluck_user")) - DB.Create(getPreparedUser("user2", "pluck_user")) - user3 := getPreparedUser("user3", "pluck_user") - user3.Languages = []Language{} - DB.Create(user3) - - var count int - err := db.Model(User{}).Select("users.id"). - Joins("LEFT JOIN user_languages ON user_languages.user_id = users.id"). - Joins("LEFT JOIN languages ON user_languages.language_id = languages.id"). - Group("users.id").Having("COUNT(languages.id) > 1").Count(&count).Error - - if err != nil { - t.Error("Unexpected error on query count with having") - } - - if count != 2 { - t.Error("Unexpected result on query count with having") - } -} - -func TestPluck(t *testing.T) { - db := DB.New() - db.Delete(User{}) - defer db.Delete(User{}) - - DB.Create(&User{Id: 1, Name: "user1"}) - DB.Create(&User{Id: 2, Name: "user2"}) - DB.Create(&User{Id: 3, Name: "user3"}) - - var ids []int64 - err := db.Model(User{}).Order("id").Pluck("id", &ids).Error - - if err != nil { - t.Error("Unexpected error on pluck") - } - - if len(ids) != 3 || ids[0] != 1 || ids[1] != 2 || ids[2] != 3 { - t.Error("Unexpected result on pluck") - } - - err = db.Model(User{}).Order("id").Pluck("id", &ids).Error - - if err != nil { - t.Error("Unexpected error on pluck again") - } - - if len(ids) != 3 || ids[0] != 1 || ids[1] != 2 || ids[2] != 3 { - t.Error("Unexpected result on pluck again") - } -} - -func TestCountWithQueryOption(t *testing.T) { - db := DB.New() - db.Delete(User{}) - defer db.Delete(User{}) - - DB.Create(&User{Name: "user1"}) - DB.Create(&User{Name: "user2"}) - DB.Create(&User{Name: "user3"}) - - var count int - err := db.Model(User{}).Select("users.id"). - Set("gorm:query_option", "WHERE users.name='user2'"). - Count(&count).Error - - if err != nil { - t.Error("Unexpected error on query count with query_option") - } - - if count != 1 { - t.Error("Unexpected result on query count with query_option") - } -} - -func TestQueryHint1(t *testing.T) { - db := DB.New() - - _, err := db.Model(User{}).Raw("select 1").Rows() - - if err != nil { - t.Error("Unexpected error on query count with query_option") - } -} - -func TestQueryHint2(t *testing.T) { - type TestStruct struct { - ID string `gorm:"primary_key"` - Name string - } - DB.DropTable(&TestStruct{}) - DB.AutoMigrate(&TestStruct{}) - - data := TestStruct{ID: "uuid", Name: "hello"} - if err := DB.Set("gorm:query_hint", "/*master*/").Save(&data).Error; err != nil { - t.Error("Unexpected error on query count with query_option") - } -} - -func TestFloatColumnPrecision(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" && dialect != "sqlite" { - t.Skip() - } - - type FloatTest struct { - ID string `gorm:"primary_key"` - FloatValue float64 `gorm:"column:float_value" sql:"type:float(255,5);"` - } - DB.DropTable(&FloatTest{}) - DB.AutoMigrate(&FloatTest{}) - - data := FloatTest{ID: "uuid", FloatValue: 112.57315} - if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.FloatValue != 112.57315 { - t.Errorf("Float value should not lose precision") - } -} - -func TestWhereUpdates(t *testing.T) { - type OwnerEntity struct { - gorm.Model - OwnerID uint - OwnerType string - } - - type SomeEntity struct { - gorm.Model - Name string - OwnerEntity OwnerEntity `gorm:"polymorphic:Owner"` - } - - DB.DropTable(&SomeEntity{}) - DB.AutoMigrate(&SomeEntity{}) - - a := SomeEntity{Name: "test"} - DB.Model(&a).Where(a).Updates(SomeEntity{Name: "test2"}) -} - -func BenchmarkGorm(b *testing.B) { - b.N = 2000 - for x := 0; x < b.N; x++ { - e := strconv.Itoa(x) + "benchmark@example.org" - now := time.Now() - email := EmailWithIdx{Email: e, UserAgent: "pc", RegisteredAt: &now} - // Insert - DB.Save(&email) - // Query - DB.First(&EmailWithIdx{}, "email = ?", e) - // Update - DB.Model(&email).UpdateColumn("email", "new-"+e) - // Delete - DB.Delete(&email) - } -} - -func BenchmarkRawSql(b *testing.B) { - DB, _ := sql.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable") - DB.SetMaxIdleConns(10) - insertSql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id" - querySql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1" - updateSql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3" - deleteSql := "DELETE FROM orders WHERE id = $1" - - b.N = 2000 - for x := 0; x < b.N; x++ { - var id int64 - e := strconv.Itoa(x) + "benchmark@example.org" - now := time.Now() - email := EmailWithIdx{Email: e, UserAgent: "pc", RegisteredAt: &now} - // Insert - DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id) - // Query - rows, _ := DB.Query(querySql, email.Email) - rows.Close() - // Update - DB.Exec(updateSql, "new-"+e, time.Now(), id) - // Delete - DB.Exec(deleteSql, id) - } -} - -func parseTime(str string) *time.Time { - t := now.New(time.Now().UTC()).MustParse(str) - return &t -} diff --git a/migration_test.go b/migration_test.go deleted file mode 100644 index d94ec9ec..00000000 --- a/migration_test.go +++ /dev/null @@ -1,579 +0,0 @@ -package gorm_test - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "os" - "reflect" - "strconv" - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -type User struct { - Id int64 - Age int64 - UserNum Num - Name string `sql:"size:255"` - Email string - Birthday *time.Time // Time - CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically - UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically - Emails []Email // Embedded structs - BillingAddress Address // Embedded struct - BillingAddressID sql.NullInt64 // Embedded struct's foreign key - ShippingAddress Address // Embedded struct - ShippingAddressId int64 // Embedded struct's foreign key - CreditCard CreditCard - Latitude float64 - Languages []Language `gorm:"many2many:user_languages;"` - CompanyID *int - Company Company - Role Role - Password EncryptedData - PasswordHash []byte - IgnoreMe int64 `sql:"-"` - IgnoreStringSlice []string `sql:"-"` - Ignored struct{ Name string } `sql:"-"` - IgnoredPointer *User `sql:"-"` -} - -type NotSoLongTableName struct { - Id int64 - ReallyLongThingID int64 - ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit -} - -type ReallyLongTableNameToTestMySQLNameLengthLimit struct { - Id int64 -} - -type ReallyLongThingThatReferencesShort struct { - Id int64 - ShortID int64 - Short Short -} - -type Short struct { - Id int64 -} - -type CreditCard struct { - ID int8 - Number string - UserId sql.NullInt64 - CreatedAt time.Time `sql:"not null"` - UpdatedAt time.Time - DeletedAt *time.Time `sql:"column:deleted_time"` -} - -type Email struct { - Id int16 - UserId int - Email string `sql:"type:varchar(100);"` - CreatedAt time.Time - UpdatedAt time.Time -} - -type Address struct { - ID int - Address1 string - Address2 string - Post string - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time -} - -type Language struct { - gorm.Model - Name string - Users []User `gorm:"many2many:user_languages;"` -} - -type Product struct { - Id int64 - Code string - Price int64 - CreatedAt time.Time - UpdatedAt time.Time - AfterFindCallTimes int64 - BeforeCreateCallTimes int64 - AfterCreateCallTimes int64 - BeforeUpdateCallTimes int64 - AfterUpdateCallTimes int64 - BeforeSaveCallTimes int64 - AfterSaveCallTimes int64 - BeforeDeleteCallTimes int64 - AfterDeleteCallTimes int64 -} - -type Company struct { - Id int64 - Name string - Owner *User `sql:"-"` -} - -type Place struct { - Id int64 - PlaceAddressID int - PlaceAddress *Address `gorm:"save_associations:false"` - OwnerAddressID int - OwnerAddress *Address `gorm:"save_associations:true"` -} - -type EncryptedData []byte - -func (data *EncryptedData) Scan(value interface{}) error { - if b, ok := value.([]byte); ok { - if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' { - return errors.New("Too short") - } - - *data = b[3:] - return nil - } - - return errors.New("Bytes expected") -} - -func (data EncryptedData) Value() (driver.Value, error) { - if len(data) > 0 && data[0] == 'x' { - //needed to test failures - return nil, errors.New("Should not start with 'x'") - } - - //prepend asterisks - return append([]byte("***"), data...), nil -} - -type Role struct { - Name string `gorm:"size:256"` -} - -func (role *Role) Scan(value interface{}) error { - if b, ok := value.([]uint8); ok { - role.Name = string(b) - } else { - role.Name = value.(string) - } - return nil -} - -func (role Role) Value() (driver.Value, error) { - return role.Name, nil -} - -func (role Role) IsAdmin() bool { - return role.Name == "admin" -} - -type Num int64 - -func (i *Num) Scan(src interface{}) error { - switch s := src.(type) { - case []byte: - n, _ := strconv.Atoi(string(s)) - *i = Num(n) - case int64: - *i = Num(s) - default: - return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) - } - return nil -} - -type Animal struct { - Counter uint64 `gorm:"primary_key:yes"` - Name string `sql:"DEFAULT:'galeone'"` - From string //test reserved sql keyword as field name - Age time.Time `sql:"DEFAULT:current_timestamp"` - unexported string // unexported value - CreatedAt time.Time - UpdatedAt time.Time -} - -type JoinTable struct { - From uint64 - To uint64 - Time time.Time `sql:"default: null"` -} - -type Post struct { - Id int64 - CategoryId sql.NullInt64 - MainCategoryId int64 - Title string - Body string - Comments []*Comment - Category Category - MainCategory Category -} - -type Category struct { - gorm.Model - Name string - - Categories []Category - CategoryID *uint -} - -type Comment struct { - gorm.Model - PostId int64 - Content string - Post Post -} - -// Scanner -type NullValue struct { - Id int64 - Name sql.NullString `sql:"not null"` - Gender *sql.NullString `sql:"not null"` - Age sql.NullInt64 - Male sql.NullBool - Height sql.NullFloat64 - AddedAt NullTime -} - -type NullTime struct { - Time time.Time - Valid bool -} - -func (nt *NullTime) Scan(value interface{}) error { - if value == nil { - nt.Valid = false - return nil - } - nt.Time, nt.Valid = value.(time.Time), true - return nil -} - -func (nt NullTime) Value() (driver.Value, error) { - if !nt.Valid { - return nil, nil - } - return nt.Time, nil -} - -func getPreparedUser(name string, role string) *User { - var company Company - DB.Where(Company{Name: role}).FirstOrCreate(&company) - - return &User{ - Name: name, - Age: 20, - Role: Role{role}, - BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)}, - ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)}, - CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)}, - Emails: []Email{ - {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)}, - }, - Company: company, - Languages: []Language{ - {Name: fmt.Sprintf("lang_1_%v", name)}, - {Name: fmt.Sprintf("lang_2_%v", name)}, - }, - } -} - -func runMigration() { - if err := DB.DropTableIfExists(&User{}).Error; err != nil { - fmt.Printf("Got error when try to delete table users, %+v\n", err) - } - - for _, table := range []string{"animals", "user_languages"} { - DB.Exec(fmt.Sprintf("drop table %v;", table)) - } - - values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}} - for _, value := range values { - DB.DropTable(value) - } - if err := DB.AutoMigrate(values...).Error; err != nil { - panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) - } -} - -func TestIndexes(t *testing.T) { - if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil { - t.Errorf("Got error when tried to create index: %+v", err) - } - - scope := DB.NewScope(&Email{}) - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { - t.Errorf("Email should have index idx_email_email") - } - - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil { - t.Errorf("Got error when tried to remove index: %+v", err) - } - - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { - t.Errorf("Email's index idx_email_email should be deleted") - } - - if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { - t.Errorf("Got error when tried to create index: %+v", err) - } - - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email should have index idx_email_email_and_user_id") - } - - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { - t.Errorf("Got error when tried to remove index: %+v", err) - } - - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email's index idx_email_email_and_user_id should be deleted") - } - - if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { - t.Errorf("Got error when tried to create index: %+v", err) - } - - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email should have index idx_email_email_and_user_id") - } - - if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil { - t.Errorf("Should get to create duplicate record when having unique index") - } - - var user = User{Name: "sample_user"} - DB.Save(&user) - if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil { - t.Errorf("Should get no error when append two emails for user") - } - - if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil { - t.Errorf("Should get no duplicated email error when insert duplicated emails for a user") - } - - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { - t.Errorf("Got error when tried to remove index: %+v", err) - } - - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email's index idx_email_email_and_user_id should be deleted") - } - - if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil { - t.Errorf("Should be able to create duplicated emails after remove unique index") - } -} - -type EmailWithIdx struct { - Id int64 - UserId int64 - Email string `sql:"index:idx_email_agent"` - UserAgent string `sql:"index:idx_email_agent"` - RegisteredAt *time.Time `sql:"unique_index"` - CreatedAt time.Time - UpdatedAt time.Time -} - -func TestAutoMigration(t *testing.T) { - DB.AutoMigrate(&Address{}) - DB.DropTable(&EmailWithIdx{}) - if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { - t.Errorf("Auto Migrate should not raise any error") - } - - now := time.Now() - DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now}) - - scope := DB.NewScope(&EmailWithIdx{}) - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") { - t.Errorf("Failed to create index") - } - - var bigemail EmailWithIdx - DB.First(&bigemail, "user_agent = ?", "pc") - if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() { - t.Error("Big Emails should be saved and fetched correctly") - } -} - -func TestCreateAndAutomigrateTransaction(t *testing.T) { - tx := DB.Begin() - - func() { - type Bar struct { - ID uint - } - DB.DropTableIfExists(&Bar{}) - - if ok := DB.HasTable("bars"); ok { - t.Errorf("Table should not exist, but does") - } - - if ok := tx.HasTable("bars"); ok { - t.Errorf("Table should not exist, but does") - } - }() - - func() { - type Bar struct { - Name string - } - err := tx.CreateTable(&Bar{}).Error - - if err != nil { - t.Errorf("Should have been able to create the table, but couldn't: %s", err) - } - - if ok := tx.HasTable(&Bar{}); !ok { - t.Errorf("The transaction should be able to see the table") - } - }() - - func() { - type Bar struct { - Stuff string - } - - err := tx.AutoMigrate(&Bar{}).Error - if err != nil { - t.Errorf("Should have been able to alter the table, but couldn't") - } - }() - - tx.Rollback() -} - -type MultipleIndexes struct { - ID int64 - UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` - Name string `sql:"unique_index:uix_multipleindexes_user_name"` - Email string `sql:"unique_index:,uix_multipleindexes_user_email"` - Other string `sql:"index:,idx_multipleindexes_user_other"` -} - -func TestMultipleIndexes(t *testing.T) { - if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil { - fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err) - } - - DB.AutoMigrate(&MultipleIndexes{}) - if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { - t.Errorf("Auto Migrate should not raise any error") - } - - DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"}) - - scope := DB.NewScope(&MultipleIndexes{}) - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") { - t.Errorf("Failed to create index") - } - - var mutipleIndexes MultipleIndexes - DB.First(&mutipleIndexes, "name = ?", "jinzhu") - if mutipleIndexes.Email != "jinzhu@example.org" || mutipleIndexes.Name != "jinzhu" { - t.Error("MutipleIndexes should be saved and fetched correctly") - } - - // Check unique constraints - if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil { - t.Error("MultipleIndexes unique index failed") - } - - if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "foo@example.org", Other: "foo"}).Error; err != nil { - t.Error("MultipleIndexes unique index failed") - } - - if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil { - t.Error("MultipleIndexes unique index failed") - } - - if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "foo2@example.org", Other: "foo"}).Error; err != nil { - t.Error("MultipleIndexes unique index failed") - } -} - -func TestModifyColumnType(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" { - t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type") - } - - type ModifyColumnType struct { - gorm.Model - Name1 string `gorm:"length:100"` - Name2 string `gorm:"length:200"` - } - DB.DropTable(&ModifyColumnType{}) - DB.CreateTable(&ModifyColumnType{}) - - name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2") - name2Type := DB.Dialect().DataTypeOf(name2Field.StructField) - - if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil { - t.Errorf("No error should happen when ModifyColumn, but got %v", err) - } -} - -func TestIndexWithPrefixLength(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" { - t.Skip("Skipping this because only mysql support setting an index prefix length") - } - - type IndexWithPrefix struct { - gorm.Model - Name string - Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` - } - type IndexesWithPrefix struct { - gorm.Model - Name string - Description1 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` - Description2 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` - } - type IndexesWithPrefixAndWithoutPrefix struct { - gorm.Model - Name string `gorm:"index:idx_index_with_prefixes_length"` - Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` - } - tables := []interface{}{&IndexWithPrefix{}, &IndexesWithPrefix{}, &IndexesWithPrefixAndWithoutPrefix{}} - for _, table := range tables { - scope := DB.NewScope(table) - tableName := scope.TableName() - t.Run(fmt.Sprintf("Create index with prefix length: %s", tableName), func(t *testing.T) { - if err := DB.DropTableIfExists(table).Error; err != nil { - t.Errorf("Failed to drop %s table: %v", tableName, err) - } - if err := DB.CreateTable(table).Error; err != nil { - t.Errorf("Failed to create %s table: %v", tableName, err) - } - if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") { - t.Errorf("Failed to create %s table index:", tableName) - } - }) - } -} diff --git a/model.go b/model.go deleted file mode 100644 index f37ff7ea..00000000 --- a/model.go +++ /dev/null @@ -1,14 +0,0 @@ -package gorm - -import "time" - -// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models -// type User struct { -// gorm.Model -// } -type Model struct { - ID uint `gorm:"primary_key"` - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time `sql:"index"` -} diff --git a/model_struct.go b/model_struct.go deleted file mode 100644 index d9e2e90f..00000000 --- a/model_struct.go +++ /dev/null @@ -1,671 +0,0 @@ -package gorm - -import ( - "database/sql" - "errors" - "go/ast" - "reflect" - "strings" - "sync" - "time" - - "github.com/jinzhu/inflection" -) - -// DefaultTableNameHandler default table name handler -var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { - return defaultTableName -} - -// lock for mutating global cached model metadata -var structsLock sync.Mutex - -// global cache of model metadata -var modelStructsMap sync.Map - -// ModelStruct model definition -type ModelStruct struct { - PrimaryFields []*StructField - StructFields []*StructField - ModelType reflect.Type - - defaultTableName string - l sync.Mutex -} - -// TableName returns model's table name -func (s *ModelStruct) TableName(db *DB) string { - s.l.Lock() - defer s.l.Unlock() - - if s.defaultTableName == "" && db != nil && s.ModelType != nil { - // Set default table name - if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { - s.defaultTableName = tabler.TableName() - } else { - tableName := ToTableName(s.ModelType.Name()) - db.parent.RLock() - if db == nil || (db.parent != nil && !db.parent.singularTable) { - tableName = inflection.Plural(tableName) - } - db.parent.RUnlock() - s.defaultTableName = tableName - } - } - - return DefaultTableNameHandler(db, s.defaultTableName) -} - -// StructField model field's struct definition -type StructField struct { - DBName string - Name string - Names []string - IsPrimaryKey bool - IsNormal bool - IsIgnored bool - IsScanner bool - HasDefaultValue bool - Tag reflect.StructTag - TagSettings map[string]string - Struct reflect.StructField - IsForeignKey bool - Relationship *Relationship - - tagSettingsLock sync.RWMutex -} - -// TagSettingsSet Sets a tag in the tag settings map -func (sf *StructField) TagSettingsSet(key, val string) { - sf.tagSettingsLock.Lock() - defer sf.tagSettingsLock.Unlock() - sf.TagSettings[key] = val -} - -// TagSettingsGet returns a tag from the tag settings -func (sf *StructField) TagSettingsGet(key string) (string, bool) { - sf.tagSettingsLock.RLock() - defer sf.tagSettingsLock.RUnlock() - val, ok := sf.TagSettings[key] - return val, ok -} - -// TagSettingsDelete deletes a tag -func (sf *StructField) TagSettingsDelete(key string) { - sf.tagSettingsLock.Lock() - defer sf.tagSettingsLock.Unlock() - delete(sf.TagSettings, key) -} - -func (sf *StructField) clone() *StructField { - clone := &StructField{ - DBName: sf.DBName, - Name: sf.Name, - Names: sf.Names, - IsPrimaryKey: sf.IsPrimaryKey, - IsNormal: sf.IsNormal, - IsIgnored: sf.IsIgnored, - IsScanner: sf.IsScanner, - HasDefaultValue: sf.HasDefaultValue, - Tag: sf.Tag, - TagSettings: map[string]string{}, - Struct: sf.Struct, - IsForeignKey: sf.IsForeignKey, - } - - if sf.Relationship != nil { - relationship := *sf.Relationship - clone.Relationship = &relationship - } - - // copy the struct field tagSettings, they should be read-locked while they are copied - sf.tagSettingsLock.Lock() - defer sf.tagSettingsLock.Unlock() - for key, value := range sf.TagSettings { - clone.TagSettings[key] = value - } - - return clone -} - -// Relationship described the relationship between models -type Relationship struct { - Kind string - PolymorphicType string - PolymorphicDBName string - PolymorphicValue string - ForeignFieldNames []string - ForeignDBNames []string - AssociationForeignFieldNames []string - AssociationForeignDBNames []string - JoinTableHandler JoinTableHandlerInterface -} - -func getForeignField(column string, fields []*StructField) *StructField { - for _, field := range fields { - if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) { - return field - } - } - return nil -} - -// GetModelStruct get value's model struct, relationships based on struct and tag definition -func (scope *Scope) GetModelStruct() *ModelStruct { - var modelStruct ModelStruct - // Scope value can't be nil - if scope.Value == nil { - return &modelStruct - } - - reflectType := reflect.ValueOf(scope.Value).Type() - for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr { - reflectType = reflectType.Elem() - } - - // Scope value need to be a struct - if reflectType.Kind() != reflect.Struct { - return &modelStruct - } - - // Get Cached model struct - isSingularTable := false - if scope.db != nil && scope.db.parent != nil { - scope.db.parent.RLock() - isSingularTable = scope.db.parent.singularTable - scope.db.parent.RUnlock() - } - - hashKey := struct { - singularTable bool - reflectType reflect.Type - }{isSingularTable, reflectType} - if value, ok := modelStructsMap.Load(hashKey); ok && value != nil { - return value.(*ModelStruct) - } - - modelStruct.ModelType = reflectType - - // Get all fields - for i := 0; i < reflectType.NumField(); i++ { - if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { - field := &StructField{ - Struct: fieldStruct, - Name: fieldStruct.Name, - Names: []string{fieldStruct.Name}, - Tag: fieldStruct.Tag, - TagSettings: parseTagSetting(fieldStruct.Tag), - } - - // is ignored field - if _, ok := field.TagSettingsGet("-"); ok { - field.IsIgnored = true - } else { - if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - - if _, ok := field.TagSettingsGet("DEFAULT"); ok && !field.IsPrimaryKey { - field.HasDefaultValue = true - } - - if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey { - field.HasDefaultValue = true - } - - indirectType := fieldStruct.Type - for indirectType.Kind() == reflect.Ptr { - indirectType = indirectType.Elem() - } - - fieldValue := reflect.New(indirectType).Interface() - if _, isScanner := fieldValue.(sql.Scanner); isScanner { - // is scanner - field.IsScanner, field.IsNormal = true, true - if indirectType.Kind() == reflect.Struct { - for i := 0; i < indirectType.NumField(); i++ { - for key, value := range parseTagSetting(indirectType.Field(i).Tag) { - if _, ok := field.TagSettingsGet(key); !ok { - field.TagSettingsSet(key, value) - } - } - } - } - } else if _, isTime := fieldValue.(*time.Time); isTime { - // is time - field.IsNormal = true - } else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous { - // is embedded struct - for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields { - subField = subField.clone() - subField.Names = append([]string{fieldStruct.Name}, subField.Names...) - if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok { - subField.DBName = prefix + subField.DBName - } - - if subField.IsPrimaryKey { - if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok { - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) - } else { - subField.IsPrimaryKey = false - } - } - - if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil { - if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok { - newJoinTableHandler := &JoinTableHandler{} - newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType) - subField.Relationship.JoinTableHandler = newJoinTableHandler - } - } - - modelStruct.StructFields = append(modelStruct.StructFields, subField) - } - continue - } else { - // build relationships - switch indirectType.Kind() { - case reflect.Slice: - defer func(field *StructField) { - var ( - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - foreignKeys []string - associationForeignKeys []string - elemType = field.Struct.Type - ) - - if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { - foreignKeys = strings.Split(foreignKey, ",") - } - - if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { - associationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { - associationForeignKeys = strings.Split(foreignKey, ",") - } - - for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } - - if elemType.Kind() == reflect.Struct { - if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { - relationship.Kind = "many_to_many" - - { // Foreign Keys for Source - joinTableDBNames := []string{} - - if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" { - joinTableDBNames = strings.Split(foreignKey, ",") - } - - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, field.DBName) - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - // source foreign keys (db names) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) - - // setup join table foreign keys for source - if len(joinTableDBNames) > idx { - // if defined join table's foreign key - relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) - } else { - defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName - relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) - } - } - } - } - - { // Foreign Keys for Association (Destination) - associationJoinTableDBNames := []string{} - - if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" { - associationJoinTableDBNames = strings.Split(foreignKey, ",") - } - - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range toScope.PrimaryFields() { - associationForeignKeys = append(associationForeignKeys, field.DBName) - } - } - - for idx, name := range associationForeignKeys { - if field, ok := toScope.FieldByName(name); ok { - // association foreign keys (db names) - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) - - // setup join table foreign keys for association - if len(associationJoinTableDBNames) > idx { - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) - } else { - // join table foreign keys for association - joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) - } - } - } - } - - joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, many2many, reflectType, elemType) - relationship.JoinTableHandler = &joinTableHandler - field.Relationship = relationship - } else { - // User has many comments, associationType is User, comment use UserID as foreign key - var associationType = reflectType.Name() - var toFields = toScope.GetStructFields() - relationship.Kind = "has_many" - - if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { - // Dog has many toys, tag polymorphic is Owner, then associationType is Owner - // Toy use OwnerID, OwnerType ('dogs') as foreign key - if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { - associationType = polymorphic - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - // if Dog has multiple set of toys set name of the set (instead of default 'dogs') - if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { - relationship.PolymorphicValue = value - } else { - relationship.PolymorphicValue = scope.TableName() - } - polymorphicType.IsForeignKey = true - } - } - - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, associationType+field.Name) - associationForeignKeys = append(associationForeignKeys, field.Name) - } - } else { - // generate foreign keys from defined association foreign keys - for _, scopeFieldName := range associationForeignKeys { - if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil { - foreignKeys = append(foreignKeys, associationType+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate association foreign keys from foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, associationType) { - associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { - // mark field as foreignkey, use global lock to avoid race - structsLock.Lock() - foreignField.IsForeignKey = true - structsLock.Unlock() - - // association foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - field.Relationship = relationship - } - } - } else { - field.IsNormal = true - } - }(field) - case reflect.Struct: - defer func(field *StructField) { - var ( - // user has one profile, associationType is User, profile use UserID as foreign key - // user belongs to profile, associationType is Profile, user use ProfileID as foreign key - associationType = reflectType.Name() - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - toFields = toScope.GetStructFields() - tagForeignKeys []string - tagAssociationForeignKeys []string - ) - - if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { - tagForeignKeys = strings.Split(foreignKey, ",") - } - - if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { - tagAssociationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { - tagAssociationForeignKeys = strings.Split(foreignKey, ",") - } - - if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { - // Cat has one toy, tag polymorphic is Owner, then associationType is Owner - // Toy use OwnerID, OwnerType ('cats') as foreign key - if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { - associationType = polymorphic - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - // if Cat has several different types of toys set name for each (instead of default 'cats') - if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { - relationship.PolymorphicValue = value - } else { - relationship.PolymorphicValue = scope.TableName() - } - polymorphicType.IsForeignKey = true - } - } - - // Has One - { - var foreignKeys = tagForeignKeys - var associationForeignKeys = tagAssociationForeignKeys - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, primaryField := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, associationType+primaryField.Name) - associationForeignKeys = append(associationForeignKeys, primaryField.Name) - } - } else { - // generate foreign keys form association foreign keys - for _, associationForeignKey := range tagAssociationForeignKeys { - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - foreignKeys = append(foreignKeys, associationType+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate association foreign keys from foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, associationType) { - associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { - // mark field as foreignkey, use global lock to avoid race - structsLock.Lock() - foreignField.IsForeignKey = true - structsLock.Unlock() - - // association foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) - - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "has_one" - field.Relationship = relationship - } else { - var foreignKeys = tagForeignKeys - var associationForeignKeys = tagAssociationForeignKeys - - if len(foreignKeys) == 0 { - // generate foreign keys & association foreign keys - if len(associationForeignKeys) == 0 { - for _, primaryField := range toScope.PrimaryFields() { - foreignKeys = append(foreignKeys, field.Name+primaryField.Name) - associationForeignKeys = append(associationForeignKeys, primaryField.Name) - } - } else { - // generate foreign keys with association foreign keys - for _, associationForeignKey := range associationForeignKeys { - if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { - foreignKeys = append(foreignKeys, field.Name+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate foreign keys & association foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, field.Name) { - associationForeignKey := strings.TrimPrefix(foreignKey, field.Name) - if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{toScope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { - // mark field as foreignkey, use global lock to avoid race - structsLock.Lock() - foreignField.IsForeignKey = true - structsLock.Unlock() - - // association foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - - // source foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "belongs_to" - field.Relationship = relationship - } - } - }(field) - default: - field.IsNormal = true - } - } - } - - // Even it is ignored, also possible to decode db value into the field - if value, ok := field.TagSettingsGet("COLUMN"); ok { - field.DBName = value - } else { - field.DBName = ToColumnName(fieldStruct.Name) - } - - modelStruct.StructFields = append(modelStruct.StructFields, field) - } - } - - if len(modelStruct.PrimaryFields) == 0 { - if field := getForeignField("id", modelStruct.StructFields); field != nil { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - } - - modelStructsMap.Store(hashKey, &modelStruct) - - return &modelStruct -} - -// GetStructFields get model's field structs -func (scope *Scope) GetStructFields() (fields []*StructField) { - return scope.GetModelStruct().StructFields -} - -func parseTagSetting(tags reflect.StructTag) map[string]string { - setting := map[string]string{} - for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { - if str == "" { - continue - } - tags := strings.Split(str, ";") - for _, value := range tags { - v := strings.Split(value, ":") - k := strings.TrimSpace(strings.ToUpper(v[0])) - if len(v) >= 2 { - setting[k] = strings.Join(v[1:], ":") - } else { - setting[k] = k - } - } - } - return setting -} diff --git a/model_struct_test.go b/model_struct_test.go deleted file mode 100644 index 2ae419a0..00000000 --- a/model_struct_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package gorm_test - -import ( - "sync" - "testing" - - "github.com/jinzhu/gorm" -) - -type ModelA struct { - gorm.Model - Name string - - ModelCs []ModelC `gorm:"foreignkey:OtherAID"` -} - -type ModelB struct { - gorm.Model - Name string - - ModelCs []ModelC `gorm:"foreignkey:OtherBID"` -} - -type ModelC struct { - gorm.Model - Name string - - OtherAID uint64 - OtherA *ModelA `gorm:"foreignkey:OtherAID"` - OtherBID uint64 - OtherB *ModelB `gorm:"foreignkey:OtherBID"` -} - -// This test will try to cause a race condition on the model's foreignkey metadata -func TestModelStructRaceSameModel(t *testing.T) { - // use a WaitGroup to execute as much in-sync as possible - // it's more likely to hit a race condition than without - n := 32 - start := sync.WaitGroup{} - start.Add(n) - - // use another WaitGroup to know when the test is done - done := sync.WaitGroup{} - done.Add(n) - - for i := 0; i < n; i++ { - go func() { - start.Wait() - - // call GetStructFields, this had a race condition before we fixed it - DB.NewScope(&ModelA{}).GetStructFields() - - done.Done() - }() - - start.Done() - } - - done.Wait() -} - -// This test will try to cause a race condition on the model's foreignkey metadata -func TestModelStructRaceDifferentModel(t *testing.T) { - // use a WaitGroup to execute as much in-sync as possible - // it's more likely to hit a race condition than without - n := 32 - start := sync.WaitGroup{} - start.Add(n) - - // use another WaitGroup to know when the test is done - done := sync.WaitGroup{} - done.Add(n) - - for i := 0; i < n; i++ { - i := i - go func() { - start.Wait() - - // call GetStructFields, this had a race condition before we fixed it - if i%2 == 0 { - DB.NewScope(&ModelA{}).GetStructFields() - } else { - DB.NewScope(&ModelB{}).GetStructFields() - } - - done.Done() - }() - - start.Done() - } - - done.Wait() -} diff --git a/multi_primary_keys_test.go b/multi_primary_keys_test.go deleted file mode 100644 index 32a14772..00000000 --- a/multi_primary_keys_test.go +++ /dev/null @@ -1,381 +0,0 @@ -package gorm_test - -import ( - "os" - "reflect" - "sort" - "testing" -) - -type Blog struct { - ID uint `gorm:"primary_key"` - Locale string `gorm:"primary_key"` - Subject string - Body string - Tags []Tag `gorm:"many2many:blog_tags;"` - SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;AssociationForeignKey:id"` - LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;AssociationForeignKey:id"` -} - -type Tag struct { - ID uint `gorm:"primary_key"` - Locale string `gorm:"primary_key"` - Value string - Blogs []*Blog `gorm:"many2many:blogs_tags"` -} - -func compareTags(tags []Tag, contents []string) bool { - var tagContents []string - for _, tag := range tags { - tagContents = append(tagContents, tag.Value) - } - sort.Strings(tagContents) - sort.Strings(contents) - return reflect.DeepEqual(tagContents, contents) -} - -func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" { - DB.DropTable(&Blog{}, &Tag{}) - DB.DropTable("blog_tags") - DB.CreateTable(&Blog{}, &Tag{}) - blog := Blog{ - Locale: "ZH", - Subject: "subject", - Body: "body", - Tags: []Tag{ - {Locale: "ZH", Value: "tag1"}, - {Locale: "ZH", Value: "tag2"}, - }, - } - - DB.Save(&blog) - if !compareTags(blog.Tags, []string{"tag1", "tag2"}) { - t.Errorf("Blog should has two tags") - } - - // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} - DB.Model(&blog).Association("Tags").Append([]*Tag{tag3}) - if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog).Association("Tags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - var tags []Tag - DB.Model(&blog).Related(&tags, "Tags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - var blog1 Blog - DB.Preload("Tags").Find(&blog1) - if !compareTags(blog1.Tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Preload many2many relations") - } - - // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} - DB.Model(&blog).Association("Tags").Replace(tag5, tag6) - var tags2 []Tag - DB.Model(&blog).Related(&tags2, "Tags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - if DB.Model(&blog).Association("Tags").Count() != 2 { - t.Errorf("Blog should has three tags after Replace") - } - - // Delete - DB.Model(&blog).Association("Tags").Delete(tag5) - var tags3 []Tag - DB.Model(&blog).Related(&tags3, "Tags") - if !compareTags(tags3, []string{"tag6"}) { - t.Errorf("Should find 1 tags after Delete") - } - - if DB.Model(&blog).Association("Tags").Count() != 1 { - t.Errorf("Blog should has three tags after Delete") - } - - DB.Model(&blog).Association("Tags").Delete(tag3) - var tags4 []Tag - DB.Model(&blog).Related(&tags4, "Tags") - if !compareTags(tags4, []string{"tag6"}) { - t.Errorf("Tag should not be deleted when Delete with a unrelated tag") - } - - // Clear - DB.Model(&blog).Association("Tags").Clear() - if DB.Model(&blog).Association("Tags").Count() != 0 { - t.Errorf("All tags should be cleared") - } - } -} - -func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" { - DB.DropTable(&Blog{}, &Tag{}) - DB.DropTable("shared_blog_tags") - DB.CreateTable(&Blog{}, &Tag{}) - blog := Blog{ - Locale: "ZH", - Subject: "subject", - Body: "body", - SharedTags: []Tag{ - {Locale: "ZH", Value: "tag1"}, - {Locale: "ZH", Value: "tag2"}, - }, - } - DB.Save(&blog) - - blog2 := Blog{ - ID: blog.ID, - Locale: "EN", - } - DB.Create(&blog2) - - if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) { - t.Errorf("Blog should has two tags") - } - - // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} - DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3}) - if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog).Association("SharedTags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog2).Association("SharedTags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - var tags []Tag - DB.Model(&blog).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - DB.Model(&blog2).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - var blog1 Blog - DB.Preload("SharedTags").Find(&blog1) - if !compareTags(blog1.SharedTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Preload many2many relations") - } - - var tag4 = &Tag{Locale: "ZH", Value: "tag4"} - DB.Model(&blog2).Association("SharedTags").Append(tag4) - - DB.Model(&blog).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { - t.Errorf("Should find 3 tags with Related") - } - - DB.Model(&blog2).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { - t.Errorf("Should find 3 tags with Related") - } - - // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} - DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6) - var tags2 []Tag - DB.Model(&blog).Related(&tags2, "SharedTags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - DB.Model(&blog2).Related(&tags2, "SharedTags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - if DB.Model(&blog).Association("SharedTags").Count() != 2 { - t.Errorf("Blog should has three tags after Replace") - } - - // Delete - DB.Model(&blog).Association("SharedTags").Delete(tag5) - var tags3 []Tag - DB.Model(&blog).Related(&tags3, "SharedTags") - if !compareTags(tags3, []string{"tag6"}) { - t.Errorf("Should find 1 tags after Delete") - } - - if DB.Model(&blog).Association("SharedTags").Count() != 1 { - t.Errorf("Blog should has three tags after Delete") - } - - DB.Model(&blog2).Association("SharedTags").Delete(tag3) - var tags4 []Tag - DB.Model(&blog).Related(&tags4, "SharedTags") - if !compareTags(tags4, []string{"tag6"}) { - t.Errorf("Tag should not be deleted when Delete with a unrelated tag") - } - - // Clear - DB.Model(&blog2).Association("SharedTags").Clear() - if DB.Model(&blog).Association("SharedTags").Count() != 0 { - t.Errorf("All tags should be cleared") - } - } -} - -func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" { - DB.DropTable(&Blog{}, &Tag{}) - DB.DropTable("locale_blog_tags") - DB.CreateTable(&Blog{}, &Tag{}) - blog := Blog{ - Locale: "ZH", - Subject: "subject", - Body: "body", - LocaleTags: []Tag{ - {Locale: "ZH", Value: "tag1"}, - {Locale: "ZH", Value: "tag2"}, - }, - } - DB.Save(&blog) - - blog2 := Blog{ - ID: blog.ID, - Locale: "EN", - } - DB.Create(&blog2) - - // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} - DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3}) - if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { - t.Errorf("EN Blog should has 0 tags after ZH Blog Append") - } - - var tags []Tag - DB.Model(&blog).Related(&tags, "LocaleTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - DB.Model(&blog2).Related(&tags, "LocaleTags") - if len(tags) != 0 { - t.Errorf("Should find 0 tags with Related for EN Blog") - } - - var blog1 Blog - DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID) - if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Preload many2many relations") - } - - var tag4 = &Tag{Locale: "ZH", Value: "tag4"} - DB.Model(&blog2).Association("LocaleTags").Append(tag4) - - DB.Model(&blog).Related(&tags, "LocaleTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related for EN Blog") - } - - DB.Model(&blog2).Related(&tags, "LocaleTags") - if !compareTags(tags, []string{"tag4"}) { - t.Errorf("Should find 1 tags with Related for EN Blog") - } - - // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} - DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6) - - var tags2 []Tag - DB.Model(&blog).Related(&tags2, "LocaleTags") - if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") - } - - var blog11 Blog - DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale) - if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") - } - - DB.Model(&blog2).Related(&tags2, "LocaleTags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - var blog21 Blog - DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale) - if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) { - t.Errorf("EN Blog's tags should be changed after Replace") - } - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog should has three tags after Replace") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { - t.Errorf("EN Blog should has two tags after Replace") - } - - // Delete - DB.Model(&blog).Association("LocaleTags").Delete(tag5) - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog should has three tags after Delete with EN's tag") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { - t.Errorf("EN Blog should has two tags after ZH Blog Delete with EN's tag") - } - - DB.Model(&blog2).Association("LocaleTags").Delete(tag5) - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog should has three tags after EN Blog Delete with EN's tag") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 1 { - t.Errorf("EN Blog should has 1 tags after EN Blog Delete with EN's tag") - } - - // Clear - DB.Model(&blog2).Association("LocaleTags").Clear() - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog's tags should not be cleared when clear EN Blog's tags") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { - t.Errorf("EN Blog's tags should be cleared when clear EN Blog's tags") - } - - DB.Model(&blog).Association("LocaleTags").Clear() - if DB.Model(&blog).Association("LocaleTags").Count() != 0 { - t.Errorf("ZH Blog's tags should be cleared when clear ZH Blog's tags") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { - t.Errorf("EN Blog's tags should be cleared") - } - } -} diff --git a/naming.go b/naming.go deleted file mode 100644 index 6b0a4fdd..00000000 --- a/naming.go +++ /dev/null @@ -1,124 +0,0 @@ -package gorm - -import ( - "bytes" - "strings" -) - -// Namer is a function type which is given a string and return a string -type Namer func(string) string - -// NamingStrategy represents naming strategies -type NamingStrategy struct { - DB Namer - Table Namer - Column Namer -} - -// TheNamingStrategy is being initialized with defaultNamingStrategy -var TheNamingStrategy = &NamingStrategy{ - DB: defaultNamer, - Table: defaultNamer, - Column: defaultNamer, -} - -// AddNamingStrategy sets the naming strategy -func AddNamingStrategy(ns *NamingStrategy) { - if ns.DB == nil { - ns.DB = defaultNamer - } - if ns.Table == nil { - ns.Table = defaultNamer - } - if ns.Column == nil { - ns.Column = defaultNamer - } - TheNamingStrategy = ns -} - -// DBName alters the given name by DB -func (ns *NamingStrategy) DBName(name string) string { - return ns.DB(name) -} - -// TableName alters the given name by Table -func (ns *NamingStrategy) TableName(name string) string { - return ns.Table(name) -} - -// ColumnName alters the given name by Column -func (ns *NamingStrategy) ColumnName(name string) string { - return ns.Column(name) -} - -// ToDBName convert string to db name -func ToDBName(name string) string { - return TheNamingStrategy.DBName(name) -} - -// ToTableName convert string to table name -func ToTableName(name string) string { - return TheNamingStrategy.TableName(name) -} - -// ToColumnName convert string to db name -func ToColumnName(name string) string { - return TheNamingStrategy.ColumnName(name) -} - -var smap = newSafeMap() - -func defaultNamer(name string) string { - const ( - lower = false - upper = true - ) - - if v := smap.Get(name); v != "" { - return v - } - - if name == "" { - return "" - } - - var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase, nextNumber bool - ) - - for i, v := range value[:len(value)-1] { - nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z') - nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9') - - if i > 0 { - if currCase == upper { - if lastCase == upper && (nextCase == upper || nextNumber == upper) { - buf.WriteRune(v) - } else { - if value[i-1] != '_' && value[i+1] != '_' { - buf.WriteRune('_') - } - buf.WriteRune(v) - } - } else { - buf.WriteRune(v) - if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { - buf.WriteRune('_') - } - } - } else { - currCase = upper - buf.WriteRune(v) - } - lastCase = currCase - currCase = nextCase - } - - buf.WriteByte(value[len(value)-1]) - - s := strings.ToLower(buf.String()) - smap.Set(name, s) - return s -} diff --git a/naming_test.go b/naming_test.go deleted file mode 100644 index 0c6f7713..00000000 --- a/naming_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package gorm_test - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -func TestTheNamingStrategy(t *testing.T) { - - cases := []struct { - name string - namer gorm.Namer - expected string - }{ - {name: "auth", expected: "auth", namer: gorm.TheNamingStrategy.DB}, - {name: "userRestrictions", expected: "user_restrictions", namer: gorm.TheNamingStrategy.Table}, - {name: "clientID", expected: "client_id", namer: gorm.TheNamingStrategy.Column}, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - result := c.namer(c.name) - if result != c.expected { - t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) - } - }) - } - -} - -func TestNamingStrategy(t *testing.T) { - - dbNameNS := func(name string) string { - return "db_" + name - } - tableNameNS := func(name string) string { - return "tbl_" + name - } - columnNameNS := func(name string) string { - return "col_" + name - } - - ns := &gorm.NamingStrategy{ - DB: dbNameNS, - Table: tableNameNS, - Column: columnNameNS, - } - - cases := []struct { - name string - namer gorm.Namer - expected string - }{ - {name: "auth", expected: "db_auth", namer: ns.DB}, - {name: "user", expected: "tbl_user", namer: ns.Table}, - {name: "password", expected: "col_password", namer: ns.Column}, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - result := c.namer(c.name) - if result != c.expected { - t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) - } - }) - } - -} diff --git a/pointer_test.go b/pointer_test.go deleted file mode 100644 index 2a68a5ab..00000000 --- a/pointer_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package gorm_test - -import "testing" - -type PointerStruct struct { - ID int64 - Name *string - Num *int -} - -type NormalStruct struct { - ID int64 - Name string - Num int -} - -func TestPointerFields(t *testing.T) { - DB.DropTable(&PointerStruct{}) - DB.AutoMigrate(&PointerStruct{}) - var name = "pointer struct 1" - var num = 100 - pointerStruct := PointerStruct{Name: &name, Num: &num} - if DB.Create(&pointerStruct).Error != nil { - t.Errorf("Failed to save pointer struct") - } - - var pointerStructResult PointerStruct - if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).Error; err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num { - t.Errorf("Failed to query saved pointer struct") - } - - var tableName = DB.NewScope(&PointerStruct{}).TableName() - - var normalStruct NormalStruct - DB.Table(tableName).First(&normalStruct) - if normalStruct.Name != name || normalStruct.Num != num { - t.Errorf("Failed to query saved Normal struct") - } - - var nilPointerStruct = PointerStruct{} - if err := DB.Create(&nilPointerStruct).Error; err != nil { - t.Error("Failed to save nil pointer struct", err) - } - - var pointerStruct2 PointerStruct - if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { - t.Error("Failed to query saved nil pointer struct", err) - } - - var normalStruct2 NormalStruct - if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { - t.Error("Failed to query saved nil pointer struct", err) - } - - var partialNilPointerStruct1 = PointerStruct{Num: &num} - if err := DB.Create(&partialNilPointerStruct1).Error; err != nil { - t.Error("Failed to save partial nil pointer struct", err) - } - - var pointerStruct3 PointerStruct - if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num { - t.Error("Failed to query saved partial nil pointer struct", err) - } - - var normalStruct3 NormalStruct - if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num { - t.Error("Failed to query saved partial pointer struct", err) - } - - var partialNilPointerStruct2 = PointerStruct{Name: &name} - if err := DB.Create(&partialNilPointerStruct2).Error; err != nil { - t.Error("Failed to save partial nil pointer struct", err) - } - - var pointerStruct4 PointerStruct - if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name { - t.Error("Failed to query saved partial nil pointer struct", err) - } - - var normalStruct4 NormalStruct - if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name { - t.Error("Failed to query saved partial pointer struct", err) - } -} diff --git a/polymorphic_test.go b/polymorphic_test.go deleted file mode 100644 index d1ecfbbb..00000000 --- a/polymorphic_test.go +++ /dev/null @@ -1,366 +0,0 @@ -package gorm_test - -import ( - "reflect" - "sort" - "testing" -) - -type Cat struct { - Id int - Name string - Toy Toy `gorm:"polymorphic:Owner;"` -} - -type Dog struct { - Id int - Name string - Toys []Toy `gorm:"polymorphic:Owner;"` -} - -type Hamster struct { - Id int - Name string - PreferredToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_preferred"` - OtherToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_other"` -} - -type Toy struct { - Id int - Name string - OwnerId int - OwnerType string -} - -var compareToys = func(toys []Toy, contents []string) bool { - var toyContents []string - for _, toy := range toys { - toyContents = append(toyContents, toy.Name) - } - sort.Strings(toyContents) - sort.Strings(contents) - return reflect.DeepEqual(toyContents, contents) -} - -func TestPolymorphic(t *testing.T) { - cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat toy"}} - dog := Dog{Name: "Pluto", Toys: []Toy{{Name: "dog toy 1"}, {Name: "dog toy 2"}}} - DB.Save(&cat).Save(&dog) - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1") - } - - if DB.Model(&dog).Association("Toys").Count() != 2 { - t.Errorf("Dog's toys count should be 2") - } - - // Query - var catToys []Toy - if DB.Model(&cat).Related(&catToys, "Toy").RecordNotFound() { - t.Errorf("Did not find any has one polymorphic association") - } else if len(catToys) != 1 { - t.Errorf("Should have found only one polymorphic has one association") - } else if catToys[0].Name != cat.Toy.Name { - t.Errorf("Should have found the proper has one polymorphic association") - } - - var dogToys []Toy - if DB.Model(&dog).Related(&dogToys, "Toys").RecordNotFound() { - t.Errorf("Did not find any polymorphic has many associations") - } else if len(dogToys) != len(dog.Toys) { - t.Errorf("Should have found all polymorphic has many associations") - } - - var catToy Toy - DB.Model(&cat).Association("Toy").Find(&catToy) - if catToy.Name != cat.Toy.Name { - t.Errorf("Should find has one polymorphic association") - } - - var dogToys1 []Toy - DB.Model(&dog).Association("Toys").Find(&dogToys1) - if !compareToys(dogToys1, []string{"dog toy 1", "dog toy 2"}) { - t.Errorf("Should find has many polymorphic association") - } - - // Append - DB.Model(&cat).Association("Toy").Append(&Toy{ - Name: "cat toy 2", - }) - - var catToy2 Toy - DB.Model(&cat).Association("Toy").Find(&catToy2) - if catToy2.Name != "cat toy 2" { - t.Errorf("Should update has one polymorphic association with Append") - } - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1 after Append") - } - - if DB.Model(&dog).Association("Toys").Count() != 2 { - t.Errorf("Should return two polymorphic has many associations") - } - - DB.Model(&dog).Association("Toys").Append(&Toy{ - Name: "dog toy 3", - }) - - var dogToys2 []Toy - DB.Model(&dog).Association("Toys").Find(&dogToys2) - if !compareToys(dogToys2, []string{"dog toy 1", "dog toy 2", "dog toy 3"}) { - t.Errorf("Dog's toys should be updated with Append") - } - - if DB.Model(&dog).Association("Toys").Count() != 3 { - t.Errorf("Should return three polymorphic has many associations") - } - - // Replace - DB.Model(&cat).Association("Toy").Replace(&Toy{ - Name: "cat toy 3", - }) - - var catToy3 Toy - DB.Model(&cat).Association("Toy").Find(&catToy3) - if catToy3.Name != "cat toy 3" { - t.Errorf("Should update has one polymorphic association with Replace") - } - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1 after Replace") - } - - if DB.Model(&dog).Association("Toys").Count() != 3 { - t.Errorf("Should return three polymorphic has many associations") - } - - DB.Model(&dog).Association("Toys").Replace(&Toy{ - Name: "dog toy 4", - }, []Toy{ - {Name: "dog toy 5"}, {Name: "dog toy 6"}, {Name: "dog toy 7"}, - }) - - var dogToys3 []Toy - DB.Model(&dog).Association("Toys").Find(&dogToys3) - if !compareToys(dogToys3, []string{"dog toy 4", "dog toy 5", "dog toy 6", "dog toy 7"}) { - t.Errorf("Dog's toys should be updated with Replace") - } - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Should return three polymorphic has many associations") - } - - // Delete - DB.Model(&cat).Association("Toy").Delete(&catToy2) - - var catToy4 Toy - DB.Model(&cat).Association("Toy").Find(&catToy4) - if catToy4.Name != "cat toy 3" { - t.Errorf("Should not update has one polymorphic association when Delete a unrelated Toy") - } - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1") - } - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Dog's toys count should be 4") - } - - DB.Model(&cat).Association("Toy").Delete(&catToy3) - - if !DB.Model(&cat).Related(&Toy{}, "Toy").RecordNotFound() { - t.Errorf("Toy should be deleted with Delete") - } - - if DB.Model(&cat).Association("Toy").Count() != 0 { - t.Errorf("Cat's toys count should be 0 after Delete") - } - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Dog's toys count should not be changed when delete cat's toy") - } - - DB.Model(&dog).Association("Toys").Delete(&dogToys2) - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Dog's toys count should not be changed when delete unrelated toys") - } - - DB.Model(&dog).Association("Toys").Delete(&dogToys3) - - if DB.Model(&dog).Association("Toys").Count() != 0 { - t.Errorf("Dog's toys count should be deleted with Delete") - } - - // Clear - DB.Model(&cat).Association("Toy").Append(&Toy{ - Name: "cat toy 2", - }) - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys should be added with Append") - } - - DB.Model(&cat).Association("Toy").Clear() - - if DB.Model(&cat).Association("Toy").Count() != 0 { - t.Errorf("Cat's toys should be cleared with Clear") - } - - DB.Model(&dog).Association("Toys").Append(&Toy{ - Name: "dog toy 8", - }) - - if DB.Model(&dog).Association("Toys").Count() != 1 { - t.Errorf("Dog's toys should be added with Append") - } - - DB.Model(&dog).Association("Toys").Clear() - - if DB.Model(&dog).Association("Toys").Count() != 0 { - t.Errorf("Dog's toys should be cleared with Clear") - } -} - -func TestNamedPolymorphic(t *testing.T) { - hamster := Hamster{Name: "Mr. Hammond", PreferredToy: Toy{Name: "bike"}, OtherToy: Toy{Name: "treadmill"}} - DB.Save(&hamster) - - hamster2 := Hamster{} - DB.Preload("PreferredToy").Preload("OtherToy").Find(&hamster2, hamster.Id) - if hamster2.PreferredToy.Id != hamster.PreferredToy.Id || hamster2.PreferredToy.Name != hamster.PreferredToy.Name { - t.Errorf("Hamster's preferred toy couldn't be preloaded") - } - if hamster2.OtherToy.Id != hamster.OtherToy.Id || hamster2.OtherToy.Name != hamster.OtherToy.Name { - t.Errorf("Hamster's other toy couldn't be preloaded") - } - - // clear to omit Toy.Id in count - hamster2.PreferredToy = Toy{} - hamster2.OtherToy = Toy{} - - if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { - t.Errorf("Hamster's preferred toy count should be 1") - } - - if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { - t.Errorf("Hamster's other toy count should be 1") - } - - // Query - var hamsterToys []Toy - if DB.Model(&hamster).Related(&hamsterToys, "PreferredToy").RecordNotFound() { - t.Errorf("Did not find any has one polymorphic association") - } else if len(hamsterToys) != 1 { - t.Errorf("Should have found only one polymorphic has one association") - } else if hamsterToys[0].Name != hamster.PreferredToy.Name { - t.Errorf("Should have found the proper has one polymorphic association") - } - - if DB.Model(&hamster).Related(&hamsterToys, "OtherToy").RecordNotFound() { - t.Errorf("Did not find any has one polymorphic association") - } else if len(hamsterToys) != 1 { - t.Errorf("Should have found only one polymorphic has one association") - } else if hamsterToys[0].Name != hamster.OtherToy.Name { - t.Errorf("Should have found the proper has one polymorphic association") - } - - hamsterToy := Toy{} - DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) - if hamsterToy.Name != hamster.PreferredToy.Name { - t.Errorf("Should find has one polymorphic association") - } - hamsterToy = Toy{} - DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) - if hamsterToy.Name != hamster.OtherToy.Name { - t.Errorf("Should find has one polymorphic association") - } - - // Append - DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ - Name: "bike 2", - }) - DB.Model(&hamster).Association("OtherToy").Append(&Toy{ - Name: "treadmill 2", - }) - - hamsterToy = Toy{} - DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) - if hamsterToy.Name != "bike 2" { - t.Errorf("Should update has one polymorphic association with Append") - } - - hamsterToy = Toy{} - DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) - if hamsterToy.Name != "treadmill 2" { - t.Errorf("Should update has one polymorphic association with Append") - } - - if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { - t.Errorf("Hamster's toys count should be 1 after Append") - } - - if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { - t.Errorf("Hamster's toys count should be 1 after Append") - } - - // Replace - DB.Model(&hamster).Association("PreferredToy").Replace(&Toy{ - Name: "bike 3", - }) - DB.Model(&hamster).Association("OtherToy").Replace(&Toy{ - Name: "treadmill 3", - }) - - hamsterToy = Toy{} - DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) - if hamsterToy.Name != "bike 3" { - t.Errorf("Should update has one polymorphic association with Replace") - } - - hamsterToy = Toy{} - DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) - if hamsterToy.Name != "treadmill 3" { - t.Errorf("Should update has one polymorphic association with Replace") - } - - if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { - t.Errorf("hamster's toys count should be 1 after Replace") - } - - if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { - t.Errorf("hamster's toys count should be 1 after Replace") - } - - // Clear - DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ - Name: "bike 2", - }) - DB.Model(&hamster).Association("OtherToy").Append(&Toy{ - Name: "treadmill 2", - }) - - if DB.Model(&hamster).Association("PreferredToy").Count() != 1 { - t.Errorf("Hamster's toys should be added with Append") - } - if DB.Model(&hamster).Association("OtherToy").Count() != 1 { - t.Errorf("Hamster's toys should be added with Append") - } - - DB.Model(&hamster).Association("PreferredToy").Clear() - - if DB.Model(&hamster2).Association("PreferredToy").Count() != 0 { - t.Errorf("Hamster's preferred toy should be cleared with Clear") - } - if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { - t.Errorf("Hamster's other toy should be still available") - } - - DB.Model(&hamster).Association("OtherToy").Clear() - if DB.Model(&hamster).Association("OtherToy").Count() != 0 { - t.Errorf("Hamster's other toy should be cleared with Clear") - } -} diff --git a/preload_test.go b/preload_test.go deleted file mode 100644 index dd29fb5e..00000000 --- a/preload_test.go +++ /dev/null @@ -1,1701 +0,0 @@ -package gorm_test - -import ( - "database/sql" - "encoding/json" - "os" - "reflect" - "testing" - - "github.com/jinzhu/gorm" -) - -func getPreloadUser(name string) *User { - return getPreparedUser(name, "Preload") -} - -func checkUserHasPreloadData(user User, t *testing.T) { - u := getPreloadUser(user.Name) - if user.BillingAddress.Address1 != u.BillingAddress.Address1 { - t.Error("Failed to preload user's BillingAddress") - } - - if user.ShippingAddress.Address1 != u.ShippingAddress.Address1 { - t.Error("Failed to preload user's ShippingAddress") - } - - if user.CreditCard.Number != u.CreditCard.Number { - t.Error("Failed to preload user's CreditCard") - } - - if user.Company.Name != u.Company.Name { - t.Error("Failed to preload user's Company") - } - - if len(user.Emails) != len(u.Emails) { - t.Error("Failed to preload user's Emails") - } else { - var found int - for _, e1 := range u.Emails { - for _, e2 := range user.Emails { - if e1.Email == e2.Email { - found++ - break - } - } - } - if found != len(u.Emails) { - t.Error("Failed to preload user's email details") - } - } -} - -func TestPreload(t *testing.T) { - user1 := getPreloadUser("user1") - DB.Save(user1) - - preloadDB := DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company") - var user User - preloadDB.Find(&user) - checkUserHasPreloadData(user, t) - - user2 := getPreloadUser("user2") - DB.Save(user2) - - user3 := getPreloadUser("user3") - DB.Save(user3) - - var users []User - preloadDB.Find(&users) - - for _, user := range users { - checkUserHasPreloadData(user, t) - } - - var users2 []*User - preloadDB.Find(&users2) - - for _, user := range users2 { - checkUserHasPreloadData(*user, t) - } - - var users3 []*User - preloadDB.Preload("Emails", "email = ?", user3.Emails[0].Email).Find(&users3) - - for _, user := range users3 { - if user.Name == user3.Name { - if len(user.Emails) != 1 { - t.Errorf("should only preload one emails for user3 when with condition") - } - } else if len(user.Emails) != 0 { - t.Errorf("should not preload any emails for other users when with condition") - } else if user.Emails == nil { - t.Errorf("should return an empty slice to indicate zero results") - } - } -} - -func TestAutoPreload(t *testing.T) { - user1 := getPreloadUser("auto_user1") - DB.Save(user1) - - preloadDB := DB.Set("gorm:auto_preload", true).Where("role = ?", "Preload") - var user User - preloadDB.Find(&user) - checkUserHasPreloadData(user, t) - - user2 := getPreloadUser("auto_user2") - DB.Save(user2) - - var users []User - preloadDB.Find(&users) - - for _, user := range users { - checkUserHasPreloadData(user, t) - } - - var users2 []*User - preloadDB.Find(&users2) - - for _, user := range users2 { - checkUserHasPreloadData(*user, t) - } -} - -func TestAutoPreloadFalseDoesntPreload(t *testing.T) { - user1 := getPreloadUser("auto_user1") - DB.Save(user1) - - preloadDB := DB.Set("gorm:auto_preload", false).Where("role = ?", "Preload") - var user User - preloadDB.Find(&user) - - if user.BillingAddress.Address1 != "" { - t.Error("AutoPreload was set to fasle, but still fetched data") - } - - user2 := getPreloadUser("auto_user2") - DB.Save(user2) - - var users []User - preloadDB.Find(&users) - - for _, user := range users { - if user.BillingAddress.Address1 != "" { - t.Error("AutoPreload was set to fasle, but still fetched data") - } - } -} - -func TestNestedPreload1(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} - if err := DB.Create(&want).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound { - t.Error(err) - } -} - -func TestNestedPreload2(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []*Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{ - Level2s: []Level2{ - { - Level1s: []*Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - { - Level1s: []*Level1{ - {Value: "value3"}, - }, - }, - }, - } - if err := DB.Create(&want).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestNestedPreload3(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - Name string - ID uint - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{ - Level2s: []Level2{ - {Level1: Level1{Value: "value1"}}, - {Level1: Level1{Value: "value2"}}, - }, - } - if err := DB.Create(&want).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestNestedPreload4(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - } - if err := DB.Create(&want).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -// Slice: []Level3 -func TestNestedPreload5(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := make([]Level3, 2) - want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } - want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } - - var got []Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestNestedPreload6(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := make([]Level3, 2) - want[0] = Level3{ - Level2s: []Level2{ - { - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - { - Level1s: []Level1{ - {Value: "value3"}, - }, - }, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } - - want[1] = Level3{ - Level2s: []Level2{ - { - Level1s: []Level1{ - {Value: "value3"}, - {Value: "value4"}, - }, - }, - { - Level1s: []Level1{ - {Value: "value5"}, - }, - }, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } - - var got []Level3 - if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestNestedPreload7(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := make([]Level3, 2) - want[0] = Level3{ - Level2s: []Level2{ - {Level1: Level1{Value: "value1"}}, - {Level1: Level1{Value: "value2"}}, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } - - want[1] = Level3{ - Level2s: []Level2{ - {Level1: Level1{Value: "value3"}}, - {Level1: Level1{Value: "value4"}}, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } - - var got []Level3 - if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestNestedPreload8(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := make([]Level3, 2) - want[0] = Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } - want[1] = Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value3"}, - {Value: "value4"}, - }, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } - - var got []Level3 - if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestNestedPreload9(t *testing.T) { - type ( - Level0 struct { - ID uint - Value string - Level1ID uint - } - Level1 struct { - ID uint - Value string - Level2ID uint - Level2_1ID uint - Level0s []Level0 - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level2_1 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - Level2_1 Level2_1 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level2_1{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level0{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil { - t.Error(err) - } - - want := make([]Level3, 2) - want[0] = Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - Level2_1: Level2_1{ - Level1s: []Level1{ - { - Value: "value1-1", - Level0s: []Level0{{Value: "Level0-1"}}, - }, - { - Value: "value2-2", - Level0s: []Level0{{Value: "Level0-2"}}, - }, - }, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } - want[1] = Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value3"}, - {Value: "value4"}, - }, - }, - Level2_1: Level2_1{ - Level1s: []Level1{ - { - Value: "value3-3", - Level0s: []Level0{}, - }, - { - Value: "value4-4", - Level0s: []Level0{}, - }, - }, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } - - var got []Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -type LevelA1 struct { - ID uint - Value string -} - -type LevelA2 struct { - ID uint - Value string - LevelA3s []*LevelA3 -} - -type LevelA3 struct { - ID uint - Value string - LevelA1ID sql.NullInt64 - LevelA1 *LevelA1 - LevelA2ID sql.NullInt64 - LevelA2 *LevelA2 -} - -func TestNestedPreload10(t *testing.T) { - DB.DropTableIfExists(&LevelA3{}) - DB.DropTableIfExists(&LevelA2{}) - DB.DropTableIfExists(&LevelA1{}) - - if err := DB.AutoMigrate(&LevelA1{}, &LevelA2{}, &LevelA3{}).Error; err != nil { - t.Error(err) - } - - levelA1 := &LevelA1{Value: "foo"} - if err := DB.Save(levelA1).Error; err != nil { - t.Error(err) - } - - want := []*LevelA2{ - { - Value: "bar", - LevelA3s: []*LevelA3{ - { - Value: "qux", - LevelA1: levelA1, - }, - }, - }, - { - Value: "bar 2", - LevelA3s: []*LevelA3{}, - }, - } - for _, levelA2 := range want { - if err := DB.Save(levelA2).Error; err != nil { - t.Error(err) - } - } - - var got []*LevelA2 - if err := DB.Preload("LevelA3s.LevelA1").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -type LevelB1 struct { - ID uint - Value string - LevelB3s []*LevelB3 -} - -type LevelB2 struct { - ID uint - Value string -} - -type LevelB3 struct { - ID uint - Value string - LevelB1ID sql.NullInt64 - LevelB1 *LevelB1 - LevelB2s []*LevelB2 `gorm:"many2many:levelb1_levelb3_levelb2s"` -} - -func TestNestedPreload11(t *testing.T) { - DB.DropTableIfExists(&LevelB2{}) - DB.DropTableIfExists(&LevelB3{}) - DB.DropTableIfExists(&LevelB1{}) - if err := DB.AutoMigrate(&LevelB1{}, &LevelB2{}, &LevelB3{}).Error; err != nil { - t.Error(err) - } - - levelB1 := &LevelB1{Value: "foo"} - if err := DB.Create(levelB1).Error; err != nil { - t.Error(err) - } - - levelB3 := &LevelB3{ - Value: "bar", - LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)}, - LevelB2s: []*LevelB2{}, - } - if err := DB.Create(levelB3).Error; err != nil { - t.Error(err) - } - levelB1.LevelB3s = []*LevelB3{levelB3} - - want := []*LevelB1{levelB1} - var got []*LevelB1 - if err := DB.Preload("LevelB3s.LevelB2s").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -type LevelC1 struct { - ID uint - Value string - LevelC2ID uint -} - -type LevelC2 struct { - ID uint - Value string - LevelC1 LevelC1 -} - -type LevelC3 struct { - ID uint - Value string - LevelC2ID uint - LevelC2 LevelC2 -} - -func TestNestedPreload12(t *testing.T) { - DB.DropTableIfExists(&LevelC2{}) - DB.DropTableIfExists(&LevelC3{}) - DB.DropTableIfExists(&LevelC1{}) - if err := DB.AutoMigrate(&LevelC1{}, &LevelC2{}, &LevelC3{}).Error; err != nil { - t.Error(err) - } - - level2 := LevelC2{ - Value: "c2", - LevelC1: LevelC1{ - Value: "c1", - }, - } - DB.Create(&level2) - - want := []LevelC3{ - { - Value: "c3-1", - LevelC2: level2, - }, { - Value: "c3-2", - LevelC2: level2, - }, - } - - for i := range want { - if err := DB.Create(&want[i]).Error; err != nil { - t.Error(err) - } - } - - var got []LevelC3 - if err := DB.Preload("LevelC2").Preload("LevelC2.LevelC1").Find(&got).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" || dialect == "mssql" { - return - } - - type ( - Level1 struct { - ID uint `gorm:"primary_key;"` - LanguageCode string `gorm:"primary_key"` - Value string - } - Level2 struct { - ID uint `gorm:"primary_key;"` - LanguageCode string `gorm:"primary_key"` - Value string - Level1s []Level1 `gorm:"many2many:levels;"` - } - ) - - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") - - if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level2{Value: "Bob", LanguageCode: "ru", Level1s: []Level1{ - {Value: "ru", LanguageCode: "ru"}, - {Value: "en", LanguageCode: "en"}, - }} - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - want2 := Level2{Value: "Tom", LanguageCode: "zh", Level1s: []Level1{ - {Value: "zh", LanguageCode: "zh"}, - {Value: "de", LanguageCode: "de"}, - }} - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } - - var got Level2 - if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - var got2 Level2 - if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got2, want2) { - t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) - } - - var got3 []Level2 - if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got3, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) - } - - var got4 []Level2 - if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - var ruLevel1 Level1 - var zhLevel1 Level1 - DB.First(&ruLevel1, "value = ?", "ru") - DB.First(&zhLevel1, "value = ?", "zh") - - got.Level1s = []Level1{ruLevel1} - got2.Level1s = []Level1{zhLevel1} - if !reflect.DeepEqual(got4, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) - } - - if err := DB.Preload("Level1s").Find(&got4, "value IN (?)", []string{"non-existing"}).Error; err != nil { - t.Error(err) - } -} - -func TestManyToManyPreloadForNestedPointer(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:levels;"` - } - Level3 struct { - ID uint - Value string - Level2ID sql.NullInt64 - Level2 *Level2 - } - ) - - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{ - Value: "Bob", - Level2: &Level2{ - Value: "Foo", - Level1s: []*Level1{ - {Value: "ru"}, - {Value: "en"}, - }, - }, - } - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - want2 := Level3{ - Value: "Tom", - Level2: &Level2{ - Value: "Bar", - Level1s: []*Level1{ - {Value: "zh"}, - {Value: "de"}, - }, - }, - } - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - var got2 Level3 - if err := DB.Preload("Level2.Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got2, want2) { - t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) - } - - var got3 []Level3 - if err := DB.Preload("Level2.Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got3, []Level3{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level3{got, got2})) - } - - var got4 []Level3 - if err := DB.Preload("Level2.Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - var got5 Level3 - DB.Preload("Level2.Level1s").Find(&got5, "value = ?", "bogus") - - var ruLevel1 Level1 - var zhLevel1 Level1 - DB.First(&ruLevel1, "value = ?", "ru") - DB.First(&zhLevel1, "value = ?", "zh") - - got.Level2.Level1s = []*Level1{&ruLevel1} - got2.Level2.Level1s = []*Level1{&zhLevel1} - if !reflect.DeepEqual(got4, []Level3{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level3{got, got2})) - } -} - -func TestNestedManyToManyPreload(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:level1_level2;"` - } - Level3 struct { - ID uint - Value string - Level2s []Level2 `gorm:"many2many:level2_level3;"` - } - ) - - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") - DB.DropTableIfExists("level2_level3") - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{ - Value: "Level3", - Level2s: []Level2{ - { - Value: "Bob", - Level1s: []*Level1{ - {Value: "ru"}, - {Value: "en"}, - }, - }, { - Value: "Tom", - Level1s: []*Level1{ - {Value: "zh"}, - {Value: "de"}, - }, - }, - }, - } - - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2s").Preload("Level2s.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { - t.Error(err) - } -} - -func TestNestedManyToManyPreload2(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:level1_level2;"` - } - Level3 struct { - ID uint - Value string - Level2ID sql.NullInt64 - Level2 *Level2 - } - ) - - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level3{ - Value: "Level3", - Level2: &Level2{ - Value: "Bob", - Level1s: []*Level1{ - {Value: "ru"}, - {Value: "en"}, - }, - }, - } - - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - var got Level3 - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { - t.Error(err) - } -} - -func TestNestedManyToManyPreload3(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:level1_level2;"` - } - Level3 struct { - ID uint - Value string - Level2ID sql.NullInt64 - Level2 *Level2 - } - ) - - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - level1Zh := &Level1{Value: "zh"} - level1Ru := &Level1{Value: "ru"} - level1En := &Level1{Value: "en"} - - level21 := &Level2{ - Value: "Level2-1", - Level1s: []*Level1{level1Zh, level1Ru}, - } - - level22 := &Level2{ - Value: "Level2-2", - Level1s: []*Level1{level1Zh, level1En}, - } - - wants := []*Level3{ - { - Value: "Level3-1", - Level2: level21, - }, - { - Value: "Level3-2", - Level2: level22, - }, - { - Value: "Level3-3", - Level2: level21, - }, - } - - for _, want := range wants { - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - } - - var gots []*Level3 - if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { - return db.Order("level1.id ASC") - }).Find(&gots).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(gots, wants) { - t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) - } -} - -func TestNestedManyToManyPreload3ForStruct(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []Level1 `gorm:"many2many:level1_level2;"` - } - Level3 struct { - ID uint - Value string - Level2ID sql.NullInt64 - Level2 Level2 - } - ) - - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - level1Zh := Level1{Value: "zh"} - level1Ru := Level1{Value: "ru"} - level1En := Level1{Value: "en"} - - level21 := Level2{ - Value: "Level2-1", - Level1s: []Level1{level1Zh, level1Ru}, - } - - level22 := Level2{ - Value: "Level2-2", - Level1s: []Level1{level1Zh, level1En}, - } - - wants := []*Level3{ - { - Value: "Level3-1", - Level2: level21, - }, - { - Value: "Level3-2", - Level2: level22, - }, - { - Value: "Level3-3", - Level2: level21, - }, - } - - for _, want := range wants { - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - } - - var gots []*Level3 - if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { - return db.Order("level1.id ASC") - }).Find(&gots).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(gots, wants) { - t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) - } -} - -func TestNestedManyToManyPreload4(t *testing.T) { - type ( - Level4 struct { - ID uint - Value string - Level3ID uint - } - Level3 struct { - ID uint - Value string - Level4s []*Level4 - } - Level2 struct { - ID uint - Value string - Level3s []*Level3 `gorm:"many2many:level2_level3;"` - } - Level1 struct { - ID uint - Value string - Level2s []*Level2 `gorm:"many2many:level1_level2;"` - } - ) - - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level4{}) - DB.DropTableIfExists("level1_level2") - DB.DropTableIfExists("level2_level3") - - dummy := Level1{ - Value: "Level1", - Level2s: []*Level2{{ - Value: "Level2", - Level3s: []*Level3{{ - Value: "Level3", - Level4s: []*Level4{{ - Value: "Level4", - }}, - }}, - }}, - } - - if err := DB.AutoMigrate(&Level4{}, &Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - if err := DB.Save(&dummy).Error; err != nil { - t.Error(err) - } - - var level1 Level1 - if err := DB.Preload("Level2s").Preload("Level2s.Level3s").Preload("Level2s.Level3s.Level4s").First(&level1).Error; err != nil { - t.Error(err) - } -} - -func TestManyToManyPreloadForPointer(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:levels;"` - } - ) - - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") - - if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level2{Value: "Bob", Level1s: []*Level1{ - {Value: "ru"}, - {Value: "en"}, - }} - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - want2 := Level2{Value: "Tom", Level1s: []*Level1{ - {Value: "zh"}, - {Value: "de"}, - }} - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } - - var got Level2 - if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - var got2 Level2 - if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got2, want2) { - t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) - } - - var got3 []Level2 - if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got3, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) - } - - var got4 []Level2 - if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - var got5 Level2 - DB.Preload("Level1s").First(&got5, "value = ?", "bogus") - - var ruLevel1 Level1 - var zhLevel1 Level1 - DB.First(&ruLevel1, "value = ?", "ru") - DB.First(&zhLevel1, "value = ?", "zh") - - got.Level1s = []*Level1{&ruLevel1} - got2.Level1s = []*Level1{&zhLevel1} - if !reflect.DeepEqual(got4, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) - } -} - -func TestNilPointerSlice(t *testing.T) { - type ( - Level3 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level3ID uint - Level3 *Level3 - } - Level1 struct { - ID uint - Value string - Level2ID uint - Level2 *Level2 - } - ) - - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level1{ - Value: "Bob", - Level2: &Level2{ - Value: "en", - Level3: &Level3{ - Value: "native", - }, - }, - } - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - want2 := Level1{ - Value: "Tom", - Level2: nil, - } - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } - - var got []Level1 - if err := DB.Preload("Level2").Preload("Level2.Level3").Find(&got).Error; err != nil { - t.Error(err) - } - - if len(got) != 2 { - t.Errorf("got %v items, expected 2", len(got)) - } - - if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { - t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) - } - - if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { - t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want2)) - } -} - -func TestNilPointerSlice2(t *testing.T) { - type ( - Level4 struct { - ID uint - } - Level3 struct { - ID uint - Level4ID sql.NullInt64 `sql:"index"` - Level4 *Level4 - } - Level2 struct { - ID uint - Level3s []*Level3 `gorm:"many2many:level2_level3s"` - } - Level1 struct { - ID uint - Level2ID sql.NullInt64 `sql:"index"` - Level2 *Level2 - } - ) - - DB.DropTableIfExists(new(Level4)) - DB.DropTableIfExists(new(Level3)) - DB.DropTableIfExists(new(Level2)) - DB.DropTableIfExists(new(Level1)) - - if err := DB.AutoMigrate(new(Level4), new(Level3), new(Level2), new(Level1)).Error; err != nil { - t.Error(err) - } - - want := new(Level1) - if err := DB.Save(want).Error; err != nil { - t.Error(err) - } - - got := new(Level1) - err := DB.Preload("Level2.Level3s.Level4").Last(&got).Error - if err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestPrefixedPreloadDuplication(t *testing.T) { - type ( - Level4 struct { - ID uint - Name string - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level4s []*Level4 - } - Level2 struct { - ID uint - Name string - Level3ID sql.NullInt64 `sql:"index"` - Level3 *Level3 - } - Level1 struct { - ID uint - Name string - Level2ID sql.NullInt64 `sql:"index"` - Level2 *Level2 - } - ) - - DB.DropTableIfExists(new(Level3)) - DB.DropTableIfExists(new(Level4)) - DB.DropTableIfExists(new(Level2)) - DB.DropTableIfExists(new(Level1)) - - if err := DB.AutoMigrate(new(Level3), new(Level4), new(Level2), new(Level1)).Error; err != nil { - t.Error(err) - } - - lvl := &Level3{} - if err := DB.Save(lvl).Error; err != nil { - t.Error(err) - } - - sublvl1 := &Level4{Level3ID: lvl.ID} - if err := DB.Save(sublvl1).Error; err != nil { - t.Error(err) - } - sublvl2 := &Level4{Level3ID: lvl.ID} - if err := DB.Save(sublvl2).Error; err != nil { - t.Error(err) - } - - lvl.Level4s = []*Level4{sublvl1, sublvl2} - - want1 := Level1{ - Level2: &Level2{ - Level3: lvl, - }, - } - if err := DB.Save(&want1).Error; err != nil { - t.Error(err) - } - - want2 := Level1{ - Level2: &Level2{ - Level3: lvl, - }, - } - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } - - want := []Level1{want1, want2} - - var got []Level1 - err := DB.Preload("Level2.Level3.Level4s").Find(&got).Error - if err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } -} - -func TestPreloadManyToManyCallbacks(t *testing.T) { - type ( - Level2 struct { - ID uint - Name string - } - Level1 struct { - ID uint - Name string - Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"` - } - ) - - DB.DropTableIfExists("level1_level2s") - DB.DropTableIfExists(new(Level1)) - DB.DropTableIfExists(new(Level2)) - - if err := DB.AutoMigrate(new(Level1), new(Level2)).Error; err != nil { - t.Error(err) - } - - lvl := Level1{ - Name: "l1", - Level2s: []Level2{ - {Name: "l2-1"}, {Name: "l2-2"}, - }, - } - DB.Save(&lvl) - - called := 0 - - DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) { - called = called + 1 - }) - - DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID) - - if called != 3 { - t.Errorf("Wanted callback to be called 3 times but got %d", called) - } -} - -func toJSONString(v interface{}) []byte { - r, _ := json.MarshalIndent(v, "", " ") - return r -} diff --git a/query_test.go b/query_test.go deleted file mode 100644 index a23a9e24..00000000 --- a/query_test.go +++ /dev/null @@ -1,841 +0,0 @@ -package gorm_test - -import ( - "fmt" - "reflect" - - "github.com/jinzhu/gorm" - - "testing" - "time" -) - -func TestFirstAndLast(t *testing.T) { - DB.Save(&User{Name: "user1", Emails: []Email{{Email: "user1@example.com"}}}) - DB.Save(&User{Name: "user2", Emails: []Email{{Email: "user2@example.com"}}}) - - var user1, user2, user3, user4 User - DB.First(&user1) - DB.Order("id").Limit(1).Find(&user2) - - ptrOfUser3 := &user3 - DB.Last(&ptrOfUser3) - DB.Order("id desc").Limit(1).Find(&user4) - if user1.Id != user2.Id || user3.Id != user4.Id { - t.Errorf("First and Last should by order by primary key") - } - - var users []User - DB.First(&users) - if len(users) != 1 { - t.Errorf("Find first record as slice") - } - - var user User - if DB.Joins("left join emails on emails.user_id = users.id").First(&user).Error != nil { - t.Errorf("Should not raise any error when order with Join table") - } - - if user.Email != "" { - t.Errorf("User's Email should be blank as no one set it") - } -} - -func TestFirstAndLastWithNoStdPrimaryKey(t *testing.T) { - DB.Save(&Animal{Name: "animal1"}) - DB.Save(&Animal{Name: "animal2"}) - - var animal1, animal2, animal3, animal4 Animal - DB.First(&animal1) - DB.Order("counter").Limit(1).Find(&animal2) - - DB.Last(&animal3) - DB.Order("counter desc").Limit(1).Find(&animal4) - if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter { - t.Errorf("First and Last should work correctly") - } -} - -func TestFirstAndLastWithRaw(t *testing.T) { - user1 := User{Name: "user", Emails: []Email{{Email: "user1@example.com"}}} - user2 := User{Name: "user", Emails: []Email{{Email: "user2@example.com"}}} - DB.Save(&user1) - DB.Save(&user2) - - var user3, user4 User - DB.Raw("select * from users WHERE name = ?", "user").First(&user3) - if user3.Id != user1.Id { - t.Errorf("Find first record with raw") - } - - DB.Raw("select * from users WHERE name = ?", "user").Last(&user4) - if user4.Id != user2.Id { - t.Errorf("Find last record with raw") - } -} - -func TestUIntPrimaryKey(t *testing.T) { - var animal Animal - DB.First(&animal, uint64(1)) - if animal.Counter != 1 { - t.Errorf("Fetch a record from with a non-int primary key should work, but failed") - } - - DB.Model(Animal{}).Where(Animal{Counter: uint64(2)}).Scan(&animal) - if animal.Counter != 2 { - t.Errorf("Fetch a record from with a non-int primary key should work, but failed") - } -} - -func TestCustomizedTypePrimaryKey(t *testing.T) { - type ID uint - type CustomizedTypePrimaryKey struct { - ID ID - Name string - } - - DB.AutoMigrate(&CustomizedTypePrimaryKey{}) - - p1 := CustomizedTypePrimaryKey{Name: "p1"} - p2 := CustomizedTypePrimaryKey{Name: "p2"} - p3 := CustomizedTypePrimaryKey{Name: "p3"} - DB.Create(&p1) - DB.Create(&p2) - DB.Create(&p3) - - var p CustomizedTypePrimaryKey - - if err := DB.First(&p, p2.ID).Error; err == nil { - t.Errorf("Should return error for invalid query condition") - } - - if err := DB.First(&p, "id = ?", p2.ID).Error; err != nil { - t.Errorf("No error should happen when querying with customized type for primary key, got err %v", err) - } - - if p.Name != "p2" { - t.Errorf("Should find correct value when querying with customized type for primary key") - } -} - -func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { - type AddressByZipCode struct { - ZipCode string `gorm:"primary_key"` - Address string - } - - DB.AutoMigrate(&AddressByZipCode{}) - DB.Create(&AddressByZipCode{ZipCode: "00501", Address: "Holtsville"}) - - var address AddressByZipCode - DB.First(&address, "00501") - if address.ZipCode != "00501" { - t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode) - } -} - -func TestFindAsSliceOfPointers(t *testing.T) { - DB.Save(&User{Name: "user"}) - - var users []User - DB.Find(&users) - - var userPointers []*User - DB.Find(&userPointers) - - if len(users) == 0 || len(users) != len(userPointers) { - t.Errorf("Find slice of pointers") - } -} - -func TestSearchWithPlainSQL(t *testing.T) { - user1 := User{Name: "PlainSqlUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "PlainSqlUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "PlainSqlUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - scopedb := DB.Where("name LIKE ?", "%PlainSqlUser%") - - if DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { - t.Errorf("Search with plain SQL") - } - - if DB.Where("name LIKE ?", "%"+user1.Name+"%").First(&User{}).RecordNotFound() { - t.Errorf("Search with plan SQL (regexp)") - } - - var users []User - DB.Find(&users, "name LIKE ? and age > ?", "%PlainSqlUser%", 1) - if len(users) != 2 { - t.Errorf("Should found 2 users that age > 1, but got %v", len(users)) - } - - DB.Where("name LIKE ?", "%PlainSqlUser%").Where("age >= ?", 1).Find(&users) - if len(users) != 3 { - t.Errorf("Should found 3 users that age >= 1, but got %v", len(users)) - } - - scopedb.Where("age <> ?", 20).Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users age != 20, but got %v", len(users)) - } - - scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users)) - } - - scopedb.Where("birthday > ?", "2002-10-10").Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users)) - } - - scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) - if len(users) != 1 { - t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) - } - - DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users, but got %v", len(users)) - } - - DB.Where("id in (?)", []int64{user1.Id, user2.Id, user3.Id}).Find(&users) - if len(users) != 3 { - t.Errorf("Should found 3 users, but got %v", len(users)) - } - - DB.Where("id in (?)", user1.Id).Find(&users) - if len(users) != 1 { - t.Errorf("Should found 1 users, but got %v", len(users)) - } - - if err := DB.Where("id IN (?)", []string{}).Find(&users).Error; err != nil { - t.Error("no error should happen when query with empty slice, but got: ", err) - } - - if err := DB.Not("id IN (?)", []string{}).Find(&users).Error; err != nil { - t.Error("no error should happen when query with empty slice, but got: ", err) - } - - if DB.Where("name = ?", "none existing").Find(&[]User{}).RecordNotFound() { - t.Errorf("Should not get RecordNotFound error when looking for none existing records") - } -} - -func TestSearchWithTwoDimensionalArray(t *testing.T) { - var users []User - user1 := User{Name: "2DSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "2DSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "2DSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Create(&user1) - DB.Create(&user2) - DB.Create(&user3) - - if dialect := DB.Dialect().GetName(); dialect == "mysql" || dialect == "postgres" { - if err := DB.Where("(name, age) IN (?)", [][]interface{}{{"2DSearchUser1", 1}, {"2DSearchUser2", 10}}).Find(&users).Error; err != nil { - t.Errorf("No error should happen when query with 2D array, but got %v", err) - - if len(users) != 2 { - t.Errorf("Should find 2 users with 2D array, but got %v", len(users)) - } - } - } - - if dialect := DB.Dialect().GetName(); dialect == "mssql" { - if err := DB.Joins("JOIN (VALUES ?) AS x (col1, col2) ON x.col1 = name AND x.col2 = age", [][]interface{}{{"2DSearchUser1", 1}, {"2DSearchUser2", 10}}).Find(&users).Error; err != nil { - t.Errorf("No error should happen when query with 2D array, but got %v", err) - - if len(users) != 2 { - t.Errorf("Should find 2 users with 2D array, but got %v", len(users)) - } - } - } -} - -func TestSearchWithStruct(t *testing.T) { - user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "StructSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - if DB.Where(user1.Id).First(&User{}).RecordNotFound() { - t.Errorf("Search with primary key") - } - - if DB.First(&User{}, user1.Id).RecordNotFound() { - t.Errorf("Search with primary key as inline condition") - } - - if DB.First(&User{}, fmt.Sprintf("%v", user1.Id)).RecordNotFound() { - t.Errorf("Search with primary key as inline condition") - } - - var users []User - DB.Where([]int64{user1.Id, user2.Id, user3.Id}).Find(&users) - if len(users) != 3 { - t.Errorf("Should found 3 users when search with primary keys, but got %v", len(users)) - } - - var user User - DB.First(&user, &User{Name: user1.Name}) - if user.Id == 0 || user.Name != user1.Name { - t.Errorf("Search first record with inline pointer of struct") - } - - DB.First(&user, User{Name: user1.Name}) - if user.Id == 0 || user.Name != user1.Name { - t.Errorf("Search first record with inline struct") - } - - DB.Where(&User{Name: user1.Name}).First(&user) - if user.Id == 0 || user.Name != user1.Name { - t.Errorf("Search first record with where struct") - } - - DB.Find(&users, &User{Name: user2.Name}) - if len(users) != 1 { - t.Errorf("Search all records with inline struct") - } -} - -func TestSearchWithMap(t *testing.T) { - companyID := 1 - user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} - user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: parseTime("2020-1-1"), CompanyID: &companyID} - DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4) - - var user User - DB.First(&user, map[string]interface{}{"name": user1.Name}) - if user.Id == 0 || user.Name != user1.Name { - t.Errorf("Search first record with inline map") - } - - user = User{} - DB.Where(map[string]interface{}{"name": user2.Name}).First(&user) - if user.Id == 0 || user.Name != user2.Name { - t.Errorf("Search first record with where map") - } - - var users []User - DB.Where(map[string]interface{}{"name": user3.Name}).Find(&users) - if len(users) != 1 { - t.Errorf("Search all records with inline map") - } - - DB.Find(&users, map[string]interface{}{"name": user3.Name}) - if len(users) != 1 { - t.Errorf("Search all records with inline map") - } - - DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": nil}) - if len(users) != 0 { - t.Errorf("Search all records with inline map containing null value finding 0 records") - } - - DB.Find(&users, map[string]interface{}{"name": user1.Name, "company_id": nil}) - if len(users) != 1 { - t.Errorf("Search all records with inline map containing null value finding 1 record") - } - - DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": companyID}) - if len(users) != 1 { - t.Errorf("Search all records with inline multiple value map") - } -} - -func TestSearchWithEmptyChain(t *testing.T) { - user1 := User{Name: "ChainSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "ChainearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - if DB.Where("").Where("").First(&User{}).Error != nil { - t.Errorf("Should not raise any error if searching with empty strings") - } - - if DB.Where(&User{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { - t.Errorf("Should not raise any error if searching with empty struct") - } - - if DB.Where(map[string]interface{}{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { - t.Errorf("Should not raise any error if searching with empty map") - } -} - -func TestSelect(t *testing.T) { - user1 := User{Name: "SelectUser1"} - DB.Save(&user1) - - var user User - DB.Where("name = ?", user1.Name).Select("name").Find(&user) - if user.Id != 0 { - t.Errorf("Should not have ID because only selected name, %+v", user.Id) - } - - if user.Name != user1.Name { - t.Errorf("Should have user Name when selected it") - } -} - -func TestOrderAndPluck(t *testing.T) { - user1 := User{Name: "OrderPluckUser1", Age: 1} - user2 := User{Name: "OrderPluckUser2", Age: 10} - user3 := User{Name: "OrderPluckUser3", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3) - scopedb := DB.Model(&User{}).Where("name like ?", "%OrderPluckUser%") - - var user User - scopedb.Order(gorm.Expr("case when name = ? then 0 else 1 end", "OrderPluckUser2")).First(&user) - if user.Name != "OrderPluckUser2" { - t.Errorf("Order with sql expression") - } - - var ages []int64 - scopedb.Order("age desc").Pluck("age", &ages) - if ages[0] != 20 { - t.Errorf("The first age should be 20 when order with age desc") - } - - var ages1, ages2 []int64 - scopedb.Order("age desc").Pluck("age", &ages1).Pluck("age", &ages2) - if !reflect.DeepEqual(ages1, ages2) { - t.Errorf("The first order is the primary order") - } - - var ages3, ages4 []int64 - scopedb.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4) - if reflect.DeepEqual(ages3, ages4) { - t.Errorf("Reorder should work") - } - - var names []string - var ages5 []int64 - scopedb.Model(User{}).Order("name").Order("age desc").Pluck("age", &ages5).Pluck("name", &names) - if names != nil && ages5 != nil { - if !(names[0] == user1.Name && names[1] == user2.Name && names[2] == user3.Name && ages5[2] == 20) { - t.Errorf("Order with multiple orders") - } - } else { - t.Errorf("Order with multiple orders") - } - - var ages6 []int64 - if err := scopedb.Order("").Pluck("age", &ages6).Error; err != nil { - t.Errorf("An empty string as order clause produces invalid queries") - } - - DB.Model(User{}).Select("name, age").Find(&[]User{}) -} - -func TestLimit(t *testing.T) { - user1 := User{Name: "LimitUser1", Age: 1} - user2 := User{Name: "LimitUser2", Age: 10} - user3 := User{Name: "LimitUser3", Age: 20} - user4 := User{Name: "LimitUser4", Age: 10} - user5 := User{Name: "LimitUser5", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5) - - var users1, users2, users3 []User - DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) - - if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { - t.Errorf("Limit should works") - } -} - -func TestOffset(t *testing.T) { - for i := 0; i < 20; i++ { - DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) - } - var users1, users2, users3, users4 []User - DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) - - if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { - t.Errorf("Offset should work") - } -} - -func TestLimitAndOffsetSQL(t *testing.T) { - user1 := User{Name: "TestLimitAndOffsetSQL1", Age: 10} - user2 := User{Name: "TestLimitAndOffsetSQL2", Age: 20} - user3 := User{Name: "TestLimitAndOffsetSQL3", Age: 30} - user4 := User{Name: "TestLimitAndOffsetSQL4", Age: 40} - user5 := User{Name: "TestLimitAndOffsetSQL5", Age: 50} - if err := DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5).Error; err != nil { - t.Fatal(err) - } - - tests := []struct { - name string - limit, offset interface{} - users []*User - ok bool - }{ - { - name: "OK", - limit: float64(2), - offset: float64(2), - users: []*User{ - &User{Name: "TestLimitAndOffsetSQL3", Age: 30}, - &User{Name: "TestLimitAndOffsetSQL2", Age: 20}, - }, - ok: true, - }, - { - name: "Limit parse error", - limit: float64(1000000), // 1e+06 - offset: float64(2), - ok: false, - }, - { - name: "Offset parse error", - limit: float64(2), - offset: float64(1000000), // 1e+06 - ok: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var users []*User - err := DB.Where("name LIKE ?", "TestLimitAndOffsetSQL%").Order("age desc").Limit(tt.limit).Offset(tt.offset).Find(&users).Error - if tt.ok { - if err != nil { - t.Errorf("error expected nil, but got %v", err) - } - if len(users) != len(tt.users) { - t.Errorf("users length expected %d, but got %d", len(tt.users), len(users)) - } - for i := range tt.users { - if users[i].Name != tt.users[i].Name { - t.Errorf("users[%d] name expected %s, but got %s", i, tt.users[i].Name, users[i].Name) - } - if users[i].Age != tt.users[i].Age { - t.Errorf("users[%d] age expected %d, but got %d", i, tt.users[i].Age, users[i].Age) - } - } - } else { - if err == nil { - t.Error("error expected not nil, but got nil") - } - } - }) - } -} - -func TestOr(t *testing.T) { - user1 := User{Name: "OrUser1", Age: 1} - user2 := User{Name: "OrUser2", Age: 10} - user3 := User{Name: "OrUser3", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3) - - var users []User - DB.Where("name = ?", user1.Name).Or("name = ?", user2.Name).Find(&users) - if len(users) != 2 { - t.Errorf("Find users with or") - } -} - -func TestCount(t *testing.T) { - user1 := User{Name: "CountUser1", Age: 1} - user2 := User{Name: "CountUser2", Age: 10} - user3 := User{Name: "CountUser3", Age: 20} - - DB.Save(&user1).Save(&user2).Save(&user3) - var count, count1, count2 int64 - var users []User - - if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil { - t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) - } - - if count != int64(len(users)) { - t.Errorf("Count() method should get correct value") - } - - DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in (?)", []string{user2.Name, user3.Name}).Count(&count2) - if count1 != 1 || count2 != 3 { - t.Errorf("Multiple count in chain") - } - - var count3 int - if err := DB.Model(&User{}).Where("name in (?)", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { - t.Errorf("Not error should happen, but got %v", err) - } - - if count3 != 2 { - t.Errorf("Should get correct count, but got %v", count3) - } -} - -func TestNot(t *testing.T) { - DB.Create(getPreparedUser("user1", "not")) - DB.Create(getPreparedUser("user2", "not")) - DB.Create(getPreparedUser("user3", "not")) - - user4 := getPreparedUser("user4", "not") - user4.Company = Company{} - DB.Create(user4) - - DB := DB.Where("role = ?", "not") - - var users1, users2, users3, users4, users5, users6, users7, users8, users9 []User - if DB.Find(&users1).RowsAffected != 4 { - t.Errorf("should find 4 not users") - } - DB.Not(users1[0].Id).Find(&users2) - - if len(users1)-len(users2) != 1 { - t.Errorf("Should ignore the first users with Not") - } - - DB.Not([]int{}).Find(&users3) - if len(users1)-len(users3) != 0 { - t.Errorf("Should find all users with a blank condition") - } - - var name3Count int64 - DB.Table("users").Where("name = ?", "user3").Count(&name3Count) - DB.Not("name", "user3").Find(&users4) - if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users' name not equal 3") - } - - DB.Not("name = ?", "user3").Find(&users4) - if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users' name not equal 3") - } - - DB.Not("name <> ?", "user3").Find(&users4) - if len(users4) != int(name3Count) { - t.Errorf("Should find all users' name not equal 3") - } - - DB.Not(User{Name: "user3"}).Find(&users5) - - if len(users1)-len(users5) != int(name3Count) { - t.Errorf("Should find all users' name not equal 3") - } - - DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6) - if len(users1)-len(users6) != int(name3Count) { - t.Errorf("Should find all users' name not equal 3") - } - - DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) - if len(users1)-len(users7) != 2 { // not user3 or user4 - t.Errorf("Should find all user's name not equal to 3 who do not have company id") - } - - DB.Not("name", []string{"user3"}).Find(&users8) - if len(users1)-len(users8) != int(name3Count) { - t.Errorf("Should find all users' name not equal 3") - } - - var name2Count int64 - DB.Table("users").Where("name = ?", "user2").Count(&name2Count) - DB.Not("name", []string{"user3", "user2"}).Find(&users9) - if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { - t.Errorf("Should find all users' name not equal 3") - } -} - -func TestFillSmallerStruct(t *testing.T) { - user1 := User{Name: "SmallerUser", Age: 100} - DB.Save(&user1) - type SimpleUser struct { - Name string - Id int64 - UpdatedAt time.Time - CreatedAt time.Time - } - - var simpleUser SimpleUser - DB.Table("users").Where("name = ?", user1.Name).First(&simpleUser) - - if simpleUser.Id == 0 || simpleUser.Name == "" { - t.Errorf("Should fill data correctly into smaller struct") - } -} - -func TestFindOrInitialize(t *testing.T) { - var user1, user2, user3, user4, user5, user6 User - DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1) - if user1.Name != "find or init" || user1.Id != 0 || user1.Age != 33 { - t.Errorf("user should be initialized with search value") - } - - DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2) - if user2.Name != "find or init" || user2.Id != 0 || user2.Age != 33 { - t.Errorf("user should be initialized with search value") - } - - DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"}) - if user3.Name != "find or init 2" || user3.Id != 0 { - t.Errorf("user should be initialized with inline search value") - } - - DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4) - if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 { - t.Errorf("user should be initialized with search value and attrs") - } - - DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4) - if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 { - t.Errorf("user should be initialized with search value and assign attrs") - } - - DB.Save(&User{Name: "find or init", Age: 33}) - DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5) - if user5.Name != "find or init" || user5.Id == 0 || user5.Age != 33 { - t.Errorf("user should be found and not initialized by Attrs") - } - - DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6) - if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 33 { - t.Errorf("user should be found with FirstOrInit") - } - - DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6) - if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 44 { - t.Errorf("user should be found and updated with assigned attrs") - } -} - -func TestFindOrCreate(t *testing.T) { - var user1, user2, user3, user4, user5, user6, user7, user8 User - DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1) - if user1.Name != "find or create" || user1.Id == 0 || user1.Age != 33 { - t.Errorf("user should be created with search value") - } - - DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user2) - if user1.Id != user2.Id || user2.Name != "find or create" || user2.Id == 0 || user2.Age != 33 { - t.Errorf("user should be created with search value") - } - - DB.FirstOrCreate(&user3, map[string]interface{}{"name": "find or create 2"}) - if user3.Name != "find or create 2" || user3.Id == 0 { - t.Errorf("user should be created with inline search value") - } - - DB.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4) - if user4.Name != "find or create 3" || user4.Id == 0 || user4.Age != 44 { - t.Errorf("user should be created with search value and attrs") - } - - updatedAt1 := user4.UpdatedAt - DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4) - if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("UpdateAt should be changed when update values with assign") - } - - DB.Where(&User{Name: "find or create 4"}).Assign(User{Age: 44}).FirstOrCreate(&user4) - if user4.Name != "find or create 4" || user4.Id == 0 || user4.Age != 44 { - t.Errorf("user should be created with search value and assigned attrs") - } - - DB.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5) - if user5.Name != "find or create" || user5.Id == 0 || user5.Age != 33 { - t.Errorf("user should be found and not initialized by Attrs") - } - - DB.Where(&User{Name: "find or create"}).Assign(User{Age: 44}).FirstOrCreate(&user6) - if user6.Name != "find or create" || user6.Id == 0 || user6.Age != 44 { - t.Errorf("user should be found and updated with assigned attrs") - } - - DB.Where(&User{Name: "find or create"}).Find(&user7) - if user7.Name != "find or create" || user7.Id == 0 || user7.Age != 44 { - t.Errorf("user should be found and updated with assigned attrs") - } - - DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, CreditCard: CreditCard{Number: "1231231231"}, Emails: []Email{{Email: "jinzhu@assign_embedded_struct.com"}, {Email: "jinzhu-2@assign_embedded_struct.com"}}}).FirstOrCreate(&user8) - if DB.Where("email = ?", "jinzhu-2@assign_embedded_struct.com").First(&Email{}).RecordNotFound() { - t.Errorf("embedded struct email should be saved") - } - - if DB.Where("email = ?", "1231231231").First(&CreditCard{}).RecordNotFound() { - t.Errorf("embedded struct credit card should be saved") - } -} - -func TestSelectWithEscapedFieldName(t *testing.T) { - user1 := User{Name: "EscapedFieldNameUser", Age: 1} - user2 := User{Name: "EscapedFieldNameUser", Age: 10} - user3 := User{Name: "EscapedFieldNameUser", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3) - - var names []string - DB.Model(User{}).Where(&User{Name: "EscapedFieldNameUser"}).Pluck("\"name\"", &names) - - if len(names) != 3 { - t.Errorf("Expected 3 name, but got: %d", len(names)) - } -} - -func TestSelectWithVariables(t *testing.T) { - DB.Save(&User{Name: "jinzhu"}) - - rows, _ := DB.Table("users").Select("? as fake", gorm.Expr("name")).Rows() - - if !rows.Next() { - t.Errorf("Should have returned at least one row") - } else { - columns, _ := rows.Columns() - if !reflect.DeepEqual(columns, []string{"fake"}) { - t.Errorf("Should only contains one column") - } - } - - rows.Close() -} - -func TestSelectWithArrayInput(t *testing.T) { - DB.Save(&User{Name: "jinzhu", Age: 42}) - - var user User - DB.Select([]string{"name", "age"}).Where("age = 42 AND name = 'jinzhu'").First(&user) - - if user.Name != "jinzhu" || user.Age != 42 { - t.Errorf("Should have selected both age and name") - } -} - -func TestPluckWithSelect(t *testing.T) { - var ( - user = User{Name: "matematik7_pluck_with_select", Age: 25} - combinedName = fmt.Sprintf("%v%v", user.Name, user.Age) - combineUserAgeSQL = fmt.Sprintf("concat(%v, %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age")) - ) - - if dialect := DB.Dialect().GetName(); dialect == "sqlite3" { - combineUserAgeSQL = fmt.Sprintf("(%v || %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age")) - } - - DB.Save(&user) - - selectStr := combineUserAgeSQL + " as user_age" - var userAges []string - err := DB.Model(&User{}).Where("age = ?", 25).Select(selectStr).Pluck("user_age", &userAges).Error - if err != nil { - t.Error(err) - } - - if len(userAges) != 1 || userAges[0] != combinedName { - t.Errorf("Should correctly pluck with select, got: %s", userAges) - } - - selectStr = combineUserAgeSQL + fmt.Sprintf(" as %v", DB.Dialect().Quote("user_age")) - userAges = userAges[:0] - err = DB.Model(&User{}).Where("age = ?", 25).Select(selectStr).Pluck("user_age", &userAges).Error - if err != nil { - t.Error(err) - } - - if len(userAges) != 1 || userAges[0] != combinedName { - t.Errorf("Should correctly pluck with select, got: %s", userAges) - } -} diff --git a/scaner_test.go b/scaner_test.go deleted file mode 100644 index 9e251dd6..00000000 --- a/scaner_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package gorm_test - -import ( - "database/sql/driver" - "encoding/json" - "errors" - "testing" - - "github.com/jinzhu/gorm" -) - -func TestScannableSlices(t *testing.T) { - if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil { - t.Errorf("Should create table with slice values correctly: %s", err) - } - - r1 := RecordWithSlice{ - Strings: ExampleStringSlice{"a", "b", "c"}, - Structs: ExampleStructSlice{ - {"name1", "value1"}, - {"name2", "value2"}, - }, - } - - if err := DB.Save(&r1).Error; err != nil { - t.Errorf("Should save record with slice values") - } - - var r2 RecordWithSlice - - if err := DB.Find(&r2).Error; err != nil { - t.Errorf("Should fetch record with slice values") - } - - if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" { - t.Errorf("Should have serialised and deserialised a string array") - } - - if len(r2.Structs) != 2 || r2.Structs[0].Name != "name1" || r2.Structs[0].Value != "value1" || r2.Structs[1].Name != "name2" || r2.Structs[1].Value != "value2" { - t.Errorf("Should have serialised and deserialised a struct array") - } -} - -type RecordWithSlice struct { - ID uint64 - Strings ExampleStringSlice `sql:"type:text"` - Structs ExampleStructSlice `sql:"type:text"` -} - -type ExampleStringSlice []string - -func (l ExampleStringSlice) Value() (driver.Value, error) { - bytes, err := json.Marshal(l) - return string(bytes), err -} - -func (l *ExampleStringSlice) Scan(input interface{}) error { - switch value := input.(type) { - case string: - return json.Unmarshal([]byte(value), l) - case []byte: - return json.Unmarshal(value, l) - default: - return errors.New("not supported") - } -} - -type ExampleStruct struct { - Name string - Value string -} - -type ExampleStructSlice []ExampleStruct - -func (l ExampleStructSlice) Value() (driver.Value, error) { - bytes, err := json.Marshal(l) - return string(bytes), err -} - -func (l *ExampleStructSlice) Scan(input interface{}) error { - switch value := input.(type) { - case string: - return json.Unmarshal([]byte(value), l) - case []byte: - return json.Unmarshal(value, l) - default: - return errors.New("not supported") - } -} - -type ScannerDataType struct { - Street string `sql:"TYPE:varchar(24)"` -} - -func (ScannerDataType) Value() (driver.Value, error) { - return nil, nil -} - -func (*ScannerDataType) Scan(input interface{}) error { - return nil -} - -type ScannerDataTypeTestStruct struct { - Field1 int - ScannerDataType *ScannerDataType `sql:"TYPE:json"` -} - -type ScannerDataType2 struct { - Street string `sql:"TYPE:varchar(24)"` -} - -func (ScannerDataType2) Value() (driver.Value, error) { - return nil, nil -} - -func (*ScannerDataType2) Scan(input interface{}) error { - return nil -} - -type ScannerDataTypeTestStruct2 struct { - Field1 int - ScannerDataType *ScannerDataType2 -} - -func TestScannerDataType(t *testing.T) { - scope := gorm.Scope{Value: &ScannerDataTypeTestStruct{}} - if field, ok := scope.FieldByName("ScannerDataType"); ok { - if DB.Dialect().DataTypeOf(field.StructField) != "json" { - t.Errorf("data type for scanner is wrong") - } - } - - scope = gorm.Scope{Value: &ScannerDataTypeTestStruct2{}} - if field, ok := scope.FieldByName("ScannerDataType"); ok { - if DB.Dialect().DataTypeOf(field.StructField) != "varchar(24)" { - t.Errorf("data type for scanner is wrong") - } - } -} diff --git a/scope.go b/scope.go deleted file mode 100644 index d82cadbc..00000000 --- a/scope.go +++ /dev/null @@ -1,1421 +0,0 @@ -package gorm - -import ( - "bytes" - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "reflect" - "regexp" - "strings" - "time" -) - -// Scope contain current operation's information when you perform any operation on the database -type Scope struct { - Search *search - Value interface{} - SQL string - SQLVars []interface{} - db *DB - instanceID string - primaryKeyField *Field - skipLeft bool - fields *[]*Field - selectAttrs *[]string -} - -// IndirectValue return scope's reflect value's indirect value -func (scope *Scope) IndirectValue() reflect.Value { - return indirect(reflect.ValueOf(scope.Value)) -} - -// New create a new Scope without search information -func (scope *Scope) New(value interface{}) *Scope { - return &Scope{db: scope.NewDB(), Search: &search{}, Value: value} -} - -//////////////////////////////////////////////////////////////////////////////// -// Scope DB -//////////////////////////////////////////////////////////////////////////////// - -// DB return scope's DB connection -func (scope *Scope) DB() *DB { - return scope.db -} - -// NewDB create a new DB without search information -func (scope *Scope) NewDB() *DB { - if scope.db != nil { - db := scope.db.clone() - db.search = nil - db.Value = nil - return db - } - return nil -} - -// SQLDB return *sql.DB -func (scope *Scope) SQLDB() SQLCommon { - return scope.db.db -} - -// Dialect get dialect -func (scope *Scope) Dialect() Dialect { - return scope.db.dialect -} - -// Quote used to quote string to escape them for database -func (scope *Scope) Quote(str string) string { - if strings.Contains(str, ".") { - newStrs := []string{} - for _, str := range strings.Split(str, ".") { - newStrs = append(newStrs, scope.Dialect().Quote(str)) - } - return strings.Join(newStrs, ".") - } - - return scope.Dialect().Quote(str) -} - -// Err add error to Scope -func (scope *Scope) Err(err error) error { - if err != nil { - scope.db.AddError(err) - } - return err -} - -// HasError check if there are any error -func (scope *Scope) HasError() bool { - return scope.db.Error != nil -} - -// Log print log message -func (scope *Scope) Log(v ...interface{}) { - scope.db.log(v...) -} - -// SkipLeft skip remaining callbacks -func (scope *Scope) SkipLeft() { - scope.skipLeft = true -} - -// Fields get value's fields -func (scope *Scope) Fields() []*Field { - if scope.fields == nil { - var ( - fields []*Field - indirectScopeValue = scope.IndirectValue() - isStruct = indirectScopeValue.Kind() == reflect.Struct - ) - - for _, structField := range scope.GetModelStruct().StructFields { - if isStruct { - fieldValue := indirectScopeValue - for _, name := range structField.Names { - if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) - } - fieldValue = reflect.Indirect(fieldValue).FieldByName(name) - } - fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}) - } else { - fields = append(fields, &Field{StructField: structField, IsBlank: true}) - } - } - scope.fields = &fields - } - - return *scope.fields -} - -// FieldByName find `gorm.Field` with field name or db name -func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { - var ( - dbName = ToColumnName(name) - mostMatchedField *Field - ) - - for _, field := range scope.Fields() { - if field.Name == name || field.DBName == name { - return field, true - } - if field.DBName == dbName { - mostMatchedField = field - } - } - return mostMatchedField, mostMatchedField != nil -} - -// PrimaryFields return scope's primary fields -func (scope *Scope) PrimaryFields() (fields []*Field) { - for _, field := range scope.Fields() { - if field.IsPrimaryKey { - fields = append(fields, field) - } - } - return fields -} - -// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one -func (scope *Scope) PrimaryField() *Field { - if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { - if len(primaryFields) > 1 { - if field, ok := scope.FieldByName("id"); ok { - return field - } - } - return scope.PrimaryFields()[0] - } - return nil -} - -// PrimaryKey get main primary field's db name -func (scope *Scope) PrimaryKey() string { - if field := scope.PrimaryField(); field != nil { - return field.DBName - } - return "" -} - -// PrimaryKeyZero check main primary field's value is blank or not -func (scope *Scope) PrimaryKeyZero() bool { - field := scope.PrimaryField() - return field == nil || field.IsBlank -} - -// PrimaryKeyValue get the primary key's value -func (scope *Scope) PrimaryKeyValue() interface{} { - if field := scope.PrimaryField(); field != nil && field.Field.IsValid() { - return field.Field.Interface() - } - return 0 -} - -// HasColumn to check if has column -func (scope *Scope) HasColumn(column string) bool { - for _, field := range scope.GetStructFields() { - if field.IsNormal && (field.Name == column || field.DBName == column) { - return true - } - } - return false -} - -// SetColumn to set the column's value, column could be field or field's name/dbname -func (scope *Scope) SetColumn(column interface{}, value interface{}) error { - var updateAttrs = map[string]interface{}{} - if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - updateAttrs = attrs.(map[string]interface{}) - defer scope.InstanceSet("gorm:update_attrs", updateAttrs) - } - - if field, ok := column.(*Field); ok { - updateAttrs[field.DBName] = value - return field.Set(value) - } else if name, ok := column.(string); ok { - var ( - dbName = ToDBName(name) - mostMatchedField *Field - ) - for _, field := range scope.Fields() { - if field.DBName == value { - updateAttrs[field.DBName] = value - return field.Set(value) - } - if !field.IsIgnored && ((field.DBName == dbName) || (field.Name == name && mostMatchedField == nil)) { - mostMatchedField = field - } - } - - if mostMatchedField != nil { - updateAttrs[mostMatchedField.DBName] = value - return mostMatchedField.Set(value) - } - } - return errors.New("could not convert column to field") -} - -// CallMethod call scope value's method, if it is a slice, will call its element's method one by one -func (scope *Scope) CallMethod(methodName string) { - if scope.Value == nil { - return - } - - if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice { - for i := 0; i < indirectScopeValue.Len(); i++ { - scope.callMethod(methodName, indirectScopeValue.Index(i)) - } - } else { - scope.callMethod(methodName, indirectScopeValue) - } -} - -// AddToVars add value as sql's vars, used to prevent SQL injection -func (scope *Scope) AddToVars(value interface{}) string { - _, skipBindVar := scope.InstanceGet("skip_bindvar") - - if expr, ok := value.(*SqlExpr); ok { - exp := expr.expr - for _, arg := range expr.args { - if skipBindVar { - scope.AddToVars(arg) - } else { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) - } - } - return exp - } - - scope.SQLVars = append(scope.SQLVars, value) - - if skipBindVar { - return "?" - } - return scope.Dialect().BindVar(len(scope.SQLVars)) -} - -// SelectAttrs return selected attributes -func (scope *Scope) SelectAttrs() []string { - if scope.selectAttrs == nil { - attrs := []string{} - for _, value := range scope.Search.selects { - if str, ok := value.(string); ok { - attrs = append(attrs, str) - } else if strs, ok := value.([]string); ok { - attrs = append(attrs, strs...) - } else if strs, ok := value.([]interface{}); ok { - for _, str := range strs { - attrs = append(attrs, fmt.Sprintf("%v", str)) - } - } - } - scope.selectAttrs = &attrs - } - return *scope.selectAttrs -} - -// OmitAttrs return omitted attributes -func (scope *Scope) OmitAttrs() []string { - return scope.Search.omits -} - -type tabler interface { - TableName() string -} - -type dbTabler interface { - TableName(*DB) string -} - -// TableName return table name -func (scope *Scope) TableName() string { - if scope.Search != nil && len(scope.Search.tableName) > 0 { - return scope.Search.tableName - } - - if tabler, ok := scope.Value.(tabler); ok { - return tabler.TableName() - } - - if tabler, ok := scope.Value.(dbTabler); ok { - return tabler.TableName(scope.db) - } - - return scope.GetModelStruct().TableName(scope.db.Model(scope.Value)) -} - -// QuotedTableName return quoted table name -func (scope *Scope) QuotedTableName() (name string) { - if scope.Search != nil && len(scope.Search.tableName) > 0 { - if strings.Contains(scope.Search.tableName, " ") { - return scope.Search.tableName - } - return scope.Quote(scope.Search.tableName) - } - - return scope.Quote(scope.TableName()) -} - -// CombinedConditionSql return combined condition sql -func (scope *Scope) CombinedConditionSql() string { - joinSQL := scope.joinsSQL() - whereSQL := scope.whereSQL() - if scope.Search.raw { - whereSQL = strings.TrimSuffix(strings.TrimPrefix(whereSQL, "WHERE ("), ")") - } - return joinSQL + whereSQL + scope.groupSQL() + - scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL() -} - -// Raw set raw sql -func (scope *Scope) Raw(sql string) *Scope { - scope.SQL = strings.Replace(sql, "$$$", "?", -1) - return scope -} - -// Exec perform generated SQL -func (scope *Scope) Exec() *Scope { - defer scope.trace(NowFunc()) - - if !scope.HasError() { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - if count, err := result.RowsAffected(); scope.Err(err) == nil { - scope.db.RowsAffected = count - } - } - } - return scope -} - -// Set set value by name -func (scope *Scope) Set(name string, value interface{}) *Scope { - scope.db.InstantSet(name, value) - return scope -} - -// Get get setting by name -func (scope *Scope) Get(name string) (interface{}, bool) { - return scope.db.Get(name) -} - -// InstanceID get InstanceID for scope -func (scope *Scope) InstanceID() string { - if scope.instanceID == "" { - scope.instanceID = fmt.Sprintf("%v%v", &scope, &scope.db) - } - return scope.instanceID -} - -// InstanceSet set instance setting for current operation, but not for operations in callbacks, like saving associations callback -func (scope *Scope) InstanceSet(name string, value interface{}) *Scope { - return scope.Set(name+scope.InstanceID(), value) -} - -// InstanceGet get instance setting from current operation -func (scope *Scope) InstanceGet(name string) (interface{}, bool) { - return scope.Get(name + scope.InstanceID()) -} - -// Begin start a transaction -func (scope *Scope) Begin() *Scope { - if db, ok := scope.SQLDB().(sqlDb); ok { - if tx, err := db.Begin(); scope.Err(err) == nil { - scope.db.db = interface{}(tx).(SQLCommon) - scope.InstanceSet("gorm:started_transaction", true) - } - } - return scope -} - -// CommitOrRollback commit current transaction if no error happened, otherwise will rollback it -func (scope *Scope) CommitOrRollback() *Scope { - if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { - if db, ok := scope.db.db.(sqlTx); ok { - if scope.HasError() { - db.Rollback() - } else { - scope.Err(db.Commit()) - } - scope.db.db = scope.db.parent.db - } - } - return scope -} - -//////////////////////////////////////////////////////////////////////////////// -// Private Methods For *gorm.Scope -//////////////////////////////////////////////////////////////////////////////// - -func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { - // Only get address from non-pointer - if reflectValue.CanAddr() && reflectValue.Kind() != reflect.Ptr { - reflectValue = reflectValue.Addr() - } - - if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() { - switch method := methodValue.Interface().(type) { - case func(): - method() - case func(*Scope): - method(scope) - case func(*DB): - newDB := scope.NewDB() - method(newDB) - scope.Err(newDB.Error) - case func() error: - scope.Err(method()) - case func(*Scope) error: - scope.Err(method(scope)) - case func(*DB) error: - newDB := scope.NewDB() - scope.Err(method(newDB)) - scope.Err(newDB.Error) - default: - scope.Err(fmt.Errorf("unsupported function %v", methodName)) - } - } -} - -var ( - columnRegexp = regexp.MustCompile("^[a-zA-Z\\d]+(\\.[a-zA-Z\\d]+)*$") // only match string like `name`, `users.name` - isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number - comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ") - countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$") -) - -func (scope *Scope) quoteIfPossible(str string) string { - if columnRegexp.MatchString(str) { - return scope.Quote(str) - } - return str -} - -func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { - var ( - ignored interface{} - values = make([]interface{}, len(columns)) - selectFields []*Field - selectedColumnsMap = map[string]int{} - resetFields = map[int]*Field{} - ) - - for index, column := range columns { - values[index] = &ignored - - selectFields = fields - offset := 0 - if idx, ok := selectedColumnsMap[column]; ok { - offset = idx + 1 - selectFields = selectFields[offset:] - } - - for fieldIndex, field := range selectFields { - if field.DBName == column { - if field.Field.Kind() == reflect.Ptr { - values[index] = field.Field.Addr().Interface() - } else { - reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) - reflectValue.Elem().Set(field.Field.Addr()) - values[index] = reflectValue.Interface() - resetFields[index] = field - } - - selectedColumnsMap[column] = offset + fieldIndex - - if field.IsNormal { - break - } - } - } - } - - scope.Err(rows.Scan(values...)) - - for index, field := range resetFields { - if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { - field.Field.Set(v) - } - } -} - -func (scope *Scope) primaryCondition(value interface{}) string { - return fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()), value) -} - -func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) (str string) { - var ( - quotedTableName = scope.QuotedTableName() - quotedPrimaryKey = scope.Quote(scope.PrimaryKey()) - equalSQL = "=" - inSQL = "IN" - ) - - // If building not conditions - if !include { - equalSQL = "<>" - inSQL = "NOT IN" - } - - switch value := clause["query"].(type) { - case sql.NullInt64: - return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value.Int64) - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: - if !include && reflect.ValueOf(value).Len() == 0 { - return - } - str = fmt.Sprintf("(%v.%v %s (?))", quotedTableName, quotedPrimaryKey, inSQL) - clause["args"] = []interface{}{value} - case string: - if isNumberRegexp.MatchString(value) { - return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, scope.AddToVars(value)) - } - - if value != "" { - if !include { - if comparisonRegexp.MatchString(value) { - str = fmt.Sprintf("NOT (%v)", value) - } else { - str = fmt.Sprintf("(%v.%v NOT IN (?))", quotedTableName, scope.Quote(value)) - } - } else { - str = fmt.Sprintf("(%v)", value) - } - } - case map[string]interface{}: - var sqls []string - for key, value := range value { - if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(key), equalSQL, scope.AddToVars(value))) - } else { - if !include { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", quotedTableName, scope.Quote(key))) - } else { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", quotedTableName, scope.Quote(key))) - } - } - } - return strings.Join(sqls, " AND ") - case interface{}: - var sqls []string - newScope := scope.New(value) - - if len(newScope.Fields()) == 0 { - scope.Err(fmt.Errorf("invalid query condition: %v", value)) - return - } - scopeQuotedTableName := newScope.QuotedTableName() - for _, field := range newScope.Fields() { - if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) - } - } - return strings.Join(sqls, " AND ") - default: - scope.Err(fmt.Errorf("invalid query condition: %v", value)) - return - } - - replacements := []string{} - args := clause["args"].([]interface{}) - for _, arg := range args { - var err error - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, err = scanner.Value() - replacements = append(replacements, scope.AddToVars(arg)) - } else if b, ok := arg.([]byte); ok { - replacements = append(replacements, scope.AddToVars(b)) - } else if as, ok := arg.([][]interface{}); ok { - var tempMarks []string - for _, a := range as { - var arrayMarks []string - for _, v := range a { - arrayMarks = append(arrayMarks, scope.AddToVars(v)) - } - - if len(arrayMarks) > 0 { - tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ","))) - } - } - - if len(tempMarks) > 0 { - replacements = append(replacements, strings.Join(tempMarks, ",")) - } - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - replacements = append(replacements, strings.Join(tempMarks, ",")) - } else { - replacements = append(replacements, scope.AddToVars(Expr("NULL"))) - } - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, err = valuer.Value() - } - - replacements = append(replacements, scope.AddToVars(arg)) - } - - if err != nil { - scope.Err(err) - } - } - - buff := bytes.NewBuffer([]byte{}) - i := 0 - for _, s := range str { - if s == '?' && len(replacements) > i { - buff.WriteString(replacements[i]) - i++ - } else { - buff.WriteRune(s) - } - } - - str = buff.String() - - return -} - -func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) { - switch value := clause["query"].(type) { - case string: - str = value - case []string: - str = strings.Join(value, ", ") - } - - args := clause["args"].([]interface{}) - replacements := []string{} - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: - values := reflect.ValueOf(arg) - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - replacements = append(replacements, strings.Join(tempMarks, ",")) - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() - } - replacements = append(replacements, scope.AddToVars(arg)) - } - } - - buff := bytes.NewBuffer([]byte{}) - i := 0 - for pos, char := range str { - if str[pos] == '?' { - buff.WriteString(replacements[i]) - i++ - } else { - buff.WriteRune(char) - } - } - - str = buff.String() - - return -} - -func (scope *Scope) whereSQL() (sql string) { - var ( - quotedTableName = scope.QuotedTableName() - deletedAtField, hasDeletedAtField = scope.FieldByName("DeletedAt") - primaryConditions, andConditions, orConditions []string - ) - - if !scope.Search.Unscoped && hasDeletedAtField { - sql := fmt.Sprintf("%v.%v IS NULL", quotedTableName, scope.Quote(deletedAtField.DBName)) - primaryConditions = append(primaryConditions, sql) - } - - if !scope.PrimaryKeyZero() { - for _, field := range scope.PrimaryFields() { - sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())) - primaryConditions = append(primaryConditions, sql) - } - } - - for _, clause := range scope.Search.whereConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - andConditions = append(andConditions, sql) - } - } - - for _, clause := range scope.Search.orConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - orConditions = append(orConditions, sql) - } - } - - for _, clause := range scope.Search.notConditions { - if sql := scope.buildCondition(clause, false); sql != "" { - andConditions = append(andConditions, sql) - } - } - - orSQL := strings.Join(orConditions, " OR ") - combinedSQL := strings.Join(andConditions, " AND ") - if len(combinedSQL) > 0 { - if len(orSQL) > 0 { - combinedSQL = combinedSQL + " OR " + orSQL - } - } else { - combinedSQL = orSQL - } - - if len(primaryConditions) > 0 { - sql = "WHERE " + strings.Join(primaryConditions, " AND ") - if len(combinedSQL) > 0 { - sql = sql + " AND (" + combinedSQL + ")" - } - } else if len(combinedSQL) > 0 { - sql = "WHERE " + combinedSQL - } - return -} - -func (scope *Scope) selectSQL() string { - if len(scope.Search.selects) == 0 { - if len(scope.Search.joinConditions) > 0 { - return fmt.Sprintf("%v.*", scope.QuotedTableName()) - } - return "*" - } - return scope.buildSelectQuery(scope.Search.selects) -} - -func (scope *Scope) orderSQL() string { - if len(scope.Search.orders) == 0 || scope.Search.ignoreOrderQuery { - return "" - } - - var orders []string - for _, order := range scope.Search.orders { - if str, ok := order.(string); ok { - orders = append(orders, scope.quoteIfPossible(str)) - } else if expr, ok := order.(*SqlExpr); ok { - exp := expr.expr - for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) - } - orders = append(orders, exp) - } - } - return " ORDER BY " + strings.Join(orders, ",") -} - -func (scope *Scope) limitAndOffsetSQL() string { - sql, err := scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) - scope.Err(err) - return sql -} - -func (scope *Scope) groupSQL() string { - if len(scope.Search.group) == 0 { - return "" - } - return " GROUP BY " + scope.Search.group -} - -func (scope *Scope) havingSQL() string { - if len(scope.Search.havingConditions) == 0 { - return "" - } - - var andConditions []string - for _, clause := range scope.Search.havingConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - andConditions = append(andConditions, sql) - } - } - - combinedSQL := strings.Join(andConditions, " AND ") - if len(combinedSQL) == 0 { - return "" - } - - return " HAVING " + combinedSQL -} - -func (scope *Scope) joinsSQL() string { - var joinConditions []string - for _, clause := range scope.Search.joinConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) - } - } - - return strings.Join(joinConditions, " ") + " " -} - -func (scope *Scope) prepareQuerySQL() { - if scope.Search.raw { - scope.Raw(scope.CombinedConditionSql()) - } else { - scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql())) - } - return -} - -func (scope *Scope) inlineCondition(values ...interface{}) *Scope { - if len(values) > 0 { - scope.Search.Where(values[0], values[1:]...) - } - return scope -} - -func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { - defer func() { - if err := recover(); err != nil { - if db, ok := scope.db.db.(sqlTx); ok { - db.Rollback() - } - panic(err) - } - }() - for _, f := range funcs { - (*f)(scope) - if scope.skipLeft { - break - } - } - return scope -} - -func convertInterfaceToMap(values interface{}, withIgnoredField bool, db *DB) map[string]interface{} { - var attrs = map[string]interface{}{} - - switch value := values.(type) { - case map[string]interface{}: - return value - case []interface{}: - for _, v := range value { - for key, value := range convertInterfaceToMap(v, withIgnoredField, db) { - attrs[key] = value - } - } - case interface{}: - reflectValue := reflect.ValueOf(values) - - switch reflectValue.Kind() { - case reflect.Map: - for _, key := range reflectValue.MapKeys() { - attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() - } - default: - for _, field := range (&Scope{Value: values, db: db}).Fields() { - if !field.IsBlank && (withIgnoredField || !field.IsIgnored) { - attrs[field.DBName] = field.Field.Interface() - } - } - } - } - return attrs -} - -func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) { - if scope.IndirectValue().Kind() != reflect.Struct { - return convertInterfaceToMap(value, false, scope.db), true - } - - results = map[string]interface{}{} - - for key, value := range convertInterfaceToMap(value, true, scope.db) { - if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { - if _, ok := value.(*SqlExpr); ok { - hasUpdate = true - results[field.DBName] = value - } else { - err := field.Set(value) - if field.IsNormal && !field.IsIgnored { - hasUpdate = true - if err == ErrUnaddressable { - results[field.DBName] = value - } else { - results[field.DBName] = field.Field.Interface() - } - } - } - } - } - return -} - -func (scope *Scope) row() *sql.Row { - defer scope.trace(NowFunc()) - - result := &RowQueryResult{} - scope.InstanceSet("row_query_result", result) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - - return result.Row -} - -func (scope *Scope) rows() (*sql.Rows, error) { - defer scope.trace(NowFunc()) - - result := &RowsQueryResult{} - scope.InstanceSet("row_query_result", result) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - - return result.Rows, result.Error -} - -func (scope *Scope) initialize() *Scope { - for _, clause := range scope.Search.whereConditions { - scope.updatedAttrsWithValues(clause["query"]) - } - scope.updatedAttrsWithValues(scope.Search.initAttrs) - scope.updatedAttrsWithValues(scope.Search.assignAttrs) - return scope -} - -func (scope *Scope) isQueryForColumn(query interface{}, column string) bool { - queryStr := strings.ToLower(fmt.Sprint(query)) - if queryStr == column { - return true - } - - if strings.HasSuffix(queryStr, "as "+column) { - return true - } - - if strings.HasSuffix(queryStr, "as "+scope.Quote(column)) { - return true - } - - return false -} - -func (scope *Scope) pluck(column string, value interface{}) *Scope { - dest := reflect.Indirect(reflect.ValueOf(value)) - if dest.Kind() != reflect.Slice { - scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) - return scope - } - - if dest.Len() > 0 { - dest.Set(reflect.Zero(dest.Type())) - } - - if query, ok := scope.Search.selects["query"]; !ok || !scope.isQueryForColumn(query, column) { - scope.Search.Select(column) - } - - rows, err := scope.rows() - if scope.Err(err) == nil { - defer rows.Close() - for rows.Next() { - elem := reflect.New(dest.Type().Elem()).Interface() - scope.Err(rows.Scan(elem)) - dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } - } - return scope -} - -func (scope *Scope) count(value interface{}) *Scope { - if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) { - if len(scope.Search.group) != 0 { - if len(scope.Search.havingConditions) != 0 { - scope.prepareQuerySQL() - scope.Search = &search{} - scope.Search.Select("count(*)") - scope.Search.Table(fmt.Sprintf("( %s ) AS count_table", scope.SQL)) - } else { - scope.Search.Select("count(*) FROM ( SELECT count(*) as name ") - scope.Search.group += " ) AS count_table" - } - } else { - scope.Search.Select("count(*)") - } - } - scope.Search.ignoreOrderQuery = true - scope.Err(scope.row().Scan(value)) - return scope -} - -func (scope *Scope) typeName() string { - typ := scope.IndirectValue().Type() - - for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - - return typ.Name() -} - -// trace print sql log -func (scope *Scope) trace(t time.Time) { - if len(scope.SQL) > 0 { - scope.db.slog(scope.SQL, t, scope.SQLVars...) - } -} - -func (scope *Scope) changeableField(field *Field) bool { - if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 { - for _, attr := range selectAttrs { - if field.Name == attr || field.DBName == attr { - return true - } - } - return false - } - - for _, attr := range scope.OmitAttrs() { - if field.Name == attr || field.DBName == attr { - return false - } - } - - return true -} - -func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { - toScope := scope.db.NewScope(value) - tx := scope.db.Set("gorm:association:source", scope.Value) - - for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { - fromField, _ := scope.FieldByName(foreignKey) - toField, _ := toScope.FieldByName(foreignKey) - - if fromField != nil { - if relationship := fromField.Relationship; relationship != nil { - if relationship.Kind == "many_to_many" { - joinTableHandler := relationship.JoinTableHandler - scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Find(value).Error) - } else if relationship.Kind == "belongs_to" { - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(foreignKey); ok { - tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) - } - } - scope.Err(tx.Find(value).Error) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { - tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - if relationship.PolymorphicType != "" { - tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) - } - scope.Err(tx.Find(value).Error) - } - } else { - sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) - scope.Err(tx.Where(sql, fromField.Field.Interface()).Find(value).Error) - } - return scope - } else if toField != nil { - sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) - scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) - return scope - } - } - - scope.Err(fmt.Errorf("invalid association %v", foreignKeys)) - return scope -} - -// getTableOptions return the table options string or an empty string if the table options does not exist -func (scope *Scope) getTableOptions() string { - tableOptions, ok := scope.Get("gorm:table_options") - if !ok { - return "" - } - return " " + tableOptions.(string) -} - -func (scope *Scope) createJoinTable(field *StructField) { - if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { - joinTableHandler := relationship.JoinTableHandler - joinTable := joinTableHandler.Table(scope.db) - if !scope.Dialect().HasTable(joinTable) { - toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} - - var sqlTypes, primaryKeys []string - for idx, fieldName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") - foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") - sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) - } - } - - for idx, fieldName := range relationship.AssociationForeignFieldNames { - if field, ok := toScope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") - foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") - sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) - } - } - - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v))%s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) - } - scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) - } -} - -func (scope *Scope) createTable() *Scope { - var tags []string - var primaryKeys []string - var primaryKeyInColumnType = false - for _, field := range scope.GetModelStruct().StructFields { - if field.IsNormal { - sqlTag := scope.Dialect().DataTypeOf(field) - - // Check if the primary key constraint was specified as - // part of the column type. If so, we can only support - // one column as the primary key. - if strings.Contains(strings.ToLower(sqlTag), "primary key") { - primaryKeyInColumnType = true - } - - tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag) - } - - if field.IsPrimaryKey { - primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) - } - scope.createJoinTable(field) - } - - var primaryKeyStr string - if len(primaryKeys) > 0 && !primaryKeyInColumnType { - primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) - } - - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() - - scope.autoIndex() - return scope -} - -func (scope *Scope) dropTable() *Scope { - scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() - return scope -} - -func (scope *Scope) modifyColumn(column string, typ string) { - scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ)) -} - -func (scope *Scope) dropColumn(column string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() -} - -func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { - if scope.Dialect().HasIndex(scope.TableName(), indexName) { - return - } - - var columns []string - for _, name := range column { - columns = append(columns, scope.quoteIfPossible(name)) - } - - sqlCreate := "CREATE INDEX" - if unique { - sqlCreate = "CREATE UNIQUE INDEX" - } - - scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec() -} - -func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { - // Compatible with old generated key - keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - - if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { - return - } - var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() -} - -func (scope *Scope) removeForeignKey(field string, dest string) { - keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { - return - } - var mysql mysql - var query string - if scope.Dialect().GetName() == mysql.GetName() { - query = `ALTER TABLE %s DROP FOREIGN KEY %s;` - } else { - query = `ALTER TABLE %s DROP CONSTRAINT %s;` - } - - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() -} - -func (scope *Scope) removeIndex(indexName string) { - scope.Dialect().RemoveIndex(scope.TableName(), indexName) -} - -func (scope *Scope) autoMigrate() *Scope { - tableName := scope.TableName() - quotedTableName := scope.QuotedTableName() - - if !scope.Dialect().HasTable(tableName) { - scope.createTable() - } else { - for _, field := range scope.GetModelStruct().StructFields { - if !scope.Dialect().HasColumn(tableName, field.DBName) { - if field.IsNormal { - sqlTag := scope.Dialect().DataTypeOf(field) - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() - } - } - scope.createJoinTable(field) - } - scope.autoIndex() - } - return scope -} - -func (scope *Scope) autoIndex() *Scope { - var indexes = map[string][]string{} - var uniqueIndexes = map[string][]string{} - - for _, field := range scope.GetStructFields() { - if name, ok := field.TagSettingsGet("INDEX"); ok { - names := strings.Split(name, ",") - - for _, name := range names { - if name == "INDEX" || name == "" { - name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName) - } - name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName) - indexes[name] = append(indexes[name], column) - } - } - - if name, ok := field.TagSettingsGet("UNIQUE_INDEX"); ok { - names := strings.Split(name, ",") - - for _, name := range names { - if name == "UNIQUE_INDEX" || name == "" { - name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName) - } - name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName) - uniqueIndexes[name] = append(uniqueIndexes[name], column) - } - } - } - - for name, columns := range indexes { - if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddIndex(name, columns...); db.Error != nil { - scope.db.AddError(db.Error) - } - } - - for name, columns := range uniqueIndexes { - if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil { - scope.db.AddError(db.Error) - } - } - - return scope -} - -func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { - resultMap := make(map[string][]interface{}) - for _, value := range values { - indirectValue := indirect(reflect.ValueOf(value)) - - switch indirectValue.Kind() { - case reflect.Slice: - for i := 0; i < indirectValue.Len(); i++ { - var result []interface{} - var object = indirect(indirectValue.Index(i)) - var hasValue = false - for _, column := range columns { - field := object.FieldByName(column) - if hasValue || !isBlank(field) { - hasValue = true - } - result = append(result, field.Interface()) - } - - if hasValue { - h := fmt.Sprint(result...) - if _, exist := resultMap[h]; !exist { - resultMap[h] = result - } - } - } - case reflect.Struct: - var result []interface{} - var hasValue = false - for _, column := range columns { - field := indirectValue.FieldByName(column) - if hasValue || !isBlank(field) { - hasValue = true - } - result = append(result, field.Interface()) - } - - if hasValue { - h := fmt.Sprint(result...) - if _, exist := resultMap[h]; !exist { - resultMap[h] = result - } - } - } - } - for _, v := range resultMap { - results = append(results, v) - } - return -} - -func (scope *Scope) getColumnAsScope(column string) *Scope { - indirectScopeValue := scope.IndirectValue() - - switch indirectScopeValue.Kind() { - case reflect.Slice: - if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok { - fieldType := fieldStruct.Type - if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr { - fieldType = fieldType.Elem() - } - - resultsMap := map[interface{}]bool{} - results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() - - for i := 0; i < indirectScopeValue.Len(); i++ { - result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column)) - - if result.Kind() == reflect.Slice { - for j := 0; j < result.Len(); j++ { - if elem := result.Index(j); elem.CanAddr() && resultsMap[elem.Addr()] != true { - resultsMap[elem.Addr()] = true - results = reflect.Append(results, elem.Addr()) - } - } - } else if result.CanAddr() && resultsMap[result.Addr()] != true { - resultsMap[result.Addr()] = true - results = reflect.Append(results, result.Addr()) - } - } - return scope.New(results.Interface()) - } - case reflect.Struct: - if field := indirectScopeValue.FieldByName(column); field.CanAddr() { - return scope.New(field.Addr().Interface()) - } - } - return nil -} - -func (scope *Scope) hasConditions() bool { - return !scope.PrimaryKeyZero() || - len(scope.Search.whereConditions) > 0 || - len(scope.Search.orConditions) > 0 || - len(scope.Search.notConditions) > 0 -} diff --git a/scope_test.go b/scope_test.go deleted file mode 100644 index f7f1ed08..00000000 --- a/scope_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package gorm_test - -import ( - "encoding/hex" - "math/rand" - "strings" - "testing" - - "github.com/jinzhu/gorm" -) - -func NameIn1And2(d *gorm.DB) *gorm.DB { - return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}) -} - -func NameIn2And3(d *gorm.DB) *gorm.DB { - return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}) -} - -func NameIn(names []string) func(d *gorm.DB) *gorm.DB { - return func(d *gorm.DB) *gorm.DB { - return d.Where("name in (?)", names) - } -} - -func TestScopes(t *testing.T) { - user1 := User{Name: "ScopeUser1", Age: 1} - user2 := User{Name: "ScopeUser2", Age: 1} - user3 := User{Name: "ScopeUser3", Age: 2} - DB.Save(&user1).Save(&user2).Save(&user3) - - var users1, users2, users3 []User - DB.Scopes(NameIn1And2).Find(&users1) - if len(users1) != 2 { - t.Errorf("Should found two users's name in 1, 2") - } - - DB.Scopes(NameIn1And2, NameIn2And3).Find(&users2) - if len(users2) != 1 { - t.Errorf("Should found one user's name is 2") - } - - DB.Scopes(NameIn([]string{user1.Name, user3.Name})).Find(&users3) - if len(users3) != 2 { - t.Errorf("Should found two users's name in 1, 3") - } -} - -func randName() string { - data := make([]byte, 8) - rand.Read(data) - - return "n-" + hex.EncodeToString(data) -} - -func TestValuer(t *testing.T) { - name := randName() - - origUser := User{Name: name, Age: 1, Password: EncryptedData("pass1"), PasswordHash: []byte("abc")} - if err := DB.Save(&origUser).Error; err != nil { - t.Errorf("No error should happen when saving user, but got %v", err) - } - - var user2 User - if err := DB.Where("name = ? AND password = ? AND password_hash = ?", name, EncryptedData("pass1"), []byte("abc")).First(&user2).Error; err != nil { - t.Errorf("No error should happen when querying user with valuer, but got %v", err) - } -} - -func TestFailedValuer(t *testing.T) { - name := randName() - - err := DB.Exec("INSERT INTO users(name, password) VALUES(?, ?)", name, EncryptedData("xpass1")).Error - - if err == nil { - t.Errorf("There should be an error should happen when insert data") - } else if !strings.HasPrefix(err.Error(), "Should not start with") { - t.Errorf("The error should be returned from Valuer, but get %v", err) - } -} - -func TestDropTableWithTableOptions(t *testing.T) { - type UserWithOptions struct { - gorm.Model - } - DB.AutoMigrate(&UserWithOptions{}) - - DB = DB.Set("gorm:table_options", "CHARSET=utf8") - err := DB.DropTable(&UserWithOptions{}).Error - if err != nil { - t.Errorf("Table must be dropped, got error %s", err) - } -} diff --git a/search.go b/search.go deleted file mode 100644 index 7c4cc184..00000000 --- a/search.go +++ /dev/null @@ -1,153 +0,0 @@ -package gorm - -import ( - "fmt" -) - -type search struct { - db *DB - whereConditions []map[string]interface{} - orConditions []map[string]interface{} - notConditions []map[string]interface{} - havingConditions []map[string]interface{} - joinConditions []map[string]interface{} - initAttrs []interface{} - assignAttrs []interface{} - selects map[string]interface{} - omits []string - orders []interface{} - preload []searchPreload - offset interface{} - limit interface{} - group string - tableName string - raw bool - Unscoped bool - ignoreOrderQuery bool -} - -type searchPreload struct { - schema string - conditions []interface{} -} - -func (s *search) clone() *search { - clone := *s - return &clone -} - -func (s *search) Where(query interface{}, values ...interface{}) *search { - s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Not(query interface{}, values ...interface{}) *search { - s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Or(query interface{}, values ...interface{}) *search { - s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Attrs(attrs ...interface{}) *search { - s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...)) - return s -} - -func (s *search) Assign(attrs ...interface{}) *search { - s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...)) - return s -} - -func (s *search) Order(value interface{}, reorder ...bool) *search { - if len(reorder) > 0 && reorder[0] { - s.orders = []interface{}{} - } - - if value != nil && value != "" { - s.orders = append(s.orders, value) - } - return s -} - -func (s *search) Select(query interface{}, args ...interface{}) *search { - s.selects = map[string]interface{}{"query": query, "args": args} - return s -} - -func (s *search) Omit(columns ...string) *search { - s.omits = columns - return s -} - -func (s *search) Limit(limit interface{}) *search { - s.limit = limit - return s -} - -func (s *search) Offset(offset interface{}) *search { - s.offset = offset - return s -} - -func (s *search) Group(query string) *search { - s.group = s.getInterfaceAsSQL(query) - return s -} - -func (s *search) Having(query interface{}, values ...interface{}) *search { - if val, ok := query.(*SqlExpr); ok { - s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args}) - } else { - s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) - } - return s -} - -func (s *search) Joins(query string, values ...interface{}) *search { - s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Preload(schema string, values ...interface{}) *search { - var preloads []searchPreload - for _, preload := range s.preload { - if preload.schema != schema { - preloads = append(preloads, preload) - } - } - preloads = append(preloads, searchPreload{schema, values}) - s.preload = preloads - return s -} - -func (s *search) Raw(b bool) *search { - s.raw = b - return s -} - -func (s *search) unscoped() *search { - s.Unscoped = true - return s -} - -func (s *search) Table(name string) *search { - s.tableName = name - return s -} - -func (s *search) getInterfaceAsSQL(value interface{}) (str string) { - switch value.(type) { - case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - str = fmt.Sprintf("%v", value) - default: - s.db.AddError(ErrInvalidSQL) - } - - if str == "-1" { - return "" - } - return -} diff --git a/search_test.go b/search_test.go deleted file mode 100644 index 4db7ab6a..00000000 --- a/search_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package gorm - -import ( - "reflect" - "testing" -) - -func TestCloneSearch(t *testing.T) { - s := new(search) - s.Where("name = ?", "jinzhu").Order("name").Attrs("name", "jinzhu").Select("name, age") - - s1 := s.clone() - s1.Where("age = ?", 20).Order("age").Attrs("email", "a@e.org").Select("email") - - if reflect.DeepEqual(s.whereConditions, s1.whereConditions) { - t.Errorf("Where should be copied") - } - - if reflect.DeepEqual(s.orders, s1.orders) { - t.Errorf("Order should be copied") - } - - if reflect.DeepEqual(s.initAttrs, s1.initAttrs) { - t.Errorf("InitAttrs should be copied") - } - - if reflect.DeepEqual(s.Select, s1.Select) { - t.Errorf("selectStr should be copied") - } -} diff --git a/test_all.sh b/test_all.sh deleted file mode 100755 index 5cfb3321..00000000 --- a/test_all.sh +++ /dev/null @@ -1,5 +0,0 @@ -dialects=("postgres" "mysql" "mssql" "sqlite") - -for dialect in "${dialects[@]}" ; do - DEBUG=false GORM_DIALECT=${dialect} go test -done diff --git a/update_test.go b/update_test.go deleted file mode 100644 index 85d53e5f..00000000 --- a/update_test.go +++ /dev/null @@ -1,465 +0,0 @@ -package gorm_test - -import ( - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -func TestUpdate(t *testing.T) { - product1 := Product{Code: "product1code"} - product2 := Product{Code: "product2code"} - - DB.Save(&product1).Save(&product2).Update("code", "product2newcode") - - if product2.Code != "product2newcode" { - t.Errorf("Record should be updated") - } - - DB.First(&product1, product1.Id) - DB.First(&product2, product2.Id) - updatedAt1 := product1.UpdatedAt - - if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() { - t.Errorf("Product1 should not be updated") - } - - if !DB.First(&Product{}, "code = ?", "product2code").RecordNotFound() { - t.Errorf("Product2's code should be updated") - } - - if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() { - t.Errorf("Product2's code should be updated") - } - - DB.Table("products").Where("code in (?)", []string{"product1code"}).Update("code", "product1newcode") - - var product4 Product - DB.First(&product4, product1.Id) - if updatedAt1.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should be updated if something changed") - } - - if !DB.First(&Product{}, "code = 'product1code'").RecordNotFound() { - t.Errorf("Product1's code should be updated") - } - - if DB.First(&Product{}, "code = 'product1newcode'").RecordNotFound() { - t.Errorf("Product should not be changed to 789") - } - - if DB.Model(product2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil { - t.Error("No error should raise when update with CamelCase") - } - - if DB.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { - t.Error("No error should raise when update_column with CamelCase") - } - - var products []Product - DB.Find(&products) - if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) { - t.Error("RowsAffected should be correct when do batch update") - } - - DB.First(&product4, product4.Id) - updatedAt4 := product4.UpdatedAt - DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50)) - var product5 Product - DB.First(&product5, product4.Id) - if product5.Price != product4.Price+100-50 { - t.Errorf("Update with expression") - } - if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { - t.Errorf("Update with expression should update UpdatedAt") - } -} - -func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { - animal := Animal{Name: "Ferdinand"} - DB.Save(&animal) - updatedAt1 := animal.UpdatedAt - - DB.Save(&animal).Update("name", "Francis") - - if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should not be updated if nothing changed") - } - - var animals []Animal - DB.Find(&animals) - if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { - t.Error("RowsAffected should be correct when do batch update") - } - - animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone) - DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched - DB.First(&animal, animal.Counter) - if animal.Name != "galeone" { - t.Errorf("Name fields shouldn't be changed if untouched, but got %v", animal.Name) - } - - // When changing a field with a default value, the change must occur - animal.Name = "amazing horse" - DB.Save(&animal) - DB.First(&animal, animal.Counter) - if animal.Name != "amazing horse" { - t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name) - } - - // When changing a field with a default value with blank value - animal.Name = "" - DB.Save(&animal) - DB.First(&animal, animal.Counter) - if animal.Name != "" { - t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name) - } -} - -func TestUpdates(t *testing.T) { - product1 := Product{Code: "product1code", Price: 10} - product2 := Product{Code: "product2code", Price: 10} - DB.Save(&product1).Save(&product2) - DB.Model(&product1).Updates(map[string]interface{}{"code": "product1newcode", "price": 100}) - if product1.Code != "product1newcode" || product1.Price != 100 { - t.Errorf("Record should be updated also with map") - } - - DB.First(&product1, product1.Id) - DB.First(&product2, product2.Id) - updatedAt2 := product2.UpdatedAt - - if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() { - t.Errorf("Product2 should not be updated") - } - - if DB.First(&Product{}, "code = ?", "product1newcode").RecordNotFound() { - t.Errorf("Product1 should be updated") - } - - DB.Table("products").Where("code in (?)", []string{"product2code"}).Updates(Product{Code: "product2newcode"}) - if !DB.First(&Product{}, "code = 'product2code'").RecordNotFound() { - t.Errorf("Product2's code should be updated") - } - - var product4 Product - DB.First(&product4, product2.Id) - if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should be updated if something changed") - } - - if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() { - t.Errorf("product2's code should be updated") - } - - updatedAt4 := product4.UpdatedAt - DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)}) - var product5 Product - DB.First(&product5, product4.Id) - if product5.Price != product4.Price+100 { - t.Errorf("Updates with expression") - } - // product4's UpdatedAt will be reset when updating - if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { - t.Errorf("Updates with expression should update UpdatedAt") - } -} - -func TestUpdateColumn(t *testing.T) { - product1 := Product{Code: "product1code", Price: 10} - product2 := Product{Code: "product2code", Price: 20} - DB.Save(&product1).Save(&product2).UpdateColumn(map[string]interface{}{"code": "product2newcode", "price": 100}) - if product2.Code != "product2newcode" || product2.Price != 100 { - t.Errorf("product 2 should be updated with update column") - } - - var product3 Product - DB.First(&product3, product1.Id) - if product3.Code != "product1code" || product3.Price != 10 { - t.Errorf("product 1 should not be updated") - } - - DB.First(&product2, product2.Id) - updatedAt2 := product2.UpdatedAt - DB.Model(product2).UpdateColumn("code", "update_column_new") - var product4 Product - DB.First(&product4, product2.Id) - if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should not be updated with update column") - } - - DB.Model(&product4).UpdateColumn("price", gorm.Expr("price + 100 - 50")) - var product5 Product - DB.First(&product5, product4.Id) - if product5.Price != product4.Price+100-50 { - t.Errorf("UpdateColumn with expression") - } - if product5.UpdatedAt.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("UpdateColumn with expression should not update UpdatedAt") - } -} - -func TestSelectWithUpdate(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update") - DB.Create(user) - - var reloadUser User - DB.First(&reloadUser, user.Id) - reloadUser.Name = "new_name" - reloadUser.Age = 50 - reloadUser.BillingAddress = Address{Address1: "New Billing Address"} - reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"} - reloadUser.CreditCard = CreditCard{Number: "987654321"} - reloadUser.Emails = []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - } - reloadUser.Company = Company{Name: "new company"} - - DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name == user.Name || queryUser.Age != user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 || - queryUser.ShippingAddressId != user.ShippingAddressId || - queryUser.CreditCard.ID == user.CreditCard.ID || - len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id { - t.Errorf("Should only update selected relationships") - } -} - -func TestSelectWithUpdateWithMap(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{ - "Name": "new_name", - "Age": 50, - "BillingAddress": Address{Address1: "New Billing Address"}, - "ShippingAddress": Address{Address1: "New ShippingAddress Address"}, - "CreditCard": CreditCard{Number: "987654321"}, - "Emails": []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - }, - "Company": Company{Name: "new company"}, - } - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name == user.Name || queryUser.Age != user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 || - queryUser.ShippingAddressId != user.ShippingAddressId || - queryUser.CreditCard.ID == user.CreditCard.ID || - len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id { - t.Errorf("Should only update selected relationships") - } -} - -func TestOmitWithUpdate(t *testing.T) { - user := getPreparedUser("omit_user", "omit_with_update") - DB.Create(user) - - var reloadUser User - DB.First(&reloadUser, user.Id) - reloadUser.Name = "new_name" - reloadUser.Age = 50 - reloadUser.BillingAddress = Address{Address1: "New Billing Address"} - reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"} - reloadUser.CreditCard = CreditCard{Number: "987654321"} - reloadUser.Emails = []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - } - reloadUser.Company = Company{Name: "new company"} - - DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name != user.Name || queryUser.Age == user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 || - queryUser.ShippingAddressId == user.ShippingAddressId || - queryUser.CreditCard.ID != user.CreditCard.ID || - len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { - t.Errorf("Should only update relationships that not omitted") - } -} - -func TestOmitWithUpdateWithMap(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{ - "Name": "new_name", - "Age": 50, - "BillingAddress": Address{Address1: "New Billing Address"}, - "ShippingAddress": Address{Address1: "New ShippingAddress Address"}, - "CreditCard": CreditCard{Number: "987654321"}, - "Emails": []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - }, - "Company": Company{Name: "new company"}, - } - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name != user.Name || queryUser.Age == user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 || - queryUser.ShippingAddressId == user.ShippingAddressId || - queryUser.CreditCard.ID != user.CreditCard.ID || - len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { - t.Errorf("Should only update relationships not omitted") - } -} - -func TestSelectWithUpdateColumn(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Select("Name").UpdateColumn(updateValues) - - var queryUser User - DB.First(&queryUser, user.Id) - - if queryUser.Name == user.Name || queryUser.Age != user.Age { - t.Errorf("Should only update users with name column") - } -} - -func TestOmitWithUpdateColumn(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Omit("Name").UpdateColumn(updateValues) - - var queryUser User - DB.First(&queryUser, user.Id) - - if queryUser.Name != user.Name || queryUser.Age == user.Age { - t.Errorf("Should omit name column when update user") - } -} - -func TestUpdateColumnsSkipsAssociations(t *testing.T) { - user := getPreparedUser("update_columns_user", "special_role") - user.Age = 99 - address1 := "first street" - user.BillingAddress = Address{Address1: address1} - DB.Save(user) - - // Update a single field of the user and verify that the changed address is not stored. - newAge := int64(100) - user.BillingAddress.Address1 = "second street" - db := DB.Model(user).UpdateColumns(User{Age: newAge}) - if db.RowsAffected != 1 { - t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected) - } - - // Verify that Age now=`newAge`. - freshUser := &User{Id: user.Id} - DB.First(freshUser) - if freshUser.Age != newAge { - t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, freshUser.Age) - } - - // Verify that user's BillingAddress.Address1 is not changed and is still "first street". - DB.First(&freshUser.BillingAddress, freshUser.BillingAddressID) - if freshUser.BillingAddress.Address1 != address1 { - t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1) - } -} - -func TestUpdatesWithBlankValues(t *testing.T) { - product := Product{Code: "product1", Price: 10} - DB.Save(&product) - - DB.Model(&Product{Id: product.Id}).Updates(&Product{Price: 100}) - - var product1 Product - DB.First(&product1, product.Id) - - if product1.Code != "product1" || product1.Price != 100 { - t.Errorf("product's code should not be updated") - } -} - -type ElementWithIgnoredField struct { - Id int64 - Value string - IgnoredField int64 `sql:"-"` -} - -func (e ElementWithIgnoredField) TableName() string { - return "element_with_ignored_field" -} - -func TestUpdatesTableWithIgnoredValues(t *testing.T) { - elem := ElementWithIgnoredField{Value: "foo", IgnoredField: 10} - DB.Save(&elem) - - DB.Table(elem.TableName()). - Where("id = ?", elem.Id). - // DB.Model(&ElementWithIgnoredField{Id: elem.Id}). - Updates(&ElementWithIgnoredField{Value: "bar", IgnoredField: 100}) - - var elem1 ElementWithIgnoredField - err := DB.First(&elem1, elem.Id).Error - if err != nil { - t.Errorf("error getting an element from database: %s", err.Error()) - } - - if elem1.IgnoredField != 0 { - t.Errorf("element's ignored field should not be updated") - } -} - -func TestUpdateDecodeVirtualAttributes(t *testing.T) { - var user = User{ - Name: "jinzhu", - IgnoreMe: 88, - } - - DB.Save(&user) - - DB.Model(&user).Updates(User{Name: "jinzhu2", IgnoreMe: 100}) - - if user.IgnoreMe != 100 { - t.Errorf("should decode virtual attributes to struct, so it could be used in callbacks") - } -} diff --git a/utils.go b/utils.go deleted file mode 100644 index d2ae9465..00000000 --- a/utils.go +++ /dev/null @@ -1,226 +0,0 @@ -package gorm - -import ( - "database/sql/driver" - "fmt" - "reflect" - "regexp" - "runtime" - "strings" - "sync" - "time" -) - -// NowFunc returns current time, this function is exported in order to be able -// to give the flexibility to the developer to customize it according to their -// needs, e.g: -// gorm.NowFunc = func() time.Time { -// return time.Now().UTC() -// } -var NowFunc = func() time.Time { - return time.Now() -} - -// Copied from golint -var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} -var commonInitialismsReplacer *strings.Replacer - -var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`) -var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`) - -func init() { - var commonInitialismsForReplacer []string - for _, initialism := range commonInitialisms { - commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) - } - commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) -} - -type safeMap struct { - m map[string]string - l *sync.RWMutex -} - -func (s *safeMap) Set(key string, value string) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeMap) Get(key string) string { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newSafeMap() *safeMap { - return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} -} - -// SQL expression -type SqlExpr struct { - expr string - args []interface{} -} - -// Expr generate raw SQL expression, for example: -// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) -func Expr(expression string, args ...interface{}) *SqlExpr { - return &SqlExpr{expr: expression, args: args} -} - -func indirect(reflectValue reflect.Value) reflect.Value { - for reflectValue.Kind() == reflect.Ptr { - reflectValue = reflectValue.Elem() - } - return reflectValue -} - -func toQueryMarks(primaryValues [][]interface{}) string { - var results []string - - for _, primaryValue := range primaryValues { - var marks []string - for range primaryValue { - marks = append(marks, "?") - } - - if len(marks) > 1 { - results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) - } else { - results = append(results, strings.Join(marks, "")) - } - } - return strings.Join(results, ",") -} - -func toQueryCondition(scope *Scope, columns []string) string { - var newColumns []string - for _, column := range columns { - newColumns = append(newColumns, scope.Quote(column)) - } - - if len(columns) > 1 { - return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) - } - return strings.Join(newColumns, ",") -} - -func toQueryValues(values [][]interface{}) (results []interface{}) { - for _, value := range values { - for _, v := range value { - results = append(results, v) - } - } - return -} - -func fileWithLineNum() string { - for i := 2; i < 15; i++ { - _, file, line, ok := runtime.Caller(i) - if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) { - return fmt.Sprintf("%v:%v", file, line) - } - } - return "" -} - -func isBlank(value reflect.Value) bool { - switch value.Kind() { - case reflect.String: - return value.Len() == 0 - case reflect.Bool: - return !value.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return value.Int() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return value.Uint() == 0 - case reflect.Float32, reflect.Float64: - return value.Float() == 0 - case reflect.Interface, reflect.Ptr: - return value.IsNil() - } - - return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) -} - -func toSearchableMap(attrs ...interface{}) (result interface{}) { - if len(attrs) > 1 { - if str, ok := attrs[0].(string); ok { - result = map[string]interface{}{str: attrs[1]} - } - } else if len(attrs) == 1 { - if attr, ok := attrs[0].(map[string]interface{}); ok { - result = attr - } - - if attr, ok := attrs[0].(interface{}); ok { - result = attr - } - } - return -} - -func equalAsString(a interface{}, b interface{}) bool { - return toString(a) == toString(b) -} - -func toString(str interface{}) string { - if values, ok := str.([]interface{}); ok { - var results []string - for _, value := range values { - results = append(results, toString(value)) - } - return strings.Join(results, "_") - } else if bytes, ok := str.([]byte); ok { - return string(bytes) - } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { - return fmt.Sprintf("%v", reflectValue.Interface()) - } - return "" -} - -func makeSlice(elemType reflect.Type) interface{} { - if elemType.Kind() == reflect.Slice { - elemType = elemType.Elem() - } - sliceType := reflect.SliceOf(elemType) - slice := reflect.New(sliceType) - slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) - return slice.Interface() -} - -func strInSlice(a string, list []string) bool { - for _, b := range list { - if b == a { - return true - } - } - return false -} - -// getValueFromFields return given fields's value -func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) { - // If value is a nil pointer, Indirect returns a zero Value! - // Therefor we need to check for a zero value, - // as FieldByName could panic - if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { - for _, fieldName := range fieldNames { - if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() { - result := fieldValue.Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() - } - results = append(results, result) - } - } - } - return -} - -func addExtraSpaceIfExist(str string) string { - if str != "" { - return " " + str - } - return "" -} diff --git a/wercker.yml b/wercker.yml deleted file mode 100644 index c74fa4d4..00000000 --- a/wercker.yml +++ /dev/null @@ -1,154 +0,0 @@ -# use the default golang container from Docker Hub -box: golang - -services: - - name: mariadb - id: mariadb:latest - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql - id: mysql:latest - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql57 - id: mysql:5.7 - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql56 - id: mysql:5.6 - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: postgres - id: postgres:latest - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: postgres96 - id: postgres:9.6 - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: postgres95 - id: postgres:9.5 - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: postgres94 - id: postgres:9.4 - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: postgres93 - id: postgres:9.3 - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: mssql - id: mcmoe/mssqldocker:latest - env: - ACCEPT_EULA: Y - SA_PASSWORD: LoremIpsum86 - MSSQL_DB: gorm - MSSQL_USER: gorm - MSSQL_PASSWORD: LoremIpsum86 - -# The steps that will be executed in the build pipeline -build: - # The steps that will be executed on build - steps: - # Sets the go workspace and places you package - # at the right place in the workspace tree - - setup-go-workspace - - # Gets the dependencies - - script: - name: go get - code: | - cd $WERCKER_SOURCE_DIR - go version - go get -t -v ./... - - # Build the project - - script: - name: go build - code: | - go build ./... - - # Test the project - - script: - name: test sqlite - code: | - go test -race -v ./... - - - script: - name: test mariadb - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - - - script: - name: test mysql - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - - - script: - name: test mysql5.7 - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - - - script: - name: test mysql5.6 - code: | - GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... - - - script: - name: test postgres - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test postgres96 - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test postgres95 - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test postgres94 - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test postgres93 - code: | - GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... - - - script: - name: test mssql - code: | - GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test -race ./... - - - script: - name: codecov - code: | - go test -race -coverprofile=coverage.txt -covermode=atomic ./... - bash <(curl -s https://codecov.io/bash) From 8eae7e4ab934df7ca645f563a74e33a3e7367e74 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 28 Jan 2020 23:01:35 +0800 Subject: [PATCH 207/881] Add migrator --- .gitignore | 1 + go.mod | 2 ++ gorm.go | 46 ++++++++++++++++++++++++++++++++++++++++++++ logger/logger.go | 5 +++++ migrator.go | 44 ++++++++++++++++++++++++++++++++++++++++++ migrator/migrator.go | 14 ++++++++++++++ 6 files changed, 112 insertions(+) create mode 100644 gorm.go create mode 100644 logger/logger.go create mode 100644 migrator.go create mode 100644 migrator/migrator.go diff --git a/.gitignore b/.gitignore index 117f92f5..912d58f7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +TODO documents coverage.txt _book diff --git a/go.mod b/go.mod index 0b3e3065..d0a110ba 100644 --- a/go.mod +++ b/go.mod @@ -1 +1,3 @@ module github.com/jinzhu/gorm + +go 1.13 diff --git a/gorm.go b/gorm.go new file mode 100644 index 00000000..274f4c62 --- /dev/null +++ b/gorm.go @@ -0,0 +1,46 @@ +package gorm + +import ( + "time" + + "github.com/jinzhu/gorm/logger" +) + +// Config GORM config +type Config struct { + // Set true to use singular table name, by default, GORM will pluralize your struct's name as table name + // Refer https://github.com/jinzhu/inflection for inflection rules + SingularTable bool + + // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity + // You can cancel it by setting `SkipDefaultTransaction` to true + SkipDefaultTransaction bool + + // Logger + Logger logger.Interface + + // NowFunc the function to be used when creating a new timestamp + NowFunc func() time.Time +} + +// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt +// It may be embeded into your model or you may build your own model without it +// type User struct { +// gorm.Model +// } +type Model struct { + ID uint `gorm:"primary_key"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time `gorm:"index"` +} + +// Dialector GORM database dialector +type Dialector interface { + Migrator() Migrator +} + +// DB GORM DB definition +type DB struct { + *Config +} diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 00000000..87b71013 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,5 @@ +package logger + +// Interface logger interface +type Interface interface { +} diff --git a/migrator.go b/migrator.go new file mode 100644 index 00000000..c21cda42 --- /dev/null +++ b/migrator.go @@ -0,0 +1,44 @@ +package gorm + +import ( + "database/sql" +) + +// ViewOption view option +type ViewOption struct { + Replace bool + CheckOption string + Query *DB +} + +type Migrator interface { + // AutoMigrate + AutoMigrate(dst ...interface{}) error + + // Tables + CreateTable(dst ...interface{}) error + DropTable(dst ...interface{}) error + HasTable(dst ...interface{}) error + RenameTable(oldName, newName string) error + + // Columns + AddColumn(dst interface{}, field string) error + DropColumn(dst interface{}, field string) error + AlterColumn(dst interface{}, field string) error + RenameColumn(dst interface{}, oldName, field string) error + ColumnTypes(dst interface{}) ([]*sql.ColumnType, error) + + // Views + CreateView(name string, option ViewOption) error + DropView(name string) error + + // Constraints + CreateConstraint(dst interface{}, name string) error + DropConstraint(dst interface{}, name string) error + + // Indexes + CreateIndex(dst interface{}, name string) error + DropIndex(dst interface{}, name string) error + HasIndex(dst interface{}, name string) error + RenameIndex(dst interface{}, oldName, newName string) error +} diff --git a/migrator/migrator.go b/migrator/migrator.go new file mode 100644 index 00000000..0ff83ac1 --- /dev/null +++ b/migrator/migrator.go @@ -0,0 +1,14 @@ +package migrator + +import "github.com/jinzhu/gorm" + +// Migrator migrator struct +type Migrator struct { + *Config +} + +// Config schema config +type Config struct { + CheckExistsBeforeDropping bool + DB *gorm.DB +} From b9cce2be6a47d4cd8ea11674226bf67d8e39082d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 29 Jan 2020 19:22:44 +0800 Subject: [PATCH 208/881] Add clause, DB API, model definition --- .gitignore | 2 +- association.go | 5 ++ chainable_api.go | 138 ++++++++++++++++++++++++++++++ clause/clause.go | 53 ++++++++++++ clause/expr.go | 19 ++++ clause/operators.go | 195 ++++++++++++++++++++++++++++++++++++++++++ finisher_api.go | 154 +++++++++++++++++++++++++++++++++ gorm.go | 62 ++++++++++++++ model/model.go | 37 ++++++++ model/relationship.go | 37 ++++++++ statement.go | 68 +++++++++++++++ 11 files changed, 769 insertions(+), 1 deletion(-) create mode 100644 association.go create mode 100644 chainable_api.go create mode 100644 clause/clause.go create mode 100644 clause/expr.go create mode 100644 clause/operators.go create mode 100644 finisher_api.go create mode 100644 model/model.go create mode 100644 model/relationship.go create mode 100644 statement.go diff --git a/.gitignore b/.gitignore index 912d58f7..c14d6005 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -TODO +TODO* documents coverage.txt _book diff --git a/association.go b/association.go new file mode 100644 index 00000000..17f8f4a5 --- /dev/null +++ b/association.go @@ -0,0 +1,5 @@ +package gorm + +// Association Mode contains some helper methods to handle relationship things easily. +type Association struct { +} diff --git a/chainable_api.go b/chainable_api.go new file mode 100644 index 00000000..d8f2116c --- /dev/null +++ b/chainable_api.go @@ -0,0 +1,138 @@ +package gorm + +// Model specify the model you would like to run db operations +// // update all users's name to `hello` +// db.Model(&User{}).Update("name", "hello") +// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` +// db.Model(&user).Update("name", "hello") +func (db *DB) Model(value interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Table specify the table you would like to run db operations +func (db *DB) Table(name string) (tx *DB) { + tx = db.getInstance() + return +} + +// Select specify fields that you want when querying, creating, updating +func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Omit specify fields that you want to ignore when creating, updating and querying +func (db *DB) Omit(columns ...string) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Not add NOT condition +func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Or add OR conditions +func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Joins specify Joins conditions +// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Group specify the group method on the find +func (db *DB) Group(column string) (tx *DB) { + tx = db.getInstance() + return +} + +// Having specify HAVING conditions for GROUP BY +func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Order specify order when retrieve records from database +// db.Order("name DESC") +// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression +func (db *DB) Order(value interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Limit specify the number of records to be retrieved +func (db *DB) Limit(limit int64) (tx *DB) { + tx = db.getInstance() + return +} + +// Offset specify the number of records to skip before starting to return the records +func (db *DB) Offset(offset int64) (tx *DB) { + tx = db.getInstance() + return +} + +// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically +// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { +// return db.Where("amount > ?", 1000) +// } +// +// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { +// return func (db *gorm.DB) *gorm.DB { +// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) +// } +// } +// +// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) +// Refer https://jinzhu.github.io/gorm/crud.html#scopes +func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { + for _, f := range funcs { + db = f(db) + } + return db +} + +//Preloads only preloads relations, don`t touch out +func (db *DB) Preloads(out interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Preload preload associations with given conditions +// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) +func (db *DB) Preload(column string, conditions ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Assign(attrs ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Unscoped() (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} diff --git a/clause/clause.go b/clause/clause.go new file mode 100644 index 00000000..4495a9d5 --- /dev/null +++ b/clause/clause.go @@ -0,0 +1,53 @@ +package clause + +// Builder builder interface +type BuilderInterface interface { + Write(sql ...string) error + WriteQuoted(field interface{}) error + AddVar(vars ...interface{}) string + Quote(field interface{}) string +} + +// Interface clause interface +type Interface interface { + Name() string + Build(builder BuilderInterface) +} + +// NegationBuilder negation condition builder +type NegationBuilder interface { + NegationBuild(builder BuilderInterface) +} + +// Where where clause +type Where struct { +} + +// Select select attrs when querying, updating, creating +type Select struct { + Omit bool +} + +// Join join clause +type Join struct { +} + +// GroupBy group by clause +type GroupBy struct { +} + +// Having having clause +type Having struct { +} + +// Order order clause +type Order struct { +} + +// Limit limit clause +type Limit struct { +} + +// Offset offset clause +type Offset struct { +} diff --git a/clause/expr.go b/clause/expr.go new file mode 100644 index 00000000..94edb702 --- /dev/null +++ b/clause/expr.go @@ -0,0 +1,19 @@ +package clause + +type ExprInterface interface { +} + +type Expr struct { +} + +type Average struct { +} + +type Minimum struct { +} + +type Maximum struct { +} + +type Sum struct { +} diff --git a/clause/operators.go b/clause/operators.go new file mode 100644 index 00000000..331abea7 --- /dev/null +++ b/clause/operators.go @@ -0,0 +1,195 @@ +package clause + +import "strings" + +type AddConditions []Interface + +func (cs AddConditions) Build(builder BuilderInterface) { + for idx, c := range cs { + if idx > 0 { + builder.Write(" AND ") + } + c.Build(builder) + } +} + +type ORConditions []Interface + +func (cs ORConditions) Build(builder BuilderInterface) { + for idx, c := range cs { + if idx > 0 { + builder.Write(" OR ") + } + c.Build(builder) + } +} + +type NotConditions []Interface + +func (cs NotConditions) Build(builder BuilderInterface) { + for idx, c := range cs { + if idx > 0 { + builder.Write(" AND ") + } + + if negationBuilder, ok := c.(NegationBuilder); ok { + negationBuilder.NegationBuild(builder) + } else { + builder.Write(" NOT ") + c.Build(builder) + } + } +} + +// Raw raw sql for where +type Raw struct { + SQL string + Values []interface{} +} + +func (raw Raw) Build(builder BuilderInterface) { + sql := raw.SQL + for _, v := range raw.Values { + sql = strings.Replace(sql, " ? ", " "+builder.AddVar(v)+" ", 1) + } + builder.Write(sql) +} + +// IN Whether a value is within a set of values +type IN struct { + Column interface{} + Values []interface{} +} + +func (in IN) Build(builder BuilderInterface) { + builder.WriteQuoted(in.Column) + + if len(in.Values) == 0 { + builder.Write(" IN (NULL)") + } else { + builder.Write(" IN (", builder.AddVar(in.Values...), ")") + } +} + +func (in IN) NegationBuild(builder BuilderInterface) { + if len(in.Values) != 0 { + builder.WriteQuoted(in.Column) + builder.Write(" NOT IN (", builder.AddVar(in.Values...), ")") + } +} + +// Eq equal to for where +type Eq struct { + Column interface{} + Value interface{} +} + +func (eq Eq) Build(builder BuilderInterface) { + builder.WriteQuoted(eq.Column) + + if eq.Value == nil { + builder.Write(" IS NULL") + } else { + builder.Write(" = ", builder.AddVar(eq.Value)) + } +} + +func (eq Eq) NegationBuild(builder BuilderInterface) { + Neq{eq.Column, eq.Value}.Build(builder) +} + +// Neq not equal to for where +type Neq struct { + Column interface{} + Value interface{} +} + +func (neq Neq) Build(builder BuilderInterface) { + builder.WriteQuoted(neq.Column) + + if neq.Value == nil { + builder.Write(" IS NOT NULL") + } else { + builder.Write(" <> ", builder.AddVar(neq.Value)) + } +} + +func (neq Neq) NegationBuild(builder BuilderInterface) { + Eq{neq.Column, neq.Value}.Build(builder) +} + +// Gt greater than for where +type Gt struct { + Column interface{} + Value interface{} +} + +func (gt Gt) Build(builder BuilderInterface) { + builder.WriteQuoted(gt.Column) + builder.Write(" > ", builder.AddVar(gt.Value)) +} + +func (gt Gt) NegationBuild(builder BuilderInterface) { + Lte{gt.Column, gt.Value}.Build(builder) +} + +// Gte greater than or equal to for where +type Gte struct { + Column interface{} + Value interface{} +} + +func (gte Gte) Build(builder BuilderInterface) { + builder.WriteQuoted(gte.Column) + builder.Write(" >= ", builder.AddVar(gte.Value)) +} + +func (gte Gte) NegationBuild(builder BuilderInterface) { + Lt{gte.Column, gte.Value}.Build(builder) +} + +// Lt less than for where +type Lt struct { + Column interface{} + Value interface{} +} + +func (lt Lt) Build(builder BuilderInterface) { + builder.WriteQuoted(lt.Column) + builder.Write(" < ", builder.AddVar(lt.Value)) +} + +func (lt Lt) NegationBuild(builder BuilderInterface) { + Gte{lt.Column, lt.Value}.Build(builder) +} + +// Lte less than or equal to for where +type Lte struct { + Column interface{} + Value interface{} +} + +func (lte Lte) Build(builder BuilderInterface) { + builder.WriteQuoted(lte.Column) + builder.Write(" <= ", builder.AddVar(lte.Value)) +} + +func (lte Lte) NegationBuild(builder BuilderInterface) { + Gt{lte.Column, lte.Value}.Build(builder) +} + +// Like whether string matches regular expression +type Like struct { + Column interface{} + Value interface{} +} + +func (like Like) Build(builder BuilderInterface) { + builder.WriteQuoted(like.Column) + builder.Write(" LIKE ", builder.AddVar(like.Value)) +} + +func (like Like) NegationBuild(builder BuilderInterface) { + builder.WriteQuoted(like.Column) + builder.Write(" NOT LIKE ", builder.AddVar(like.Value)) +} diff --git a/finisher_api.go b/finisher_api.go new file mode 100644 index 00000000..687843e3 --- /dev/null +++ b/finisher_api.go @@ -0,0 +1,154 @@ +package gorm + +import ( + "database/sql" +) + +func (db *DB) Count(sql string, values ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// First find first record that match given conditions, order by primary key +func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Take return a record that match given conditions, the order will depend on the database implementation +func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Last find last record that match given conditions, order by primary key +func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Find find records that match given conditions +func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Scan scan value to a struct + +func (db *DB) Row() *sql.Row { + // TODO + return nil +} + +func (db *DB) Rows() (*sql.Rows, error) { + // TODO + return nil, nil +} + +func (db *DB) Scan(dest interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error { + return nil +} + +// Create insert the value into database +func (db *DB) Create(value interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Save update value in database, if the value doesn't have primary key, will insert it +func (db *DB) Save(value interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update +func (db *DB) Update(column string, value interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update +func (db *DB) Updates(values interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) UpdateColumn(attrs ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) UpdateColumns(values interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition +func (db *DB) Delete(value interface{}, where ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Related(value interface{}, foreignKeys ...string) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { + panicked := true + tx := db.Begin(opts...) + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + tx.Rollback() + } + }() + + err = fc(tx) + + if err == nil { + err = tx.Commit().Error + } + + panicked = false + return +} + +func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Commit() (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Rollback() (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Association(column string) *Association { + return nil +} diff --git a/gorm.go b/gorm.go index 274f4c62..1b6d88df 100644 --- a/gorm.go +++ b/gorm.go @@ -1,8 +1,10 @@ package gorm import ( + "context" "time" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/logger" ) @@ -38,9 +40,69 @@ type Model struct { // Dialector GORM database dialector type Dialector interface { Migrator() Migrator + BindVar(stmt Statement, v interface{}) string +} + +// Result +type Result struct { + Error error + RowsAffected int64 + Statement *Statement } // DB GORM DB definition type DB struct { *Config + Dialector + Result + Context context.Context +} + +// WithContext change current instance db's context to ctx +func (db *DB) WithContext(ctx context.Context) *DB { + tx := db.getInstance() + tx.Context = ctx + return tx +} + +// Set store value with key into current db instance's context +func (db *DB) Set(key string, value interface{}) *DB { + tx := db.getInstance() + tx.Statement.Settings.Store(key, value) + return tx +} + +// Get get value with key from current db instance's context +func (db *DB) Get(key string) (interface{}, bool) { + if db.Statement != nil { + return db.Statement.Settings.Load(key) + } + return nil, false +} + +func (db *DB) Close() *DB { + // TODO + return db +} + +func (db *DB) getInstance() *DB { + // db.Result.Statement == nil means root DB + if db.Result.Statement == nil { + return &DB{ + Config: db.Config, + Dialector: db.Dialector, + Context: context.Background(), + Result: Result{ + Statement: &Statement{DB: db, Clauses: map[string][]clause.Interface{}}, + }, + } + } + + return db +} + +// Debug start debug mode +func (db *DB) Debug() (tx *DB) { + tx = db.getInstance() + return } diff --git a/model/model.go b/model/model.go new file mode 100644 index 00000000..316f3ab5 --- /dev/null +++ b/model/model.go @@ -0,0 +1,37 @@ +package model + +import ( + "reflect" +) + +type Model struct { + ModelType reflect.Type + Table string + PrioritizedPrimaryField *Field + PrimaryFields []*Field + Fields []*Field + FieldsByName map[string]*Field + FieldsByDBName map[string]*Field + Relationships Relationships +} + +type Field struct { + Name string + DBName string + DataType reflect.Type + DBDataType string + Tag reflect.StructTag + TagSettings map[string]string + PrimaryKey bool + AutoIncrement bool + Creatable bool + Updatable bool + Nullable bool + Unique bool + Precision int + Size int + HasDefaultValue bool + DefaultValue string + StructField reflect.StructField + Model *Model +} diff --git a/model/relationship.go b/model/relationship.go new file mode 100644 index 00000000..60b0751e --- /dev/null +++ b/model/relationship.go @@ -0,0 +1,37 @@ +package model + +// RelationshipType relationship type +type RelationshipType string + +const ( + HasOneRel RelationshipType = "has_one" // HasOneRel has one relationship + HasManyRel RelationshipType = "has_many" // HasManyRel has many relationship + BelongsToRel RelationshipType = "belongs_to" // BelongsToRel belongs to relationship + Many2ManyRel RelationshipType = "many_to_many" // Many2ManyRel many to many relationship +) + +type Relationships struct { + HasOne map[string]*Relationship + BelongsTo map[string]*Relationship + HasMany map[string]*Relationship + Many2Many map[string]*Relationship +} + +type Relationship struct { + Type RelationshipType + ForeignKeys []*RelationField // self + AssociationForeignKeys []*RelationField // association + JoinTable *JoinTable +} + +type RelationField struct { + *Field + PolymorphicField *Field + PolymorphicValue string +} + +type JoinTable struct { + Table string + ForeignKeys []*RelationField + AssociationForeignKeys []*RelationField +} diff --git a/statement.go b/statement.go new file mode 100644 index 00000000..21e95e11 --- /dev/null +++ b/statement.go @@ -0,0 +1,68 @@ +package gorm + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "strings" + "sync" + + "github.com/jinzhu/gorm/clause" +) + +// Statement statement +type Statement struct { + Dest interface{} + Table interface{} + Clauses map[string][]clause.Interface + Settings sync.Map + Context context.Context + DB *DB + StatementBuilder +} + +// StatementBuilder statement builder +type StatementBuilder struct { + SQL bytes.Buffer + Vars []interface{} + NamedVars []sql.NamedArg +} + +// Write write string +func (stmt Statement) Write(sql ...string) (err error) { + for _, s := range sql { + _, err = stmt.SQL.WriteString(s) + } + return +} + +// WriteQuoted write quoted field +func (stmt Statement) WriteQuoted(field interface{}) (err error) { + _, err = stmt.SQL.WriteString(stmt.Quote(field)) + return +} + +// Write write string +func (stmt Statement) AddVar(vars ...interface{}) string { + var placeholders []string + for _, v := range vars { + if namedArg, ok := v.(sql.NamedArg); ok && len(namedArg.Name) > 0 { + stmt.NamedVars = append(stmt.NamedVars, namedArg) + placeholders = append(placeholders, "@"+namedArg.Name) + } else { + placeholders = append(placeholders, stmt.DB.Dialector.BindVar(stmt, v)) + } + } + return strings.Join(placeholders, ",") +} + +// Quote returns quoted value +func (stmt Statement) Quote(field interface{}) (str string) { + return fmt.Sprint(field) +} + +// AddClause add clause +func (s Statement) AddClause(clause clause.Interface) { + s.Clauses[clause.Name()] = append(s.Clauses[clause.Name()], clause) +} From 85bfd175c6bf18cecac0e9c7403b3956a6c4ed54 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jan 2020 03:03:06 +0800 Subject: [PATCH 209/881] Implement build conditions --- chainable_api.go | 2 ++ clause/clause.go | 5 +++ clause/operators.go | 66 ++++++++++++++++++++++++++++++---- gorm.go | 8 ++++- statement.go | 88 +++++++++++++++++++++++++++++++++++++++++---- 5 files changed, 154 insertions(+), 15 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index d8f2116c..75e0fa2a 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -7,12 +7,14 @@ package gorm // db.Model(&user).Update("name", "hello") func (db *DB) Model(value interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.Model = value return } // Table specify the table you would like to run db operations func (db *DB) Table(name string) (tx *DB) { tx = db.getInstance() + tx.Statement.Table = name return } diff --git a/clause/clause.go b/clause/clause.go index 4495a9d5..1afb120e 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -11,6 +11,11 @@ type BuilderInterface interface { // Interface clause interface type Interface interface { Name() string + Builder +} + +// Builder condition builder +type Builder interface { Build(builder BuilderInterface) } diff --git a/clause/operators.go b/clause/operators.go index 331abea7..a6bdb4aa 100644 --- a/clause/operators.go +++ b/clause/operators.go @@ -2,7 +2,8 @@ package clause import "strings" -type AddConditions []Interface +type Condition Builder +type AddConditions []Condition func (cs AddConditions) Build(builder BuilderInterface) { for idx, c := range cs { @@ -13,7 +14,7 @@ func (cs AddConditions) Build(builder BuilderInterface) { } } -type ORConditions []Interface +type ORConditions []Condition func (cs ORConditions) Build(builder BuilderInterface) { for idx, c := range cs { @@ -24,7 +25,7 @@ func (cs ORConditions) Build(builder BuilderInterface) { } } -type NotConditions []Interface +type NotConditions []Condition func (cs NotConditions) Build(builder BuilderInterface) { for idx, c := range cs { @@ -64,16 +65,22 @@ type IN struct { func (in IN) Build(builder BuilderInterface) { builder.WriteQuoted(in.Column) - if len(in.Values) == 0 { + switch len(in.Values) { + case 0: builder.Write(" IN (NULL)") - } else { + case 1: + builder.Write(" = ", builder.AddVar(in.Values...)) + default: builder.Write(" IN (", builder.AddVar(in.Values...), ")") } } func (in IN) NegationBuild(builder BuilderInterface) { - if len(in.Values) != 0 { - builder.WriteQuoted(in.Column) + switch len(in.Values) { + case 0: + case 1: + builder.Write(" <> ", builder.AddVar(in.Values...)) + default: builder.Write(" NOT IN (", builder.AddVar(in.Values...), ")") } } @@ -193,3 +200,48 @@ func (like Like) NegationBuild(builder BuilderInterface) { builder.WriteQuoted(like.Column) builder.Write(" NOT LIKE ", builder.AddVar(like.Value)) } + +// Map +type Map map[interface{}]interface{} + +func (m Map) Build(builder BuilderInterface) { + // TODO +} + +func (m Map) NegationBuild(builder BuilderInterface) { + // TODO +} + +// Attrs +type Attrs struct { + Value interface{} + Select []string + Omit []string +} + +func (attrs Attrs) Build(builder BuilderInterface) { + // TODO + // builder.WriteQuoted(like.Column) + // builder.Write(" LIKE ", builder.AddVar(like.Value)) +} + +func (attrs Attrs) NegationBuild(builder BuilderInterface) { + // TODO +} + +// ID +type ID struct { + Value []interface{} +} + +func (id ID) Build(builder BuilderInterface) { + if len(id.Value) == 1 { + } + // TODO + // builder.WriteQuoted(like.Column) + // builder.Write(" LIKE ", builder.AddVar(like.Value)) +} + +func (id ID) NegationBuild(builder BuilderInterface) { + // TODO +} diff --git a/gorm.go b/gorm.go index 1b6d88df..86d5af9a 100644 --- a/gorm.go +++ b/gorm.go @@ -93,7 +93,7 @@ func (db *DB) getInstance() *DB { Dialector: db.Dialector, Context: context.Background(), Result: Result{ - Statement: &Statement{DB: db, Clauses: map[string][]clause.Interface{}}, + Statement: &Statement{DB: db, Clauses: map[string][]clause.Condition{}}, }, } } @@ -106,3 +106,9 @@ func (db *DB) Debug() (tx *DB) { tx = db.getInstance() return } + +// Session start session mode +func (db *DB) Session() (tx *DB) { + tx = db.getInstance() + return +} diff --git a/statement.go b/statement.go index 21e95e11..5dab59b3 100644 --- a/statement.go +++ b/statement.go @@ -4,7 +4,9 @@ import ( "bytes" "context" "database/sql" + "database/sql/driver" "fmt" + "strconv" "strings" "sync" @@ -13,9 +15,10 @@ import ( // Statement statement type Statement struct { + Model interface{} Dest interface{} - Table interface{} - Clauses map[string][]clause.Interface + Table string + Clauses map[string][]clause.Condition Settings sync.Map Context context.Context DB *DB @@ -45,16 +48,29 @@ func (stmt Statement) WriteQuoted(field interface{}) (err error) { // Write write string func (stmt Statement) AddVar(vars ...interface{}) string { - var placeholders []string - for _, v := range vars { + var placeholders strings.Builder + for idx, v := range vars { + if idx > 0 { + placeholders.WriteByte(',') + } + if namedArg, ok := v.(sql.NamedArg); ok && len(namedArg.Name) > 0 { stmt.NamedVars = append(stmt.NamedVars, namedArg) - placeholders = append(placeholders, "@"+namedArg.Name) + placeholders.WriteByte('@') + placeholders.WriteString(namedArg.Name) + } else if arrs, ok := v.([]interface{}); ok { + placeholders.WriteByte('(') + if len(arrs) > 0 { + placeholders.WriteString(stmt.AddVar(arrs...)) + } else { + placeholders.WriteString("NULL") + } + placeholders.WriteByte(')') } else { - placeholders = append(placeholders, stmt.DB.Dialector.BindVar(stmt, v)) + placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v)) } } - return strings.Join(placeholders, ",") + return placeholders.String() } // Quote returns quoted value @@ -66,3 +82,61 @@ func (stmt Statement) Quote(field interface{}) (str string) { func (s Statement) AddClause(clause clause.Interface) { s.Clauses[clause.Name()] = append(s.Clauses[clause.Name()], clause) } + +// BuildCondtions build conditions +func (s Statement) BuildCondtions(query interface{}, args ...interface{}) (conditions []clause.Condition) { + if sql, ok := query.(string); ok { + if i, err := strconv.Atoi(sql); err != nil { + query = i + } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { + return []clause.Condition{clause.Raw{SQL: sql, Values: args}} + } + } + + args = append([]interface{}{query}, args...) + for _, arg := range args { + if valuer, ok := arg.(driver.Valuer); ok { + arg, _ = valuer.Value() + } + + switch v := arg.(type) { + case clause.Builder: + conditions = append(conditions, v) + case *DB: + if v.Statement == nil { + if cs, ok := v.Statement.Clauses["WHERE"]; ok { + conditions = append(conditions, cs...) + } + } + case map[interface{}]interface{}: + var clauseMap = clause.Map{} + for i, j := range v { + clauseMap[i] = j + } + conditions = append(conditions, clauseMap) + case map[string]string: + var clauseMap = clause.Map{} + for i, j := range v { + clauseMap[i] = j + } + conditions = append(conditions, clauseMap) + case map[string]interface{}: + var clauseMap = clause.Map{} + for i, j := range v { + clauseMap[i] = j + } + conditions = append(conditions, clauseMap) + default: + // TODO check is struct + // struct, slice -> ids + } + } + + if len(conditions) == 0 { + conditions = append(conditions, clause.ID{Value: args}) + } + return conditions +} + +func (s Statement) AddError(err error) { +} From 9d5b9834d91f81400d5c8561c46746153bc2d176 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jan 2020 15:14:48 +0800 Subject: [PATCH 210/881] Refactor builder --- chainable_api.go | 39 +++++++-- clause/clause.go | 128 ++++++++++++++++++++++++++---- clause/expr.go | 19 ----- clause/expression.go | 30 +++++++ clause/{operators.go => query.go} | 73 +++++++++-------- errors.go | 22 +++++ finisher_api.go | 21 +++-- gorm.go | 105 +++++++++++++----------- logger/logger.go | 9 +++ model.go | 15 ++++ statement.go | 107 ++++++++++++++++++------- 11 files changed, 412 insertions(+), 156 deletions(-) delete mode 100644 clause/expr.go create mode 100644 clause/expression.go rename clause/{operators.go => query.go} (66%) create mode 100644 errors.go create mode 100644 model.go diff --git a/chainable_api.go b/chainable_api.go index 75e0fa2a..95d5975c 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -1,5 +1,7 @@ package gorm +import "github.com/jinzhu/gorm/clause" + // Model specify the model you would like to run db operations // // update all users's name to `hello` // db.Model(&User{}).Update("name", "hello") @@ -11,6 +13,27 @@ func (db *DB) Model(value interface{}) (tx *DB) { return } +// Clauses Add clauses +func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { + tx = db.getInstance() + var whereConds []interface{} + + for _, cond := range conds { + if c, ok := cond.(clause.Interface); ok { + tx.Statement.AddClause(c) + } else { + whereConds = append(whereConds, cond) + } + } + + if len(whereConds) > 0 { + tx.Statement.AddClause(clause.Where{ + AndConditions: tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...), + }) + } + return +} + // Table specify the table you would like to run db operations func (db *DB) Table(name string) (tx *DB) { tx = db.getInstance() @@ -32,18 +55,25 @@ func (db *DB) Omit(columns ...string) (tx *DB) { func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.AddClause(clause.Where{AndConditions: tx.Statement.BuildCondtion(query, args...)}) return } // Not add NOT condition func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.AddClause(clause.Where{ + AndConditions: []clause.Expression{clause.NotConditions(tx.Statement.BuildCondtion(query, args...))}, + }) return } // Or add OR conditions func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.AddClause(clause.Where{ + ORConditions: []clause.ORConditions{tx.Statement.BuildCondtion(query, args...)}, + }) return } @@ -98,20 +128,13 @@ func (db *DB) Offset(offset int64) (tx *DB) { // } // // db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) -// Refer https://jinzhu.github.io/gorm/crud.html#scopes -func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { +func (db *DB) Scopes(funcs ...func(*DB) *DB) *DB { for _, f := range funcs { db = f(db) } return db } -//Preloads only preloads relations, don`t touch out -func (db *DB) Preloads(out interface{}) (tx *DB) { - tx = db.getInstance() - return -} - // Preload preload associations with given conditions // db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) func (db *DB) Preload(column string, conditions ...interface{}) (tx *DB) { diff --git a/clause/clause.go b/clause/clause.go index 1afb120e..b0507f44 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -1,31 +1,131 @@ package clause -// Builder builder interface -type BuilderInterface interface { - Write(sql ...string) error - WriteQuoted(field interface{}) error - AddVar(vars ...interface{}) string - Quote(field interface{}) string +// Clause +type Clause struct { + Name string // WHERE + Priority float64 + BeforeExpressions []Expression + AfterNameExpressions []Expression + AfterExpressions []Expression + Expression Expression + Builder ClauseBuilder +} + +// ClauseBuilder clause builder, allows to custmize how to build clause +type ClauseBuilder interface { + Build(Clause, Builder) +} + +// Build build clause +func (c Clause) Build(builder Builder) { + if c.Builder != nil { + c.Builder.Build(c, builder) + } else { + builders := c.BeforeExpressions + if c.Name != "" { + builders = append(builders, Expr{c.Name}) + } + + builders = append(builders, c.AfterNameExpressions...) + if c.Expression != nil { + builders = append(builders, c.Expression) + } + + for idx, expr := range append(builders, c.AfterExpressions...) { + if idx != 0 { + builder.WriteByte(' ') + } + expr.Build(builder) + } + } } // Interface clause interface type Interface interface { Name() string - Builder + Build(Builder) + MergeExpression(Expression) } -// Builder condition builder -type Builder interface { - Build(builder BuilderInterface) +type OverrideNameInterface interface { + OverrideName() string } -// NegationBuilder negation condition builder -type NegationBuilder interface { - NegationBuild(builder BuilderInterface) -} +//////////////////////////////////////////////////////////////////////////////// +// Predefined Clauses +//////////////////////////////////////////////////////////////////////////////// // Where where clause type Where struct { + AndConditions AddConditions + ORConditions []ORConditions + Builders []Expression +} + +func (where Where) Name() string { + return "WHERE" +} + +func (where Where) Build(builder Builder) { + var withConditions bool + + if len(where.AndConditions) > 0 { + withConditions = true + where.AndConditions.Build(builder) + } + + if len(where.Builders) > 0 { + for _, b := range where.Builders { + if withConditions { + builder.Write(" AND ") + } + withConditions = true + b.Build(builder) + } + } + + var singleOrConditions []ORConditions + for _, or := range where.ORConditions { + if len(or) == 1 { + if withConditions { + builder.Write(" OR ") + or.Build(builder) + } else { + singleOrConditions = append(singleOrConditions, or) + } + } else { + withConditions = true + builder.Write(" AND (") + or.Build(builder) + builder.WriteByte(')') + } + } + + for _, or := range singleOrConditions { + if withConditions { + builder.Write(" AND ") + or.Build(builder) + } else { + withConditions = true + or.Build(builder) + } + } + + if !withConditions { + builder.Write(" FALSE") + } + + return +} + +func (where Where) MergeExpression(expr Expression) { + if w, ok := expr.(Where); ok { + where.AndConditions = append(where.AndConditions, w.AndConditions...) + where.ORConditions = append(where.ORConditions, w.ORConditions...) + where.Builders = append(where.Builders, w.Builders...) + } else { + where.Builders = append(where.Builders, expr) + } } // Select select attrs when querying, updating, creating diff --git a/clause/expr.go b/clause/expr.go deleted file mode 100644 index 94edb702..00000000 --- a/clause/expr.go +++ /dev/null @@ -1,19 +0,0 @@ -package clause - -type ExprInterface interface { -} - -type Expr struct { -} - -type Average struct { -} - -type Minimum struct { -} - -type Maximum struct { -} - -type Sum struct { -} diff --git a/clause/expression.go b/clause/expression.go new file mode 100644 index 00000000..17313d43 --- /dev/null +++ b/clause/expression.go @@ -0,0 +1,30 @@ +package clause + +// Expression expression interface +type Expression interface { + Build(builder Builder) +} + +// NegationExpressionBuilder negation expression builder +type NegationExpressionBuilder interface { + NegationBuild(builder Builder) +} + +// Builder builder interface +type Builder interface { + WriteByte(byte) error + Write(sql ...string) error + WriteQuoted(field interface{}) error + AddVar(vars ...interface{}) string + Quote(field interface{}) string +} + +// Expr raw expression +type Expr struct { + Value string +} + +// Build build raw expression +func (expr Expr) Build(builder Builder) { + builder.Write(expr.Value) +} diff --git a/clause/operators.go b/clause/query.go similarity index 66% rename from clause/operators.go rename to clause/query.go index a6bdb4aa..949678d9 100644 --- a/clause/operators.go +++ b/clause/query.go @@ -2,10 +2,13 @@ package clause import "strings" -type Condition Builder -type AddConditions []Condition +//////////////////////////////////////////////////////////////////////////////// +// Query Expressions +//////////////////////////////////////////////////////////////////////////////// -func (cs AddConditions) Build(builder BuilderInterface) { +type AddConditions []Expression + +func (cs AddConditions) Build(builder Builder) { for idx, c := range cs { if idx > 0 { builder.Write(" AND ") @@ -14,9 +17,9 @@ func (cs AddConditions) Build(builder BuilderInterface) { } } -type ORConditions []Condition +type ORConditions []Expression -func (cs ORConditions) Build(builder BuilderInterface) { +func (cs ORConditions) Build(builder Builder) { for idx, c := range cs { if idx > 0 { builder.Write(" OR ") @@ -25,15 +28,15 @@ func (cs ORConditions) Build(builder BuilderInterface) { } } -type NotConditions []Condition +type NotConditions []Expression -func (cs NotConditions) Build(builder BuilderInterface) { +func (cs NotConditions) Build(builder Builder) { for idx, c := range cs { if idx > 0 { builder.Write(" AND ") } - if negationBuilder, ok := c.(NegationBuilder); ok { + if negationBuilder, ok := c.(NegationExpressionBuilder); ok { negationBuilder.NegationBuild(builder) } else { builder.Write(" NOT ") @@ -42,15 +45,15 @@ func (cs NotConditions) Build(builder BuilderInterface) { } } -// Raw raw sql for where -type Raw struct { +// String raw sql for where +type String struct { SQL string Values []interface{} } -func (raw Raw) Build(builder BuilderInterface) { - sql := raw.SQL - for _, v := range raw.Values { +func (str String) Build(builder Builder) { + sql := str.SQL + for _, v := range str.Values { sql = strings.Replace(sql, " ? ", " "+builder.AddVar(v)+" ", 1) } builder.Write(sql) @@ -62,7 +65,7 @@ type IN struct { Values []interface{} } -func (in IN) Build(builder BuilderInterface) { +func (in IN) Build(builder Builder) { builder.WriteQuoted(in.Column) switch len(in.Values) { @@ -75,7 +78,7 @@ func (in IN) Build(builder BuilderInterface) { } } -func (in IN) NegationBuild(builder BuilderInterface) { +func (in IN) NegationBuild(builder Builder) { switch len(in.Values) { case 0: case 1: @@ -91,7 +94,7 @@ type Eq struct { Value interface{} } -func (eq Eq) Build(builder BuilderInterface) { +func (eq Eq) Build(builder Builder) { builder.WriteQuoted(eq.Column) if eq.Value == nil { @@ -101,7 +104,7 @@ func (eq Eq) Build(builder BuilderInterface) { } } -func (eq Eq) NegationBuild(builder BuilderInterface) { +func (eq Eq) NegationBuild(builder Builder) { Neq{eq.Column, eq.Value}.Build(builder) } @@ -111,7 +114,7 @@ type Neq struct { Value interface{} } -func (neq Neq) Build(builder BuilderInterface) { +func (neq Neq) Build(builder Builder) { builder.WriteQuoted(neq.Column) if neq.Value == nil { @@ -121,7 +124,7 @@ func (neq Neq) Build(builder BuilderInterface) { } } -func (neq Neq) NegationBuild(builder BuilderInterface) { +func (neq Neq) NegationBuild(builder Builder) { Eq{neq.Column, neq.Value}.Build(builder) } @@ -131,12 +134,12 @@ type Gt struct { Value interface{} } -func (gt Gt) Build(builder BuilderInterface) { +func (gt Gt) Build(builder Builder) { builder.WriteQuoted(gt.Column) builder.Write(" > ", builder.AddVar(gt.Value)) } -func (gt Gt) NegationBuild(builder BuilderInterface) { +func (gt Gt) NegationBuild(builder Builder) { Lte{gt.Column, gt.Value}.Build(builder) } @@ -146,12 +149,12 @@ type Gte struct { Value interface{} } -func (gte Gte) Build(builder BuilderInterface) { +func (gte Gte) Build(builder Builder) { builder.WriteQuoted(gte.Column) builder.Write(" >= ", builder.AddVar(gte.Value)) } -func (gte Gte) NegationBuild(builder BuilderInterface) { +func (gte Gte) NegationBuild(builder Builder) { Lt{gte.Column, gte.Value}.Build(builder) } @@ -161,12 +164,12 @@ type Lt struct { Value interface{} } -func (lt Lt) Build(builder BuilderInterface) { +func (lt Lt) Build(builder Builder) { builder.WriteQuoted(lt.Column) builder.Write(" < ", builder.AddVar(lt.Value)) } -func (lt Lt) NegationBuild(builder BuilderInterface) { +func (lt Lt) NegationBuild(builder Builder) { Gte{lt.Column, lt.Value}.Build(builder) } @@ -176,12 +179,12 @@ type Lte struct { Value interface{} } -func (lte Lte) Build(builder BuilderInterface) { +func (lte Lte) Build(builder Builder) { builder.WriteQuoted(lte.Column) builder.Write(" <= ", builder.AddVar(lte.Value)) } -func (lte Lte) NegationBuild(builder BuilderInterface) { +func (lte Lte) NegationBuild(builder Builder) { Gt{lte.Column, lte.Value}.Build(builder) } @@ -191,12 +194,12 @@ type Like struct { Value interface{} } -func (like Like) Build(builder BuilderInterface) { +func (like Like) Build(builder Builder) { builder.WriteQuoted(like.Column) builder.Write(" LIKE ", builder.AddVar(like.Value)) } -func (like Like) NegationBuild(builder BuilderInterface) { +func (like Like) NegationBuild(builder Builder) { builder.WriteQuoted(like.Column) builder.Write(" NOT LIKE ", builder.AddVar(like.Value)) } @@ -204,11 +207,11 @@ func (like Like) NegationBuild(builder BuilderInterface) { // Map type Map map[interface{}]interface{} -func (m Map) Build(builder BuilderInterface) { +func (m Map) Build(builder Builder) { // TODO } -func (m Map) NegationBuild(builder BuilderInterface) { +func (m Map) NegationBuild(builder Builder) { // TODO } @@ -219,13 +222,13 @@ type Attrs struct { Omit []string } -func (attrs Attrs) Build(builder BuilderInterface) { +func (attrs Attrs) Build(builder Builder) { // TODO // builder.WriteQuoted(like.Column) // builder.Write(" LIKE ", builder.AddVar(like.Value)) } -func (attrs Attrs) NegationBuild(builder BuilderInterface) { +func (attrs Attrs) NegationBuild(builder Builder) { // TODO } @@ -234,7 +237,7 @@ type ID struct { Value []interface{} } -func (id ID) Build(builder BuilderInterface) { +func (id ID) Build(builder Builder) { if len(id.Value) == 1 { } // TODO @@ -242,6 +245,6 @@ func (id ID) Build(builder BuilderInterface) { // builder.Write(" LIKE ", builder.AddVar(like.Value)) } -func (id ID) NegationBuild(builder BuilderInterface) { +func (id ID) NegationBuild(builder Builder) { // TODO } diff --git a/errors.go b/errors.go new file mode 100644 index 00000000..c66408be --- /dev/null +++ b/errors.go @@ -0,0 +1,22 @@ +package gorm + +import "errors" + +var ( + // ErrRecordNotFound record not found error + ErrRecordNotFound = errors.New("record not found") + // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL + ErrInvalidSQL = errors.New("invalid SQL") + // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` + ErrInvalidTransaction = errors.New("no valid transaction") + // ErrUnaddressable unaddressable value + ErrUnaddressable = errors.New("using unaddressable value") +) + +type Error struct { + Err error +} + +func (e Error) Unwrap() error { + return e.Err +} diff --git a/finisher_api.go b/finisher_api.go index 687843e3..2668e1fe 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -33,8 +33,6 @@ func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) { return } -// Scan scan value to a struct - func (db *DB) Row() *sql.Row { // TODO return nil @@ -45,6 +43,7 @@ func (db *DB) Rows() (*sql.Rows, error) { return nil, nil } +// Scan scan value to a struct func (db *DB) Scan(dest interface{}) (tx *DB) { tx = db.getInstance() return @@ -88,12 +87,12 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) { return } -func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { +func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() return } -func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) { +func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() return } @@ -109,6 +108,16 @@ func (db *DB) Related(value interface{}, foreignKeys ...string) (tx *DB) { return } +//Preloads only preloads relations, don`t touch out +func (db *DB) Preloads(out interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Association(column string) *Association { + return nil +} + func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { panicked := true tx := db.Begin(opts...) @@ -148,7 +157,3 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() return } - -func (db *DB) Association(column string) *Association { - return nil -} diff --git a/gorm.go b/gorm.go index 86d5af9a..838f2862 100644 --- a/gorm.go +++ b/gorm.go @@ -25,44 +25,72 @@ type Config struct { NowFunc func() time.Time } -// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt -// It may be embeded into your model or you may build your own model without it -// type User struct { -// gorm.Model -// } -type Model struct { - ID uint `gorm:"primary_key"` - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time `gorm:"index"` -} - // Dialector GORM database dialector type Dialector interface { Migrator() Migrator BindVar(stmt Statement, v interface{}) string } -// Result -type Result struct { - Error error - RowsAffected int64 - Statement *Statement -} - // DB GORM DB definition type DB struct { *Config Dialector - Result + Instance + clone bool +} + +// Session session config when create new session +type Session struct { Context context.Context + Logger logger.Interface + NowFunc func() time.Time +} + +// Open initialize db session based on dialector +func Open(dialector Dialector, config *Config) (db *DB, err error) { + return &DB{ + Config: config, + Dialector: dialector, + clone: true, + }, nil +} + +// Session create new db session +func (db *DB) Session(config *Session) *DB { + var ( + tx = db.getInstance() + txConfig = *tx.Config + ) + + if config.Context != nil { + tx.Context = config.Context + } + + if config.Logger != nil { + txConfig.Logger = config.Logger + } + + if config.NowFunc != nil { + txConfig.NowFunc = config.NowFunc + } + + tx.Config = &txConfig + tx.clone = true + return tx } // WithContext change current instance db's context to ctx func (db *DB) WithContext(ctx context.Context) *DB { - tx := db.getInstance() - tx.Context = ctx - return tx + return db.Session(&Session{Context: ctx}) +} + +// Debug start debug mode +func (db *DB) Debug() (tx *DB) { + return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)}) +} + +func (db *DB) Close() error { + return nil } // Set store value with key into current db instance's context @@ -80,35 +108,22 @@ func (db *DB) Get(key string) (interface{}, bool) { return nil, false } -func (db *DB) Close() *DB { - // TODO - return db -} - func (db *DB) getInstance() *DB { - // db.Result.Statement == nil means root DB - if db.Result.Statement == nil { + if db.clone { + ctx := db.Instance.Context + if ctx == nil { + ctx = context.Background() + } + return &DB{ Config: db.Config, Dialector: db.Dialector, - Context: context.Background(), - Result: Result{ - Statement: &Statement{DB: db, Clauses: map[string][]clause.Condition{}}, + Instance: Instance{ + Context: ctx, + Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}}, }, } } return db } - -// Debug start debug mode -func (db *DB) Debug() (tx *DB) { - tx = db.getInstance() - return -} - -// Session start session mode -func (db *DB) Session() (tx *DB) { - tx = db.getInstance() - return -} diff --git a/logger/logger.go b/logger/logger.go index 87b71013..389a6763 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -1,5 +1,14 @@ package logger +type LogLevel int + +const ( + Info LogLevel = iota + 1 + Warn + Error +) + // Interface logger interface type Interface interface { + LogMode(LogLevel) Interface } diff --git a/model.go b/model.go new file mode 100644 index 00000000..118d8f14 --- /dev/null +++ b/model.go @@ -0,0 +1,15 @@ +package gorm + +import "time" + +// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt +// It may be embeded into your model or you may build your own model without it +// type User struct { +// gorm.Model +// } +type Model struct { + ID uint `gorm:"primary_key"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time `gorm:"index"` +} diff --git a/statement.go b/statement.go index 5dab59b3..30d45b98 100644 --- a/statement.go +++ b/statement.go @@ -1,7 +1,6 @@ package gorm import ( - "bytes" "context" "database/sql" "database/sql/driver" @@ -13,25 +12,43 @@ import ( "github.com/jinzhu/gorm/clause" ) -// Statement statement -type Statement struct { - Model interface{} - Dest interface{} - Table string - Clauses map[string][]clause.Condition - Settings sync.Map - Context context.Context - DB *DB - StatementBuilder +// Instance db instance +type Instance struct { + Error error + RowsAffected int64 + Context context.Context + Statement *Statement } -// StatementBuilder statement builder -type StatementBuilder struct { - SQL bytes.Buffer +// AddError add error to instance +func (inst Instance) AddError(err error) { + if inst.Error == nil { + inst.Error = err + } else { + inst.Error = fmt.Errorf("%v; %w", inst.Error, err) + } +} + +// Statement statement +type Statement struct { + Table string + Model interface{} + Dest interface{} + Clauses map[string]clause.Clause + Settings sync.Map + DB *DB + + // SQL Builder + SQL strings.Builder Vars []interface{} NamedVars []sql.NamedArg } +// StatementOptimizer statement optimizer interface +type StatementOptimizer interface { + OptimizeStatement(Statement) +} + // Write write string func (stmt Statement) Write(sql ...string) (err error) { for _, s := range sql { @@ -40,12 +57,23 @@ func (stmt Statement) Write(sql ...string) (err error) { return } +// Write write string +func (stmt Statement) WriteByte(c byte) (err error) { + return stmt.SQL.WriteByte(c) +} + // WriteQuoted write quoted field func (stmt Statement) WriteQuoted(field interface{}) (err error) { _, err = stmt.SQL.WriteString(stmt.Quote(field)) return } +// Quote returns quoted value +func (stmt Statement) Quote(field interface{}) (str string) { + // FIXME + return fmt.Sprint(field) +} + // Write write string func (stmt Statement) AddVar(vars ...interface{}) string { var placeholders strings.Builder @@ -73,23 +101,34 @@ func (stmt Statement) AddVar(vars ...interface{}) string { return placeholders.String() } -// Quote returns quoted value -func (stmt Statement) Quote(field interface{}) (str string) { - return fmt.Sprint(field) -} - // AddClause add clause -func (s Statement) AddClause(clause clause.Interface) { - s.Clauses[clause.Name()] = append(s.Clauses[clause.Name()], clause) +func (stmt Statement) AddClause(v clause.Interface) { + if optimizer, ok := v.(StatementOptimizer); ok { + optimizer.OptimizeStatement(stmt) + } + + c, _ := stmt.Clauses[v.Name()] + if namer, ok := v.(clause.OverrideNameInterface); ok { + c.Name = namer.OverrideName() + } else { + c.Name = v.Name() + } + + if c.Expression != nil { + v.MergeExpression(c.Expression) + } + + c.Expression = v + stmt.Clauses[v.Name()] = c } -// BuildCondtions build conditions -func (s Statement) BuildCondtions(query interface{}, args ...interface{}) (conditions []clause.Condition) { +// BuildCondtion build condition +func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) { if sql, ok := query.(string); ok { if i, err := strconv.Atoi(sql); err != nil { query = i } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { - return []clause.Condition{clause.Raw{SQL: sql, Values: args}} + return []clause.Expression{clause.String{SQL: sql, Values: args}} } } @@ -100,12 +139,12 @@ func (s Statement) BuildCondtions(query interface{}, args ...interface{}) (condi } switch v := arg.(type) { - case clause.Builder: + case clause.Expression: conditions = append(conditions, v) case *DB: if v.Statement == nil { if cs, ok := v.Statement.Clauses["WHERE"]; ok { - conditions = append(conditions, cs...) + conditions = append(conditions, cs.Expression) } } case map[interface{}]interface{}: @@ -135,8 +174,22 @@ func (s Statement) BuildCondtions(query interface{}, args ...interface{}) (condi if len(conditions) == 0 { conditions = append(conditions, clause.ID{Value: args}) } + return conditions } -func (s Statement) AddError(err error) { +// Build build sql with clauses names +func (stmt Statement) Build(clauses ...string) { + var includeSpace bool + + for _, name := range clauses { + if c, ok := stmt.Clauses[name]; ok { + if includeSpace { + stmt.WriteByte(' ') + } + + includeSpace = true + c.Build(stmt) + } + } } From e509b3100daa35df7b7e80e8928bcf74aacf3a9e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 31 Jan 2020 06:35:25 +0800 Subject: [PATCH 211/881] Implement callbacks --- callbacks.go | 211 ++++++++++++++++++++++++++++++++++++++++ callbacks_test.go | 131 +++++++++++++++++++++++++ errors.go => helpers.go | 21 ++-- logger/logger.go | 46 +++++++++ model.go | 15 --- utils/utils.go | 20 ++++ 6 files changed, 422 insertions(+), 22 deletions(-) create mode 100644 callbacks.go create mode 100644 callbacks_test.go rename errors.go => helpers.go (55%) delete mode 100644 model.go create mode 100644 utils/utils.go diff --git a/callbacks.go b/callbacks.go new file mode 100644 index 00000000..d53e8049 --- /dev/null +++ b/callbacks.go @@ -0,0 +1,211 @@ +package gorm + +import ( + "fmt" + "log" + + "github.com/jinzhu/gorm/logger" + "github.com/jinzhu/gorm/utils" +) + +// Callbacks gorm callbacks manager +type Callbacks struct { + creates []func(*DB) + queries []func(*DB) + updates []func(*DB) + deletes []func(*DB) + row []func(*DB) + raw []func(*DB) + db *DB + processors []*processor +} + +type processor struct { + kind string + name string + before string + after string + remove bool + replace bool + match func(*DB) bool + handler func(*DB) + callbacks *Callbacks +} + +func (cs *Callbacks) Create() *processor { + return &processor{callbacks: cs, kind: "create"} +} + +func (cs *Callbacks) Query() *processor { + return &processor{callbacks: cs, kind: "query"} +} + +func (cs *Callbacks) Update() *processor { + return &processor{callbacks: cs, kind: "update"} +} + +func (cs *Callbacks) Delete() *processor { + return &processor{callbacks: cs, kind: "delete"} +} + +func (cs *Callbacks) Row() *processor { + return &processor{callbacks: cs, kind: "row"} +} + +func (cs *Callbacks) Raw() *processor { + return &processor{callbacks: cs, kind: "raw"} +} + +func (p *processor) Before(name string) *processor { + p.before = name + return p +} + +func (p *processor) After(name string) *processor { + p.after = name + return p +} + +func (p *processor) Match(fc func(*DB) bool) *processor { + p.match = fc + return p +} + +func (p *processor) Get(name string) func(*DB) { + for i := len(p.callbacks.processors) - 1; i >= 0; i-- { + if v := p.callbacks.processors[i]; v.name == name && v.kind == v.kind && !v.remove { + return v.handler + } + } + return nil +} + +func (p *processor) Register(name string, fn func(*DB)) { + p.name = name + p.handler = fn + p.callbacks.processors = append(p.callbacks.processors, p) + p.callbacks.compile(p.callbacks.db) +} + +func (p *processor) Remove(name string) { + logger.Default.Info("removing callback `%v` from %v\n", name, utils.FileWithLineNum()) + p.name = name + p.remove = true + p.callbacks.processors = append(p.callbacks.processors, p) + p.callbacks.compile(p.callbacks.db) +} + +func (p *processor) Replace(name string, fn func(*DB)) { + logger.Default.Info("[info] replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) + p.name = name + p.handler = fn + p.replace = true + p.callbacks.processors = append(p.callbacks.processors, p) + p.callbacks.compile(p.callbacks.db) +} + +// getRIndex get right index from string slice +func getRIndex(strs []string, str string) int { + for i := len(strs) - 1; i >= 0; i-- { + if strs[i] == str { + return i + } + } + return -1 +} + +func sortProcessors(ps []*processor) []func(*DB) { + var ( + allNames, sortedNames []string + sortProcessor func(*processor) error + ) + + for _, p := range ps { + // show warning message the callback name already exists + if idx := getRIndex(allNames, p.name); idx > -1 && !p.replace && !p.remove && !ps[idx].remove { + log.Printf("[warning] duplicated callback `%v` from %v\n", p.name, utils.FileWithLineNum()) + } + allNames = append(allNames, p.name) + } + + sortProcessor = func(p *processor) error { + if getRIndex(sortedNames, p.name) == -1 { // if not sorted + if p.before != "" { // if defined before callback + if sortedIdx := getRIndex(sortedNames, p.before); sortedIdx != -1 { + if curIdx := getRIndex(sortedNames, p.name); curIdx != -1 || true { + // if before callback already sorted, append current callback just after it + sortedNames = append(sortedNames[:sortedIdx], append([]string{p.name}, sortedNames[sortedIdx:]...)...) + } else if curIdx > sortedIdx { + return fmt.Errorf("conflicting callback %v with before %v", p.name, p.before) + } + } else if idx := getRIndex(allNames, p.before); idx != -1 { + // if before callback exists + ps[idx].after = p.name + } + } + + if p.after != "" { // if defined after callback + if sortedIdx := getRIndex(sortedNames, p.after); sortedIdx != -1 { + // if after callback sorted, append current callback to last + sortedNames = append(sortedNames, p.name) + } else if idx := getRIndex(allNames, p.after); idx != -1 { + // if after callback exists but haven't sorted + // set after callback's before callback to current callback + if after := ps[idx]; after.before == "" { + after.before = p.name + sortProcessor(after) + } + } + } + + // if current callback haven't been sorted, append it to last + if getRIndex(sortedNames, p.name) == -1 { + sortedNames = append(sortedNames, p.name) + } + } + + return nil + } + + for _, p := range ps { + sortProcessor(p) + } + + var fns []func(*DB) + for _, name := range sortedNames { + if idx := getRIndex(allNames, name); !ps[idx].remove { + fns = append(fns, ps[idx].handler) + } + } + + return fns +} + +// compile processors +func (cs *Callbacks) compile(db *DB) { + processors := map[string][]*processor{} + for _, p := range cs.processors { + if p.name != "" { + if p.match == nil || p.match(db) { + processors[p.kind] = append(processors[p.kind], p) + } + } + } + + for name, ps := range processors { + switch name { + case "create": + cs.creates = sortProcessors(ps) + case "query": + cs.queries = sortProcessors(ps) + case "update": + cs.updates = sortProcessors(ps) + case "delete": + cs.deletes = sortProcessors(ps) + case "row": + cs.row = sortProcessors(ps) + case "raw": + cs.raw = sortProcessors(ps) + } + } +} diff --git a/callbacks_test.go b/callbacks_test.go new file mode 100644 index 00000000..547cdca1 --- /dev/null +++ b/callbacks_test.go @@ -0,0 +1,131 @@ +package gorm + +import ( + "fmt" + "reflect" + "runtime" + "strings" + "testing" +) + +func assertCallbacks(funcs []func(*DB), fnames []string) (result bool, msg string) { + var got []string + + for _, f := range funcs { + got = append(got, getFuncName(f)) + } + + return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got) +} + +func getFuncName(fc func(*DB)) string { + fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(fc).Pointer()).Name(), ".") + return fnames[len(fnames)-1] +} + +func c1(*DB) {} +func c2(*DB) {} +func c3(*DB) {} +func c4(*DB) {} +func c5(*DB) {} + +func TestCallbacks(t *testing.T) { + type callback struct { + name string + before string + after string + remove bool + replace bool + err error + match func(*DB) bool + h func(*DB) + } + + datas := []struct { + callbacks []callback + results []string + }{ + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}}, + results: []string{"c1", "c2", "c3", "c4", "c5"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}}, + results: []string{"c1", "c2", "c3", "c5", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}}, + results: []string{"c1", "c2", "c3", "c5", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}}, + results: []string{"c1", "c2", "c3", "c5", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}}, + results: []string{"c1", "c5", "c2", "c3", "c4"}, + }, + { + callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, + results: []string{"c1", "c3", "c5", "c2", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, + results: []string{"c1", "c5", "c3", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, + results: []string{"c1", "c4", "c3"}, + }, + } + + // func TestRegisterCallbackWithComplexOrder(t *testing.T) { + // var callback2 = &Callback{logger: defaultLogger} + + // callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) + // callback2.Delete().Before("create").Register("before_create1", beforeCreate1) + // callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2) + // callback2.Delete().Register("after_create1", afterCreate1) + // callback2.Delete().After("after_create1").Register("after_create2", afterCreate2) + + // if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { + // t.Errorf("register callback with order") + // } + // } + + for idx, data := range datas { + callbacks := &Callbacks{} + + for _, c := range data.callbacks { + p := callbacks.Create() + + if c.name == "" { + c.name = getFuncName(c.h) + } + + if c.before != "" { + p = p.Before(c.before) + } + + if c.after != "" { + p = p.After(c.after) + } + + if c.match != nil { + p = p.Match(c.match) + } + + if c.remove { + p.Remove(c.name) + } else if c.replace { + p.Replace(c.name, c.h) + } else { + p.Register(c.name, c.h) + } + } + + if ok, msg := assertCallbacks(callbacks.creates, data.results); !ok { + t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg) + } + } +} diff --git a/errors.go b/helpers.go similarity index 55% rename from errors.go rename to helpers.go index c66408be..8f9df009 100644 --- a/errors.go +++ b/helpers.go @@ -1,6 +1,9 @@ package gorm -import "errors" +import ( + "errors" + "time" +) var ( // ErrRecordNotFound record not found error @@ -13,10 +16,14 @@ var ( ErrUnaddressable = errors.New("using unaddressable value") ) -type Error struct { - Err error -} - -func (e Error) Unwrap() error { - return e.Err +// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt +// It may be embeded into your model or you may build your own model without it +// type User struct { +// gorm.Model +// } +type Model struct { + ID uint `gorm:"primary_key"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time `gorm:"index"` } diff --git a/logger/logger.go b/logger/logger.go index 389a6763..9d6e70bf 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -1,7 +1,15 @@ package logger +import ( + "fmt" + "log" + "os" +) + type LogLevel int +var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", 0)} + const ( Info LogLevel = iota + 1 Warn @@ -11,4 +19,42 @@ const ( // Interface logger interface type Interface interface { LogMode(LogLevel) Interface + Info(string, ...interface{}) + Warn(string, ...interface{}) + Error(string, ...interface{}) +} + +// Writer log writer interface +type Writer interface { + Print(...interface{}) +} + +type Logger struct { + Writer + logLevel LogLevel +} + +func (logger Logger) LogMode(level LogLevel) Interface { + return Logger{Writer: logger.Writer, logLevel: level} +} + +// Info print info +func (logger Logger) Info(msg string, data ...interface{}) { + if logger.logLevel >= Info { + logger.Print("[info] " + fmt.Sprintf(msg, data...)) + } +} + +// Warn print warn messages +func (logger Logger) Warn(msg string, data ...interface{}) { + if logger.logLevel >= Warn { + logger.Print("[warn] " + fmt.Sprintf(msg, data...)) + } +} + +// Error print error messages +func (logger Logger) Error(msg string, data ...interface{}) { + if logger.logLevel >= Error { + logger.Print("[error] " + fmt.Sprintf(msg, data...)) + } } diff --git a/model.go b/model.go deleted file mode 100644 index 118d8f14..00000000 --- a/model.go +++ /dev/null @@ -1,15 +0,0 @@ -package gorm - -import "time" - -// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt -// It may be embeded into your model or you may build your own model without it -// type User struct { -// gorm.Model -// } -type Model struct { - ID uint `gorm:"primary_key"` - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time `gorm:"index"` -} diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 00000000..81ac8b30 --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,20 @@ +package utils + +import ( + "fmt" + "regexp" + "runtime" +) + +var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) +var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`) + +func FileWithLineNum() string { + for i := 2; i < 15; i++ { + _, file, line, ok := runtime.Caller(i) + if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) { + return fmt.Sprintf("%v:%v", file, line) + } + } + return "" +} From 5959c81be67187142fa11159e7d6dc8043f0af82 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 31 Jan 2020 08:29:35 +0800 Subject: [PATCH 212/881] Refactor callbacks --- callbacks.go | 285 ++++++++++++++++++++++------------------ callbacks_test.go | 131 ------------------ tests/callbacks_test.go | 158 ++++++++++++++++++++++ 3 files changed, 313 insertions(+), 261 deletions(-) delete mode 100644 callbacks_test.go create mode 100644 tests/callbacks_test.go diff --git a/callbacks.go b/callbacks.go index d53e8049..a7f30612 100644 --- a/callbacks.go +++ b/callbacks.go @@ -2,26 +2,36 @@ package gorm import ( "fmt" - "log" "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/utils" ) -// Callbacks gorm callbacks manager -type Callbacks struct { - creates []func(*DB) - queries []func(*DB) - updates []func(*DB) - deletes []func(*DB) - row []func(*DB) - raw []func(*DB) - db *DB - processors []*processor +func InitializeCallbacks() *callbacks { + return &callbacks{ + processors: map[string]*processor{ + "create": &processor{}, + "query": &processor{}, + "update": &processor{}, + "delete": &processor{}, + "row": &processor{}, + "raw": &processor{}, + }, + } +} + +// callbacks gorm callbacks manager +type callbacks struct { + processors map[string]*processor } type processor struct { - kind string + db *DB + fns []func(*DB) + callbacks []*callback +} + +type callback struct { name string before string after string @@ -29,79 +39,111 @@ type processor struct { replace bool match func(*DB) bool handler func(*DB) - callbacks *Callbacks + processor *processor } -func (cs *Callbacks) Create() *processor { - return &processor{callbacks: cs, kind: "create"} +func (cs *callbacks) Create() *processor { + return cs.processors["create"] } -func (cs *Callbacks) Query() *processor { - return &processor{callbacks: cs, kind: "query"} +func (cs *callbacks) Query() *processor { + return cs.processors["query"] } -func (cs *Callbacks) Update() *processor { - return &processor{callbacks: cs, kind: "update"} +func (cs *callbacks) Update() *processor { + return cs.processors["update"] } -func (cs *Callbacks) Delete() *processor { - return &processor{callbacks: cs, kind: "delete"} +func (cs *callbacks) Delete() *processor { + return cs.processors["delete"] } -func (cs *Callbacks) Row() *processor { - return &processor{callbacks: cs, kind: "row"} +func (cs *callbacks) Row() *processor { + return cs.processors["row"] } -func (cs *Callbacks) Raw() *processor { - return &processor{callbacks: cs, kind: "raw"} +func (cs *callbacks) Raw() *processor { + return cs.processors["raw"] } -func (p *processor) Before(name string) *processor { - p.before = name - return p -} - -func (p *processor) After(name string) *processor { - p.after = name - return p -} - -func (p *processor) Match(fc func(*DB) bool) *processor { - p.match = fc - return p +func (p *processor) Execute(db *DB) { + for _, f := range p.fns { + f(db) + } } func (p *processor) Get(name string) func(*DB) { - for i := len(p.callbacks.processors) - 1; i >= 0; i-- { - if v := p.callbacks.processors[i]; v.name == name && v.kind == v.kind && !v.remove { + for i := len(p.callbacks) - 1; i >= 0; i-- { + if v := p.callbacks[i]; v.name == name && !v.remove { return v.handler } } return nil } -func (p *processor) Register(name string, fn func(*DB)) { - p.name = name - p.handler = fn - p.callbacks.processors = append(p.callbacks.processors, p) - p.callbacks.compile(p.callbacks.db) +func (p *processor) Before(name string) *callback { + return &callback{before: name, processor: p} } -func (p *processor) Remove(name string) { - logger.Default.Info("removing callback `%v` from %v\n", name, utils.FileWithLineNum()) - p.name = name - p.remove = true - p.callbacks.processors = append(p.callbacks.processors, p) - p.callbacks.compile(p.callbacks.db) +func (p *processor) After(name string) *callback { + return &callback{after: name, processor: p} } -func (p *processor) Replace(name string, fn func(*DB)) { - logger.Default.Info("[info] replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) - p.name = name - p.handler = fn - p.replace = true - p.callbacks.processors = append(p.callbacks.processors, p) - p.callbacks.compile(p.callbacks.db) +func (p *processor) Match(fc func(*DB) bool) *callback { + return &callback{match: fc, processor: p} +} + +func (p *processor) Register(name string, fn func(*DB)) error { + return (&callback{processor: p}).Register(name, fn) +} + +func (p *processor) Remove(name string) error { + return (&callback{processor: p}).Remove(name) +} + +func (p *processor) Replace(name string, fn func(*DB)) error { + return (&callback{processor: p}).Replace(name, fn) +} + +func (p *processor) compile(db *DB) (err error) { + if p.fns, err = sortCallbacks(p.callbacks); err != nil { + logger.Default.Error("Got error when compile callbacks, got %v", err) + } + return +} + +func (c *callback) Before(name string) *callback { + c.before = name + return c +} + +func (c *callback) After(name string) *callback { + c.after = name + return c +} + +func (c *callback) Register(name string, fn func(*DB)) error { + c.name = name + c.handler = fn + c.processor.callbacks = append(c.processor.callbacks, c) + return c.processor.compile(c.processor.db) +} + +func (c *callback) Remove(name string) error { + logger.Default.Warn("removing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.name = name + c.remove = true + c.processor.callbacks = append(c.processor.callbacks, c) + return c.processor.compile(c.processor.db) +} + +func (c *callback) Replace(name string, fn func(*DB)) error { + logger.Default.Info("replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.name = name + c.handler = fn + c.replace = true + c.processor.callbacks = append(c.processor.callbacks, c) + return c.processor.compile(c.processor.db) } // getRIndex get right index from string slice @@ -114,98 +156,81 @@ func getRIndex(strs []string, str string) int { return -1 } -func sortProcessors(ps []*processor) []func(*DB) { +func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { var ( - allNames, sortedNames []string - sortProcessor func(*processor) error + names, sorted []string + sortCallback func(*callback) error ) - for _, p := range ps { + for _, c := range cs { // show warning message the callback name already exists - if idx := getRIndex(allNames, p.name); idx > -1 && !p.replace && !p.remove && !ps[idx].remove { - log.Printf("[warning] duplicated callback `%v` from %v\n", p.name, utils.FileWithLineNum()) + if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { + logger.Default.Warn("duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) } - allNames = append(allNames, p.name) + names = append(names, c.name) } - sortProcessor = func(p *processor) error { - if getRIndex(sortedNames, p.name) == -1 { // if not sorted - if p.before != "" { // if defined before callback - if sortedIdx := getRIndex(sortedNames, p.before); sortedIdx != -1 { - if curIdx := getRIndex(sortedNames, p.name); curIdx != -1 || true { - // if before callback already sorted, append current callback just after it - sortedNames = append(sortedNames[:sortedIdx], append([]string{p.name}, sortedNames[sortedIdx:]...)...) - } else if curIdx > sortedIdx { - return fmt.Errorf("conflicting callback %v with before %v", p.name, p.before) - } - } else if idx := getRIndex(allNames, p.before); idx != -1 { - // if before callback exists - ps[idx].after = p.name + sortCallback = func(c *callback) error { + if c.before != "" { // if defined before callback + if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + // if before callback already sorted, append current callback just after it + sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) + } else if curIdx > sortedIdx { + return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before) } + } else if idx := getRIndex(names, c.before); idx != -1 { + // if before callback exists + cs[idx].after = c.name } + } - if p.after != "" { // if defined after callback - if sortedIdx := getRIndex(sortedNames, p.after); sortedIdx != -1 { + if c.after != "" { // if defined after callback + if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { // if after callback sorted, append current callback to last - sortedNames = append(sortedNames, p.name) - } else if idx := getRIndex(allNames, p.after); idx != -1 { - // if after callback exists but haven't sorted - // set after callback's before callback to current callback - if after := ps[idx]; after.before == "" { - after.before = p.name - sortProcessor(after) - } + sorted = append(sorted, c.name) + } else if curIdx < sortedIdx { + return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after) + } + } else if idx := getRIndex(names, c.after); idx != -1 { + // if after callback exists but haven't sorted + // set after callback's before callback to current callback + after := cs[idx] + + if after.before == "" { + after.before = c.name + } + + if err := sortCallback(after); err != nil { + return err + } + + if err := sortCallback(c); err != nil { + return err } } + } - // if current callback haven't been sorted, append it to last - if getRIndex(sortedNames, p.name) == -1 { - sortedNames = append(sortedNames, p.name) - } + // if current callback haven't been sorted, append it to last + if getRIndex(sorted, c.name) == -1 { + sorted = append(sorted, c.name) } return nil } - for _, p := range ps { - sortProcessor(p) - } - - var fns []func(*DB) - for _, name := range sortedNames { - if idx := getRIndex(allNames, name); !ps[idx].remove { - fns = append(fns, ps[idx].handler) + for _, c := range cs { + if err = sortCallback(c); err != nil { + return } } - return fns -} - -// compile processors -func (cs *Callbacks) compile(db *DB) { - processors := map[string][]*processor{} - for _, p := range cs.processors { - if p.name != "" { - if p.match == nil || p.match(db) { - processors[p.kind] = append(processors[p.kind], p) - } - } - } - - for name, ps := range processors { - switch name { - case "create": - cs.creates = sortProcessors(ps) - case "query": - cs.queries = sortProcessors(ps) - case "update": - cs.updates = sortProcessors(ps) - case "delete": - cs.deletes = sortProcessors(ps) - case "row": - cs.row = sortProcessors(ps) - case "raw": - cs.raw = sortProcessors(ps) - } - } + for _, name := range sorted { + if idx := getRIndex(names, name); !cs[idx].remove { + fns = append(fns, cs[idx].handler) + } + } + + return } diff --git a/callbacks_test.go b/callbacks_test.go deleted file mode 100644 index 547cdca1..00000000 --- a/callbacks_test.go +++ /dev/null @@ -1,131 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "runtime" - "strings" - "testing" -) - -func assertCallbacks(funcs []func(*DB), fnames []string) (result bool, msg string) { - var got []string - - for _, f := range funcs { - got = append(got, getFuncName(f)) - } - - return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got) -} - -func getFuncName(fc func(*DB)) string { - fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(fc).Pointer()).Name(), ".") - return fnames[len(fnames)-1] -} - -func c1(*DB) {} -func c2(*DB) {} -func c3(*DB) {} -func c4(*DB) {} -func c5(*DB) {} - -func TestCallbacks(t *testing.T) { - type callback struct { - name string - before string - after string - remove bool - replace bool - err error - match func(*DB) bool - h func(*DB) - } - - datas := []struct { - callbacks []callback - results []string - }{ - { - callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}}, - results: []string{"c1", "c2", "c3", "c4", "c5"}, - }, - { - callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}}, - results: []string{"c1", "c2", "c3", "c5", "c4"}, - }, - { - callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}}, - results: []string{"c1", "c2", "c3", "c5", "c4"}, - }, - { - callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}}, - results: []string{"c1", "c2", "c3", "c5", "c4"}, - }, - { - callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}}, - results: []string{"c1", "c5", "c2", "c3", "c4"}, - }, - { - callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, - results: []string{"c1", "c3", "c5", "c2", "c4"}, - }, - { - callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, - results: []string{"c1", "c5", "c3", "c4"}, - }, - { - callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, - results: []string{"c1", "c4", "c3"}, - }, - } - - // func TestRegisterCallbackWithComplexOrder(t *testing.T) { - // var callback2 = &Callback{logger: defaultLogger} - - // callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) - // callback2.Delete().Before("create").Register("before_create1", beforeCreate1) - // callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2) - // callback2.Delete().Register("after_create1", afterCreate1) - // callback2.Delete().After("after_create1").Register("after_create2", afterCreate2) - - // if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { - // t.Errorf("register callback with order") - // } - // } - - for idx, data := range datas { - callbacks := &Callbacks{} - - for _, c := range data.callbacks { - p := callbacks.Create() - - if c.name == "" { - c.name = getFuncName(c.h) - } - - if c.before != "" { - p = p.Before(c.before) - } - - if c.after != "" { - p = p.After(c.after) - } - - if c.match != nil { - p = p.Match(c.match) - } - - if c.remove { - p.Remove(c.name) - } else if c.replace { - p.Replace(c.name, c.h) - } else { - p.Register(c.name, c.h) - } - } - - if ok, msg := assertCallbacks(callbacks.creates, data.results); !ok { - t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg) - } - } -} diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go new file mode 100644 index 00000000..878384a7 --- /dev/null +++ b/tests/callbacks_test.go @@ -0,0 +1,158 @@ +package gorm_test + +import ( + "fmt" + "reflect" + "runtime" + "strings" + "testing" + + "github.com/jinzhu/gorm" +) + +func assertCallbacks(v interface{}, fnames []string) (result bool, msg string) { + var ( + got []string + funcs = reflect.ValueOf(v).Elem().FieldByName("fns") + ) + + for i := 0; i < funcs.Len(); i++ { + got = append(got, getFuncName(funcs.Index(i))) + } + + return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got) +} + +func getFuncName(fc interface{}) string { + reflectValue, ok := fc.(reflect.Value) + if !ok { + reflectValue = reflect.ValueOf(fc) + } + + fnames := strings.Split(runtime.FuncForPC(reflectValue.Pointer()).Name(), ".") + return fnames[len(fnames)-1] +} + +func c1(*gorm.DB) {} +func c2(*gorm.DB) {} +func c3(*gorm.DB) {} +func c4(*gorm.DB) {} +func c5(*gorm.DB) {} + +func TestCallbacks(t *testing.T) { + type callback struct { + name string + before string + after string + remove bool + replace bool + err string + match func(*gorm.DB) bool + h func(*gorm.DB) + } + + datas := []struct { + callbacks []callback + err string + results []string + }{ + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}}, + results: []string{"c1", "c2", "c3", "c4", "c5"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}}, + results: []string{"c1", "c2", "c3", "c5", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}}, + results: []string{"c1", "c2", "c3", "c5", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}}, + results: []string{"c1", "c2", "c3", "c5", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}}, + results: []string{"c1", "c5", "c2", "c3", "c4"}, + }, + { + callbacks: []callback{{h: c1, after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, + results: []string{"c3", "c1", "c5", "c2", "c4"}, + }, + { + callbacks: []callback{{h: c1, before: "c4", after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, + results: []string{"c3", "c1", "c5", "c2", "c4"}, + }, + { + callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, + err: "conflicting", + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, + results: []string{"c1", "c5", "c3", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, + results: []string{"c1", "c4", "c3"}, + }, + } + + for idx, data := range datas { + var err error + callbacks := gorm.InitializeCallbacks() + + for _, c := range data.callbacks { + var v interface{} = callbacks.Create() + callMethod := func(s interface{}, name string, args ...interface{}) { + var argValues []reflect.Value + for _, arg := range args { + argValues = append(argValues, reflect.ValueOf(arg)) + } + + results := reflect.ValueOf(s).MethodByName(name).Call(argValues) + if len(results) > 0 { + v = results[0].Interface() + } + } + + if c.name == "" { + c.name = getFuncName(c.h) + } + + if c.before != "" { + callMethod(v, "Before", c.before) + } + + if c.after != "" { + callMethod(v, "After", c.after) + } + + if c.match != nil { + callMethod(v, "Match", c.match) + } + + if c.remove { + callMethod(v, "Remove", c.name) + } else if c.replace { + callMethod(v, "Replace", c.name, c.h) + } else { + callMethod(v, "Register", c.name, c.h) + } + + if e, ok := v.(error); !ok || e != nil { + err = e + } + } + + if len(data.err) > 0 && err == nil { + t.Errorf("callbacks tests #%v should got error %v, but not", idx+1, data.err) + } else if len(data.err) == 0 && err != nil { + t.Errorf("callbacks tests #%v should not got error, but got %v", idx+1, err) + } + + if ok, msg := assertCallbacks(callbacks.Create(), data.results); !ok { + t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg) + } + } +} From 1079e17caf327efd28c941e48decc7cde6cccaf0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 31 Jan 2020 12:22:37 +0800 Subject: [PATCH 213/881] Implement schema parser --- model/model.go | 37 ------ schema/field.go | 202 ++++++++++++++++++++++++++++++ {model => schema}/relationship.go | 8 +- schema/schema.go | 80 ++++++++++++ schema/utils.go | 31 +++++ 5 files changed, 320 insertions(+), 38 deletions(-) delete mode 100644 model/model.go create mode 100644 schema/field.go rename {model => schema}/relationship.go (89%) create mode 100644 schema/schema.go create mode 100644 schema/utils.go diff --git a/model/model.go b/model/model.go deleted file mode 100644 index 316f3ab5..00000000 --- a/model/model.go +++ /dev/null @@ -1,37 +0,0 @@ -package model - -import ( - "reflect" -) - -type Model struct { - ModelType reflect.Type - Table string - PrioritizedPrimaryField *Field - PrimaryFields []*Field - Fields []*Field - FieldsByName map[string]*Field - FieldsByDBName map[string]*Field - Relationships Relationships -} - -type Field struct { - Name string - DBName string - DataType reflect.Type - DBDataType string - Tag reflect.StructTag - TagSettings map[string]string - PrimaryKey bool - AutoIncrement bool - Creatable bool - Updatable bool - Nullable bool - Unique bool - Precision int - Size int - HasDefaultValue bool - DefaultValue string - StructField reflect.StructField - Model *Model -} diff --git a/schema/field.go b/schema/field.go new file mode 100644 index 00000000..9d3b3033 --- /dev/null +++ b/schema/field.go @@ -0,0 +1,202 @@ +package schema + +import ( + "database/sql/driver" + "reflect" + "strconv" + "sync" + "time" +) + +type FieldType string + +const ( + Bool FieldType = "bool" + Int = "int" + Uint = "uint" + Float = "float" + String = "string" + Time = "time" + Bytes = "bytes" +) + +type Field struct { + Name string + DBName string + BindNames []string + DataType FieldType + DBDataType string + PrimaryKey bool + AutoIncrement bool + Creatable bool + Updatable bool + HasDefaultValue bool + DefaultValue string + NotNull bool + Unique bool + Comment string + Size int + Precision int + FieldType reflect.Type + StructField reflect.StructField + Tag reflect.StructTag + TagSettings map[string]string + Schema *Schema + EmbeddedbSchema *Schema + Relationship string +} + +func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { + field := &Field{ + Name: fieldStruct.Name, + BindNames: []string{fieldStruct.Name}, + FieldType: fieldStruct.Type, + StructField: fieldStruct, + Creatable: true, + Updatable: true, + Tag: fieldStruct.Tag, + TagSettings: parseTagSetting(fieldStruct.Tag), + } + + for field.FieldType.Kind() == reflect.Ptr { + field.FieldType = field.FieldType.Elem() + } + + fieldValue := reflect.New(field.FieldType) + + // if field is valuer, used its value or first fields as data type + if valuer, isValuer := fieldValue.Interface().(driver.Valuer); isValuer { + var overrideFieldValue bool + if v, err := valuer.Value(); v != nil && err == nil { + overrideFieldValue = true + fieldValue = reflect.ValueOf(v) + } + + if field.FieldType.Kind() == reflect.Struct { + for i := 0; i < field.FieldType.NumField(); i++ { + if !overrideFieldValue { + newFieldType := field.FieldType.Field(i).Type + for newFieldType.Kind() == reflect.Ptr { + newFieldType = newFieldType.Elem() + } + + fieldValue = reflect.New(newFieldType) + overrideFieldValue = true + } + + // copy tag settings from valuer + for key, value := range parseTagSetting(field.FieldType.Field(i).Tag) { + if _, ok := field.TagSettings[key]; !ok { + field.TagSettings[key] = value + } + } + } + } + } + + // setup permission + if _, ok := field.TagSettings["-"]; ok { + field.Creatable = false + field.Updatable = false + } + + if dbName, ok := field.TagSettings["COLUMN"]; ok { + field.DBName = dbName + } + + if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && checkTruth(val) { + field.PrimaryKey = true + } + + if val, ok := field.TagSettings["AUTO_INCREMENT"]; ok && checkTruth(val) { + field.AutoIncrement = true + field.HasDefaultValue = true + } + + if v, ok := field.TagSettings["DEFAULT"]; ok { + field.HasDefaultValue = true + field.DefaultValue = v + } + + if num, ok := field.TagSettings["SIZE"]; ok { + field.Size, _ = strconv.Atoi(num) + } + + if p, ok := field.TagSettings["PRECISION"]; ok { + field.Precision, _ = strconv.Atoi(p) + } + + if val, ok := field.TagSettings["NOT NULL"]; ok && checkTruth(val) { + field.NotNull = true + } + + if val, ok := field.TagSettings["UNIQUE"]; ok && checkTruth(val) { + field.Unique = true + } + + if val, ok := field.TagSettings["COMMENT"]; ok { + field.Comment = val + } + + if val, ok := field.TagSettings["TYPE"]; ok { + field.DBDataType = val + } + + switch fieldValue.Kind() { + case reflect.Bool: + field.DataType = Bool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.DataType = Int + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.DataType = Uint + case reflect.Float32, reflect.Float64: + field.DataType = Float + case reflect.String: + field.DataType = String + case reflect.Struct: + if _, ok := fieldValue.Interface().(time.Time); ok { + field.DataType = Time + } + case reflect.Array, reflect.Slice: + if fieldValue.Type().Elem() == reflect.TypeOf(uint8(0)) { + field.DataType = Bytes + } + } + + if field.Size == 0 { + switch fieldValue.Kind() { + case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: + field.Size = 64 + case reflect.Int8, reflect.Uint8: + field.Size = 8 + case reflect.Int16, reflect.Uint16: + field.Size = 16 + case reflect.Int32, reflect.Uint32, reflect.Float32: + field.Size = 32 + } + } + + if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { + field.EmbeddedbSchema = Parse(fieldValue, sync.Map{}) + for _, ef := range field.EmbeddedbSchema.Fields { + ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) + + if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok { + ef.DBName = prefix + ef.DBName + } + + for k, v := range field.TagSettings { + ef.TagSettings[k] = v + } + } + } else { + switch fieldValue.Kind() { + case reflect.Struct: + field.Relationship = "one" + case reflect.Slice: + field.Relationship = "many" + } + } + + return field +} diff --git a/model/relationship.go b/schema/relationship.go similarity index 89% rename from model/relationship.go rename to schema/relationship.go index 60b0751e..b0c630be 100644 --- a/model/relationship.go +++ b/schema/relationship.go @@ -1,4 +1,4 @@ -package model +package schema // RelationshipType relationship type type RelationshipType string @@ -35,3 +35,9 @@ type JoinTable struct { ForeignKeys []*RelationField AssociationForeignKeys []*RelationField } + +func (schema *Schema) buildToOneRel(field *Field) { +} + +func (schema *Schema) buildToManyRel(field *Field) { +} diff --git a/schema/schema.go b/schema/schema.go new file mode 100644 index 00000000..6d85af8c --- /dev/null +++ b/schema/schema.go @@ -0,0 +1,80 @@ +package schema + +import ( + "go/ast" + "reflect" + "strings" + "sync" +) + +type Schema struct { + ModelType reflect.Type + Table string + PrioritizedPrimaryField *Field + PrimaryFields []*Field + Fields []*Field + FieldsByName map[string]*Field + FieldsByDBName map[string]*Field + Relationships Relationships +} + +// get data type from dialector +func Parse(dest interface{}, cacheStore sync.Map) *Schema { + modelType := reflect.ValueOf(dest).Type() + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + return nil + } + + if v, ok := cacheStore.Load(modelType); ok { + return v.(*Schema) + } + + schema := &Schema{ + ModelType: modelType, + FieldsByName: map[string]*Field{}, + FieldsByDBName: map[string]*Field{}, + } + + for i := 0; i < modelType.NumField(); i++ { + fieldStruct := modelType.Field(i) + if !ast.IsExported(fieldStruct.Name) { + continue + } + + schema.Fields = append(schema.Fields, schema.ParseField(fieldStruct)) + // db namer + } + + for _, field := range schema.Fields { + if field.DBName != "" { + // nonexistence or shortest path or first appear prioritized + if v, ok := schema.FieldsByDBName[field.DBName]; !ok || len(field.BindNames) < len(v.BindNames) { + schema.FieldsByDBName[field.DBName] = field + schema.FieldsByName[field.Name] = field + } + } + + if _, ok := schema.FieldsByName[field.Name]; !ok { + schema.FieldsByName[field.Name] = field + } + } + + for db, field := range schema.FieldsByDBName { + if strings.ToLower(db) == "id" { + schema.PrioritizedPrimaryField = field + } + + if field.PrimaryKey { + if schema.PrioritizedPrimaryField == nil { + schema.PrioritizedPrimaryField = field + } + schema.PrimaryFields = append(schema.PrimaryFields, field) + } + } + + return schema +} diff --git a/schema/utils.go b/schema/utils.go new file mode 100644 index 00000000..1b0f5eac --- /dev/null +++ b/schema/utils.go @@ -0,0 +1,31 @@ +package schema + +import ( + "reflect" + "strings" +) + +func parseTagSetting(tags reflect.StructTag) map[string]string { + setting := map[string]string{} + + for _, value := range strings.Split(tags.Get("gorm"), ";") { + if value != "" { + v := strings.Split(value, ":") + k := strings.TrimSpace(strings.ToUpper(v[0])) + + if len(v) >= 2 { + setting[k] = strings.Join(v[1:], ":") + } else { + setting[k] = k + } + } + } + return setting +} + +func checkTruth(val string) bool { + if strings.ToLower(val) == "false" { + return false + } + return true +} From bc68fde6aa9892b734cdbd569bb22d58e9493f46 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 31 Jan 2020 14:17:02 +0800 Subject: [PATCH 214/881] Implement naming strategy --- go.mod | 2 + go.sum | 2 + gorm.go | 12 ++++-- schema/naming.go | 96 +++++++++++++++++++++++++++++++++++++++++++ schema/naming_test.go | 34 +++++++++++++++ 5 files changed, 142 insertions(+), 4 deletions(-) create mode 100644 go.sum create mode 100644 schema/naming.go create mode 100644 schema/naming_test.go diff --git a/go.mod b/go.mod index d0a110ba..516a9759 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/jinzhu/gorm go 1.13 + +require github.com/jinzhu/inflection v1.0.0 diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..a310b071 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= diff --git a/gorm.go b/gorm.go index 838f2862..6ceac412 100644 --- a/gorm.go +++ b/gorm.go @@ -6,18 +6,18 @@ import ( "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/logger" + "github.com/jinzhu/gorm/schema" ) // Config GORM config type Config struct { - // Set true to use singular table name, by default, GORM will pluralize your struct's name as table name - // Refer https://github.com/jinzhu/inflection for inflection rules - SingularTable bool - // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity // You can cancel it by setting `SkipDefaultTransaction` to true SkipDefaultTransaction bool + // NamingStrategy tables, columns naming strategy + NamingStrategy schema.Namer + // Logger Logger logger.Interface @@ -48,6 +48,10 @@ type Session struct { // Open initialize db session based on dialector func Open(dialector Dialector, config *Config) (db *DB, err error) { + if config.NamingStrategy == nil { + config.NamingStrategy = schema.NamingStrategy{} + } + return &DB{ Config: config, Dialector: dialector, diff --git a/schema/naming.go b/schema/naming.go new file mode 100644 index 00000000..1baa8558 --- /dev/null +++ b/schema/naming.go @@ -0,0 +1,96 @@ +package schema + +import ( + "fmt" + "strings" + "sync" + + "github.com/jinzhu/inflection" +) + +// Namer namer interface +type Namer interface { + TableName(string) string + ColumnName(string) string +} + +// NamingStrategy tables, columns naming strategy +type NamingStrategy struct { + TablePrefix string + SingularTable bool +} + +// TableName convert string to table name +func (ns NamingStrategy) TableName(str string) string { + if ns.SingularTable { + return ns.TablePrefix + toDBName(str) + } + return ns.TablePrefix + inflection.Plural(toDBName(str)) +} + +// ColumnName convert string to column name +func (ns NamingStrategy) ColumnName(str string) string { + return toDBName(str) +} + +var ( + smap sync.Map + // https://github.com/golang/lint/blob/master/lint.go#L770 + commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} + commonInitialismsReplacer *strings.Replacer +) + +func init() { + var commonInitialismsForReplacer []string + for _, initialism := range commonInitialisms { + commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) + } + commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) +} + +func toDBName(name string) string { + if name == "" { + return "" + } else if v, ok := smap.Load(name); ok { + return fmt.Sprint(v) + } + + var ( + value = commonInitialismsReplacer.Replace(name) + buf strings.Builder + lastCase, nextCase, nextNumber bool // upper case == true + curCase = value[0] <= 'Z' && value[0] >= 'A' + ) + + for i, v := range value[:len(value)-1] { + nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A' + nextNumber = value[i+1] >= '0' && value[i+1] <= '9' + + if curCase { + if lastCase && (nextCase || nextNumber) { + buf.WriteRune(v + 32) + } else { + if i > 0 && value[i-1] != '_' && value[i+1] != '_' { + buf.WriteByte('_') + } + buf.WriteRune(v + 32) + } + } else { + buf.WriteRune(v) + } + + lastCase = curCase + curCase = nextCase + } + + if curCase { + if !lastCase && len(value) > 1 { + buf.WriteByte('_') + } + buf.WriteByte(value[len(value)-1] + 32) + } else { + buf.WriteByte(value[len(value)-1]) + } + + return buf.String() +} diff --git a/schema/naming_test.go b/schema/naming_test.go new file mode 100644 index 00000000..96b83ced --- /dev/null +++ b/schema/naming_test.go @@ -0,0 +1,34 @@ +package schema + +import ( + "testing" +) + +func TestToDBName(t *testing.T) { + var maps = map[string]string{ + "": "", + "x": "x", + "X": "x", + "userRestrictions": "user_restrictions", + "ThisIsATest": "this_is_a_test", + "PFAndESI": "pf_and_esi", + "AbcAndJkl": "abc_and_jkl", + "EmployeeID": "employee_id", + "SKU_ID": "sku_id", + "FieldX": "field_x", + "HTTPAndSMTP": "http_and_smtp", + "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", + "UUID": "uuid", + "HTTPURL": "http_url", + "HTTP_URL": "http_url", + "SHA256Hash": "sha256_hash", + "SHA256HASH": "sha256_hash", + "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", + } + + for key, value := range maps { + if toDBName(key) != value { + t.Errorf("%v toName should equal %v, but got %v", key, value, toDBName(key)) + } + } +} From 010dc7e6ddca1751ffe7bd08769debcbcb0c2ce1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 31 Jan 2020 14:31:15 +0800 Subject: [PATCH 215/881] Add namer when generate schema --- schema/field.go | 2 +- schema/schema.go | 21 ++++++++++++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/schema/field.go b/schema/field.go index 9d3b3033..88a0d3fb 100644 --- a/schema/field.go +++ b/schema/field.go @@ -177,7 +177,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { - field.EmbeddedbSchema = Parse(fieldValue, sync.Map{}) + field.EmbeddedbSchema = Parse(fieldValue, sync.Map{}, schema.namer) for _, ef := range field.EmbeddedbSchema.Fields { ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) diff --git a/schema/schema.go b/schema/schema.go index 6d85af8c..5069bb44 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -16,10 +16,11 @@ type Schema struct { FieldsByName map[string]*Field FieldsByDBName map[string]*Field Relationships Relationships + namer Namer } // get data type from dialector -func Parse(dest interface{}, cacheStore sync.Map) *Schema { +func Parse(dest interface{}, cacheStore sync.Map, namer Namer) *Schema { modelType := reflect.ValueOf(dest).Type() for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() @@ -35,6 +36,7 @@ func Parse(dest interface{}, cacheStore sync.Map) *Schema { schema := &Schema{ ModelType: modelType, + Table: namer.TableName(modelType.Name()), FieldsByName: map[string]*Field{}, FieldsByDBName: map[string]*Field{}, } @@ -45,14 +47,23 @@ func Parse(dest interface{}, cacheStore sync.Map) *Schema { continue } - schema.Fields = append(schema.Fields, schema.ParseField(fieldStruct)) - // db namer + field := schema.ParseField(fieldStruct) + schema.Fields = append(schema.Fields, field) + if field.EmbeddedbSchema != nil { + for _, f := range field.EmbeddedbSchema.Fields { + schema.Fields = append(schema.Fields, f) + } + } } for _, field := range schema.Fields { + if field.DBName == "" { + field.DBName = namer.ColumnName(field.Name) + } + if field.DBName != "" { - // nonexistence or shortest path or first appear prioritized - if v, ok := schema.FieldsByDBName[field.DBName]; !ok || len(field.BindNames) < len(v.BindNames) { + // nonexistence or shortest path or first appear prioritized if has permission + if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) { schema.FieldsByDBName[field.DBName] = field schema.FieldsByName[field.Name] = field } From eea78f3f309eafef7b4fe5833506f283d2c850f3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 1 Feb 2020 12:46:52 +0800 Subject: [PATCH 216/881] Implement parse relationship architecture --- clause/clause.go | 26 ++++++-- clause/query.go | 6 ++ schema/field.go | 32 ++++----- schema/naming.go | 16 ++++- schema/relationship.go | 144 ++++++++++++++++++++++++++++++++++------- schema/schema.go | 55 ++++++++++++---- schema/utils.go | 9 +++ 7 files changed, 226 insertions(+), 62 deletions(-) diff --git a/clause/clause.go b/clause/clause.go index b0507f44..1b4a7e85 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -59,7 +59,7 @@ type OverrideNameInterface interface { type Where struct { AndConditions AddConditions ORConditions []ORConditions - Builders []Expression + builders []Expression } func (where Where) Name() string { @@ -74,8 +74,8 @@ func (where Where) Build(builder Builder) { where.AndConditions.Build(builder) } - if len(where.Builders) > 0 { - for _, b := range where.Builders { + if len(where.builders) > 0 { + for _, b := range where.builders { if withConditions { builder.Write(" AND ") } @@ -122,9 +122,9 @@ func (where Where) MergeExpression(expr Expression) { if w, ok := expr.(Where); ok { where.AndConditions = append(where.AndConditions, w.AndConditions...) where.ORConditions = append(where.ORConditions, w.ORConditions...) - where.Builders = append(where.Builders, w.Builders...) + where.builders = append(where.builders, w.builders...) } else { - where.Builders = append(where.Builders, expr) + where.builders = append(where.builders, expr) } } @@ -135,6 +135,22 @@ type Select struct { // Join join clause type Join struct { + Table string + Type string // left join books on + ON []Expression + builders []Expression +} + +func (join Join) Build(builder Builder) { + // TODO +} + +func (join Join) MergeExpression(expr Expression) { + if j, ok := expr.(Join); ok { + join.builders = append(join.builders, j.builders...) + } else { + join.builders = append(join.builders, expr) + } } // GroupBy group by clause diff --git a/clause/query.go b/clause/query.go index 949678d9..7b5491e5 100644 --- a/clause/query.go +++ b/clause/query.go @@ -2,6 +2,12 @@ package clause import "strings" +// Column quote with name +type Column struct { + Table string + Name string +} + //////////////////////////////////////////////////////////////////////////////// // Query Expressions //////////////////////////////////////////////////////////////////////////////// diff --git a/schema/field.go b/schema/field.go index 88a0d3fb..005fd4e3 100644 --- a/schema/field.go +++ b/schema/field.go @@ -8,23 +8,23 @@ import ( "time" ) -type FieldType string +type DataType string const ( - Bool FieldType = "bool" - Int = "int" - Uint = "uint" - Float = "float" - String = "string" - Time = "time" - Bytes = "bytes" + Bool DataType = "bool" + Int = "int" + Uint = "uint" + Float = "float" + String = "string" + Time = "time" + Bytes = "bytes" ) type Field struct { Name string DBName string BindNames []string - DataType FieldType + DataType DataType DBDataType string PrimaryKey bool AutoIncrement bool @@ -42,8 +42,7 @@ type Field struct { Tag reflect.StructTag TagSettings map[string]string Schema *Schema - EmbeddedbSchema *Schema - Relationship string + EmbeddedSchema *Schema } func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { @@ -177,8 +176,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { - field.EmbeddedbSchema = Parse(fieldValue, sync.Map{}, schema.namer) - for _, ef := range field.EmbeddedbSchema.Fields { + field.EmbeddedSchema, schema.err = Parse(fieldValue, sync.Map{}, schema.namer) + for _, ef := range field.EmbeddedSchema.Fields { ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok { @@ -189,13 +188,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.TagSettings[k] = v } } - } else { - switch fieldValue.Kind() { - case reflect.Struct: - field.Relationship = "one" - case reflect.Slice: - field.Relationship = "many" - } } return field diff --git a/schema/naming.go b/schema/naming.go index 1baa8558..6df80d2a 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -10,8 +10,10 @@ import ( // Namer namer interface type Namer interface { - TableName(string) string - ColumnName(string) string + TableName(table string) string + ColumnName(column string) string + JoinTableName(table string) string + JoinTableColumnName(table, column string) string } // NamingStrategy tables, columns naming strategy @@ -33,6 +35,16 @@ func (ns NamingStrategy) ColumnName(str string) string { return toDBName(str) } +// JoinTableName convert string to join table name +func (ns NamingStrategy) JoinTableName(str string) string { + return ns.TablePrefix + toDBName(str) +} + +// JoinTableColumnName convert string to join table column name +func (ns NamingStrategy) JoinTableColumnName(referenceTable, referenceColumn string) string { + return inflection.Singular(toDBName(referenceTable)) + toDBName(referenceColumn) +} + var ( smap sync.Map // https://github.com/golang/lint/blob/master/lint.go#L770 diff --git a/schema/relationship.go b/schema/relationship.go index b0c630be..95f56f6d 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -1,43 +1,143 @@ package schema +import ( + "fmt" + "reflect" + "strings" +) + // RelationshipType relationship type type RelationshipType string const ( - HasOneRel RelationshipType = "has_one" // HasOneRel has one relationship - HasManyRel RelationshipType = "has_many" // HasManyRel has many relationship - BelongsToRel RelationshipType = "belongs_to" // BelongsToRel belongs to relationship - Many2ManyRel RelationshipType = "many_to_many" // Many2ManyRel many to many relationship + HasOne RelationshipType = "has_one" // HasOneRel has one relationship + HasMany RelationshipType = "has_many" // HasManyRel has many relationship + BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship + Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship ) type Relationships struct { - HasOne map[string]*Relationship - BelongsTo map[string]*Relationship - HasMany map[string]*Relationship - Many2Many map[string]*Relationship + HasOne []*Relationship + BelongsTo []*Relationship + HasMany []*Relationship + Many2Many []*Relationship + Relations map[string]*Relationship } type Relationship struct { - Type RelationshipType - ForeignKeys []*RelationField // self - AssociationForeignKeys []*RelationField // association - JoinTable *JoinTable + Name string + Type RelationshipType + Field *Field + Polymorphic *Polymorphic + References []Reference + Schema *Schema + FieldSchema *Schema + JoinTable *Schema + ForeignKeys, AssociationForeignKeys []string } -type RelationField struct { - *Field - PolymorphicField *Field - PolymorphicValue string +type Polymorphic struct { + PolymorphicID *Field + PolymorphicType *Field + Value string } -type JoinTable struct { - Table string - ForeignKeys []*RelationField - AssociationForeignKeys []*RelationField +type Reference struct { + PriamryKey *Field + PriamryValue string + ForeignKey *Field + OwnPriamryKey bool } -func (schema *Schema) buildToOneRel(field *Field) { +func (schema *Schema) parseRelation(field *Field) { + var ( + fieldValue = reflect.New(field.FieldType).Interface() + relation = &Relationship{ + Name: field.Name, + Field: field, + Schema: schema, + Type: RelationshipType(strings.ToLower(strings.TrimSpace(field.TagSettings["REL"]))), + ForeignKeys: toColumns(field.TagSettings["FOREIGNKEY"]), + AssociationForeignKeys: toColumns(field.TagSettings["ASSOCIATION_FOREIGNKEY"]), + } + ) + + if relation.FieldSchema, schema.err = Parse(fieldValue, schema.cacheStore, schema.namer); schema.err != nil { + return + } + + // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` + // type User struct { + // Toys []Toy `gorm:"polymorphic:Owner;"` + // } + // type Pet struct { + // Toy Toy `gorm:"polymorphic:Owner;"` + // } + // type Toy struct { + // OwnerID int + // OwnerType string + // } + if polymorphic, _ := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + relation.Polymorphic = &Polymorphic{ + Value: schema.Table, + PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"], + PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"], + } + + if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { + relation.Polymorphic.Value = strings.TrimSpace(value) + } + + if relation.Polymorphic.PolymorphicType == nil { + schema.err = fmt.Errorf("invalid polymorphic type: %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type") + } + + if relation.Polymorphic.PolymorphicID == nil { + schema.err = fmt.Errorf("invalid polymorphic type: %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID") + } + + if schema.err == nil { + relation.References = append(relation.References, Reference{ + PriamryValue: relation.Polymorphic.Value, + ForeignKey: relation.Polymorphic.PolymorphicType, + }) + + primaryKeyField := schema.PrioritizedPrimaryField + if len(relation.ForeignKeys) > 0 { + if primaryKeyField = schema.LookUpField(relation.ForeignKeys[0]); primaryKeyField == nil || len(relation.ForeignKeys) > 1 { + schema.err = fmt.Errorf("invalid polymorphic foreign key: %+v for %v on field %v", relation.ForeignKeys, schema, field.Name) + } + } + relation.References = append(relation.References, Reference{ + PriamryKey: primaryKeyField, + ForeignKey: relation.Polymorphic.PolymorphicType, + OwnPriamryKey: true, + }) + } + + switch field.FieldType.Kind() { + case reflect.Struct: + relation.Type = HasOne + case reflect.Slice: + relation.Type = HasMany + } + return + } + + switch field.FieldType.Kind() { + case reflect.Struct: + schema.parseStructRelation(relation, field) + case reflect.Slice: + schema.parseSliceRelation(relation, field) + default: + schema.err = fmt.Errorf("unsupported data type: %v (in %v#%v ", field.FieldType.PkgPath(), schema, field.Name) + } } -func (schema *Schema) buildToManyRel(field *Field) { +func (schema *Schema) parseStructRelation(relation *Relationship, field *Field) error { + return nil +} + +func (schema *Schema) parseSliceRelation(relation *Relationship, field *Field) error { + return nil } diff --git a/schema/schema.go b/schema/schema.go index 5069bb44..f18cb7a6 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -1,6 +1,7 @@ package schema import ( + "fmt" "go/ast" "reflect" "strings" @@ -8,6 +9,7 @@ import ( ) type Schema struct { + Name string ModelType reflect.Type Table string PrioritizedPrimaryField *Field @@ -16,42 +18,64 @@ type Schema struct { FieldsByName map[string]*Field FieldsByDBName map[string]*Field Relationships Relationships + err error namer Namer + cacheStore sync.Map +} + +func (schema Schema) String() string { + return schema.ModelType.PkgPath() +} + +func (schema Schema) LookUpField(name string) *Field { + if field, ok := schema.FieldsByDBName[name]; ok { + return field + } + if field, ok := schema.FieldsByName[name]; ok { + return field + } + return nil } // get data type from dialector -func Parse(dest interface{}, cacheStore sync.Map, namer Namer) *Schema { +func Parse(dest interface{}, cacheStore sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } if modelType.Kind() != reflect.Struct { - return nil + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("unsupported data %+v when parsing model", dest) + } + return nil, fmt.Errorf("unsupported data type %v when parsing model", modelType.PkgPath()) } if v, ok := cacheStore.Load(modelType); ok { - return v.(*Schema) + return v.(*Schema), nil } schema := &Schema{ + Name: modelType.Name(), ModelType: modelType, Table: namer.TableName(modelType.Name()), FieldsByName: map[string]*Field{}, FieldsByDBName: map[string]*Field{}, + cacheStore: cacheStore, } - for i := 0; i < modelType.NumField(); i++ { - fieldStruct := modelType.Field(i) - if !ast.IsExported(fieldStruct.Name) { - continue + defer func() { + if schema.err != nil { + cacheStore.Delete(modelType) } + }() - field := schema.ParseField(fieldStruct) - schema.Fields = append(schema.Fields, field) - if field.EmbeddedbSchema != nil { - for _, f := range field.EmbeddedbSchema.Fields { - schema.Fields = append(schema.Fields, f) + for i := 0; i < modelType.NumField(); i++ { + if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { + field := schema.ParseField(fieldStruct) + schema.Fields = append(schema.Fields, field) + if field.EmbeddedSchema != nil { + schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...) } } } @@ -85,7 +109,12 @@ func Parse(dest interface{}, cacheStore sync.Map, namer Namer) *Schema { } schema.PrimaryFields = append(schema.PrimaryFields, field) } + + if field.DataType == "" { + defer schema.parseRelation(field) + } } - return schema + cacheStore.Store(modelType, schema) + return schema, schema.err } diff --git a/schema/utils.go b/schema/utils.go index 1b0f5eac..4f4bfa50 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -29,3 +29,12 @@ func checkTruth(val string) bool { } return true } + +func toColumns(val string) (results []string) { + if val != "" { + for _, v := range strings.Split(val, ",") { + results = append(results, strings.TrimSpace(v)) + } + } + return +} From a9c20291e495c777f9b74ee95f33285748e1c61c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 1 Feb 2020 15:23:45 +0800 Subject: [PATCH 217/881] Implement guess relation --- schema/relationship.go | 140 ++++++++++++++++++++++++++++++++--------- 1 file changed, 110 insertions(+), 30 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index 95f56f6d..5081d540 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -25,15 +25,15 @@ type Relationships struct { } type Relationship struct { - Name string - Type RelationshipType - Field *Field - Polymorphic *Polymorphic - References []Reference - Schema *Schema - FieldSchema *Schema - JoinTable *Schema - ForeignKeys, AssociationForeignKeys []string + Name string + Type RelationshipType + Field *Field + Polymorphic *Polymorphic + References []Reference + Schema *Schema + FieldSchema *Schema + JoinTable *Schema + ForeignKeys, PrimaryKeys []string } type Polymorphic struct { @@ -53,12 +53,11 @@ func (schema *Schema) parseRelation(field *Field) { var ( fieldValue = reflect.New(field.FieldType).Interface() relation = &Relationship{ - Name: field.Name, - Field: field, - Schema: schema, - Type: RelationshipType(strings.ToLower(strings.TrimSpace(field.TagSettings["REL"]))), - ForeignKeys: toColumns(field.TagSettings["FOREIGNKEY"]), - AssociationForeignKeys: toColumns(field.TagSettings["ASSOCIATION_FOREIGNKEY"]), + Name: field.Name, + Field: field, + Schema: schema, + ForeignKeys: toColumns(field.TagSettings["FOREIGNKEY"]), + PrimaryKeys: toColumns(field.TagSettings["PRIMARYKEY"]), } ) @@ -66,6 +65,8 @@ func (schema *Schema) parseRelation(field *Field) { return } + // Parse Polymorphic relations + // // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` // type User struct { // Toys []Toy `gorm:"polymorphic:Owner;"` @@ -89,11 +90,11 @@ func (schema *Schema) parseRelation(field *Field) { } if relation.Polymorphic.PolymorphicType == nil { - schema.err = fmt.Errorf("invalid polymorphic type: %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type") + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type") } if relation.Polymorphic.PolymorphicID == nil { - schema.err = fmt.Errorf("invalid polymorphic type: %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID") + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID") } if schema.err == nil { @@ -105,7 +106,7 @@ func (schema *Schema) parseRelation(field *Field) { primaryKeyField := schema.PrioritizedPrimaryField if len(relation.ForeignKeys) > 0 { if primaryKeyField = schema.LookUpField(relation.ForeignKeys[0]); primaryKeyField == nil || len(relation.ForeignKeys) > 1 { - schema.err = fmt.Errorf("invalid polymorphic foreign key: %+v for %v on field %v", relation.ForeignKeys, schema, field.Name) + schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.ForeignKeys, schema, field.Name) } } relation.References = append(relation.References, Reference{ @@ -115,29 +116,108 @@ func (schema *Schema) parseRelation(field *Field) { }) } + relation.Type = "has" + } else { + switch field.FieldType.Kind() { + case reflect.Struct: + schema.guessRelation(relation, field, true) + case reflect.Slice: + schema.guessRelation(relation, field, true) + default: + schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) + } + } + + if relation.Type == "has" { switch field.FieldType.Kind() { case reflect.Struct: relation.Type = HasOne case reflect.Slice: relation.Type = HasMany } + } +} + +func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) { + var ( + primaryFields, foreignFields []*Field + primarySchema, foreignSchema = schema, relation.FieldSchema + ) + + if !guessHas { + primarySchema, foreignSchema = relation.FieldSchema, schema + } + + reguessOrErr := func(err string, args ...interface{}) { + if guessHas { + schema.guessRelation(relation, field, false) + } else { + schema.err = fmt.Errorf(err, args...) + } + } + + if len(relation.ForeignKeys) > 0 { + for _, foreignKey := range relation.ForeignKeys { + if f := foreignSchema.LookUpField(foreignKey); f != nil { + foreignFields = append(foreignFields, f) + } else { + reguessOrErr("unsupported relations %v for %v on field %v with foreign keys %v", relation.FieldSchema, schema, field.Name, relation.ForeignKeys) + return + } + } + } else { + for _, primaryField := range primarySchema.PrimaryFields { + if f := foreignSchema.LookUpField(field.Name + primaryField.Name); f != nil { + foreignFields = append(foreignFields, f) + primaryFields = append(primaryFields, primaryField) + } + } + } + + if len(foreignFields) == 0 { + reguessOrErr("failed to guess %v's relations with %v's field %v", relation.FieldSchema, schema, field.Name) return + } else if len(relation.PrimaryKeys) > 0 { + for idx, primaryKey := range relation.PrimaryKeys { + if f := primarySchema.LookUpField(primaryKey); f != nil { + if len(primaryFields) < idx+1 { + primaryFields = append(primaryFields, f) + } else if f != primaryFields[idx] { + reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.PrimaryKeys) + return + } + } else { + reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.PrimaryKeys) + return + } + } + } else if len(primaryFields) == 0 { + if len(foreignFields) == 1 { + primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField) + } else if len(primarySchema.PrimaryFields) == len(foreignFields) { + primaryFields = append(primaryFields, primarySchema.PrimaryFields...) + } else { + reguessOrErr("unsupported relations %v for %v on field %v", relation.FieldSchema, schema, field.Name) + return + } } - switch field.FieldType.Kind() { - case reflect.Struct: - schema.parseStructRelation(relation, field) - case reflect.Slice: - schema.parseSliceRelation(relation, field) - default: - schema.err = fmt.Errorf("unsupported data type: %v (in %v#%v ", field.FieldType.PkgPath(), schema, field.Name) + // build references + for idx, foreignField := range foreignFields { + relation.References = append(relation.References, Reference{ + PriamryKey: primaryFields[idx], + ForeignKey: foreignField, + OwnPriamryKey: schema == primarySchema, + }) + } + + if guessHas { + relation.Type = "has" + } else { + relation.Type = "belongs_to" } } -func (schema *Schema) parseStructRelation(relation *Relationship, field *Field) error { - return nil -} - -func (schema *Schema) parseSliceRelation(relation *Relationship, field *Field) error { +func (schema *Schema) parseMany2ManyRelation(relation *Relationship, field *Field) error { return nil } From fd9b688084d3021927721b8925a655d19762918f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 1 Feb 2020 18:02:19 +0800 Subject: [PATCH 218/881] Implement parse many2many relation --- schema/field.go | 6 +- schema/naming.go | 6 -- schema/relationship.go | 162 ++++++++++++++++++++++++++--------------- schema/utils.go | 5 ++ schema/utils_test.go | 23 ++++++ 5 files changed, 133 insertions(+), 69 deletions(-) create mode 100644 schema/utils_test.go diff --git a/schema/field.go b/schema/field.go index 005fd4e3..d2747100 100644 --- a/schema/field.go +++ b/schema/field.go @@ -103,11 +103,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DBName = dbName } - if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && checkTruth(val) { + if val, ok := field.TagSettings["PRIMARYKEY"]; ok && checkTruth(val) { field.PrimaryKey = true } - if val, ok := field.TagSettings["AUTO_INCREMENT"]; ok && checkTruth(val) { + if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && checkTruth(val) { field.AutoIncrement = true field.HasDefaultValue = true } @@ -180,7 +180,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { for _, ef := range field.EmbeddedSchema.Fields { ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) - if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok { + if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { ef.DBName = prefix + ef.DBName } diff --git a/schema/naming.go b/schema/naming.go index 6df80d2a..5a2311b6 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -13,7 +13,6 @@ type Namer interface { TableName(table string) string ColumnName(column string) string JoinTableName(table string) string - JoinTableColumnName(table, column string) string } // NamingStrategy tables, columns naming strategy @@ -40,11 +39,6 @@ func (ns NamingStrategy) JoinTableName(str string) string { return ns.TablePrefix + toDBName(str) } -// JoinTableColumnName convert string to join table column name -func (ns NamingStrategy) JoinTableColumnName(referenceTable, referenceColumn string) string { - return inflection.Singular(toDBName(referenceTable)) + toDBName(referenceColumn) -} - var ( smap sync.Map // https://github.com/golang/lint/blob/master/lint.go#L770 diff --git a/schema/relationship.go b/schema/relationship.go index 5081d540..5195589d 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -57,7 +57,7 @@ func (schema *Schema) parseRelation(field *Field) { Field: field, Schema: schema, ForeignKeys: toColumns(field.TagSettings["FOREIGNKEY"]), - PrimaryKeys: toColumns(field.TagSettings["PRIMARYKEY"]), + PrimaryKeys: toColumns(field.TagSettings["REFERENCES"]), } ) @@ -65,63 +65,13 @@ func (schema *Schema) parseRelation(field *Field) { return } - // Parse Polymorphic relations - // - // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` - // type User struct { - // Toys []Toy `gorm:"polymorphic:Owner;"` - // } - // type Pet struct { - // Toy Toy `gorm:"polymorphic:Owner;"` - // } - // type Toy struct { - // OwnerID int - // OwnerType string - // } if polymorphic, _ := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { - relation.Polymorphic = &Polymorphic{ - Value: schema.Table, - PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"], - PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"], - } - - if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { - relation.Polymorphic.Value = strings.TrimSpace(value) - } - - if relation.Polymorphic.PolymorphicType == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type") - } - - if relation.Polymorphic.PolymorphicID == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID") - } - - if schema.err == nil { - relation.References = append(relation.References, Reference{ - PriamryValue: relation.Polymorphic.Value, - ForeignKey: relation.Polymorphic.PolymorphicType, - }) - - primaryKeyField := schema.PrioritizedPrimaryField - if len(relation.ForeignKeys) > 0 { - if primaryKeyField = schema.LookUpField(relation.ForeignKeys[0]); primaryKeyField == nil || len(relation.ForeignKeys) > 1 { - schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.ForeignKeys, schema, field.Name) - } - } - relation.References = append(relation.References, Reference{ - PriamryKey: primaryKeyField, - ForeignKey: relation.Polymorphic.PolymorphicType, - OwnPriamryKey: true, - }) - } - - relation.Type = "has" + schema.buildPolymorphicRelation(relation, field, polymorphic) + } else if many2many, _ := field.TagSettings["MANY2MANY"]; many2many != "" { + schema.buildMany2ManyRelation(relation, field, many2many) } else { switch field.FieldType.Kind() { - case reflect.Struct: - schema.guessRelation(relation, field, true) - case reflect.Slice: + case reflect.Struct, reflect.Slice: schema.guessRelation(relation, field, true) default: schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) @@ -138,6 +88,102 @@ func (schema *Schema) parseRelation(field *Field) { } } +// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` +// type User struct { +// Toys []Toy `gorm:"polymorphic:Owner;"` +// } +// type Pet struct { +// Toy Toy `gorm:"polymorphic:Owner;"` +// } +// type Toy struct { +// OwnerID int +// OwnerType string +// } +func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) { + relation.Polymorphic = &Polymorphic{ + Value: schema.Table, + PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"], + PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"], + } + + if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { + relation.Polymorphic.Value = strings.TrimSpace(value) + } + + if relation.Polymorphic.PolymorphicType == nil { + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type") + } + + if relation.Polymorphic.PolymorphicID == nil { + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID") + } + + if schema.err == nil { + relation.References = append(relation.References, Reference{ + PriamryValue: relation.Polymorphic.Value, + ForeignKey: relation.Polymorphic.PolymorphicType, + }) + + primaryKeyField := schema.PrioritizedPrimaryField + if len(relation.ForeignKeys) > 0 { + if primaryKeyField = schema.LookUpField(relation.ForeignKeys[0]); primaryKeyField == nil || len(relation.ForeignKeys) > 1 { + schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.ForeignKeys, schema, field.Name) + } + } + relation.References = append(relation.References, Reference{ + PriamryKey: primaryKeyField, + ForeignKey: relation.Polymorphic.PolymorphicType, + OwnPriamryKey: true, + }) + } + + relation.Type = "has" +} + +func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) { + relation.Type = Many2Many + + var ( + joinTableFields []reflect.StructField + fieldsMap = map[string]*Field{} + ) + + for _, s := range []*Schema{schema, relation.Schema} { + for _, primaryField := range s.PrimaryFields { + fieldName := s.Name + primaryField.Name + if _, ok := fieldsMap[fieldName]; ok { + if field.Name != s.Name { + fieldName = field.Name + primaryField.Name + } else { + fieldName = s.Name + primaryField.Name + "Reference" + } + } + + fieldsMap[fieldName] = primaryField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: fieldName, + PkgPath: primaryField.StructField.PkgPath, + Type: primaryField.StructField.Type, + Tag: removeSettingFromTag(primaryField.StructField.Tag, "column"), + }) + } + } + + relation.JoinTable, schema.err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer) + relation.JoinTable.Name = many2many + relation.JoinTable.Table = schema.namer.JoinTableName(many2many) + + // build references + for _, f := range relation.JoinTable.Fields { + relation.References = append(relation.References, Reference{ + PriamryKey: fieldsMap[f.Name], + ForeignKey: f, + OwnPriamryKey: schema == fieldsMap[f.Name].Schema, + }) + } + return +} + func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) { var ( primaryFields, foreignFields []*Field @@ -214,10 +260,6 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH if guessHas { relation.Type = "has" } else { - relation.Type = "belongs_to" + relation.Type = BelongsTo } } - -func (schema *Schema) parseMany2ManyRelation(relation *Relationship, field *Field) error { - return nil -} diff --git a/schema/utils.go b/schema/utils.go index 4f4bfa50..f2dd90af 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -2,6 +2,7 @@ package schema import ( "reflect" + "regexp" "strings" ) @@ -38,3 +39,7 @@ func toColumns(val string) (results []string) { } return } + +func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag { + return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`:.*?)(;|("))`).ReplaceAllString(string(tag), "${1}${4}")) +} diff --git a/schema/utils_test.go b/schema/utils_test.go new file mode 100644 index 00000000..e70169bf --- /dev/null +++ b/schema/utils_test.go @@ -0,0 +1,23 @@ +package schema + +import ( + "reflect" + "testing" +) + +func TestRemoveSettingFromTag(t *testing.T) { + tags := map[string]string{ + `gorm:"before:value;column:db;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column:db;" other:"before:value;column:db;after:value"`: `gorm:"before:value;" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column:db" other:"before:value;column:db;after:value"`: `gorm:"before:value;" other:"before:value;column:db;after:value"`, + `gorm:"column:db" other:"before:value;column:db;after:value"`: `gorm:"" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column:db ;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column:db; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`, + } + + for k, v := range tags { + if string(removeSettingFromTag(reflect.StructTag(k), "column")) != v { + t.Errorf("%v after removeSettingFromTag should equal %v, but got %v", k, v, removeSettingFromTag(reflect.StructTag(k), "column")) + } + } +} From 14724ddeae2e269093327f0d5f982f690aeee739 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 1 Feb 2020 20:18:25 +0800 Subject: [PATCH 219/881] Add tests model definition and basic fields tests --- helpers.go | 2 +- schema/field.go | 10 +++--- schema/schema.go | 8 +++-- schema/schema_test.go | 78 +++++++++++++++++++++++++++++++++++++++++ schema/utils.go | 2 +- tests/callbacks_test.go | 2 +- tests/model.go | 58 ++++++++++++++++++++++++++++++ 7 files changed, 150 insertions(+), 10 deletions(-) create mode 100644 schema/schema_test.go create mode 100644 tests/model.go diff --git a/helpers.go b/helpers.go index 8f9df009..77bbece8 100644 --- a/helpers.go +++ b/helpers.go @@ -22,7 +22,7 @@ var ( // gorm.Model // } type Model struct { - ID uint `gorm:"primary_key"` + ID uint `gorm:"primarykey"` CreatedAt time.Time UpdatedAt time.Time DeletedAt *time.Time `gorm:"index"` diff --git a/schema/field.go b/schema/field.go index d2747100..47250aa8 100644 --- a/schema/field.go +++ b/schema/field.go @@ -54,7 +54,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { Creatable: true, Updatable: true, Tag: fieldStruct.Tag, - TagSettings: parseTagSetting(fieldStruct.Tag), + TagSettings: ParseTagSetting(fieldStruct.Tag), } for field.FieldType.Kind() == reflect.Ptr { @@ -84,7 +84,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } // copy tag settings from valuer - for key, value := range parseTagSetting(field.FieldType.Field(i).Tag) { + for key, value := range ParseTagSetting(field.FieldType.Field(i).Tag) { if _, ok := field.TagSettings[key]; !ok { field.TagSettings[key] = value } @@ -141,7 +141,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DBDataType = val } - switch fieldValue.Kind() { + switch fieldValue.Elem().Kind() { case reflect.Bool: field.DataType = Bool case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -153,7 +153,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.String: field.DataType = String case reflect.Struct: - if _, ok := fieldValue.Interface().(time.Time); ok { + if _, ok := fieldValue.Interface().(*time.Time); ok { field.DataType = Time } case reflect.Array, reflect.Slice: @@ -176,7 +176,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { - field.EmbeddedSchema, schema.err = Parse(fieldValue, sync.Map{}, schema.namer) + field.EmbeddedSchema, schema.err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer) for _, ef := range field.EmbeddedSchema.Fields { ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) diff --git a/schema/schema.go b/schema/schema.go index f18cb7a6..0b5548e3 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -6,6 +6,8 @@ import ( "reflect" "strings" "sync" + + "github.com/jinzhu/gorm/logger" ) type Schema struct { @@ -20,7 +22,7 @@ type Schema struct { Relationships Relationships err error namer Namer - cacheStore sync.Map + cacheStore *sync.Map } func (schema Schema) String() string { @@ -38,7 +40,7 @@ func (schema Schema) LookUpField(name string) *Field { } // get data type from dialector -func Parse(dest interface{}, cacheStore sync.Map, namer Namer) (*Schema, error) { +func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() @@ -62,10 +64,12 @@ func Parse(dest interface{}, cacheStore sync.Map, namer Namer) (*Schema, error) FieldsByName: map[string]*Field{}, FieldsByDBName: map[string]*Field{}, cacheStore: cacheStore, + namer: namer, } defer func() { if schema.err != nil { + logger.Default.Error(schema.err.Error()) cacheStore.Delete(modelType) } }() diff --git a/schema/schema_test.go b/schema/schema_test.go new file mode 100644 index 00000000..eefac98b --- /dev/null +++ b/schema/schema_test.go @@ -0,0 +1,78 @@ +package schema_test + +import ( + "reflect" + "sync" + "testing" + + "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/gorm/tests" +) + +func TestParseSchema(t *testing.T) { + cacheMap := sync.Map{} + user, err := schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{}) + + if err != nil { + t.Fatalf("failed to parse user, got error %v", err) + } + + checkSchemaFields(t, user) +} + +func checkSchemaFields(t *testing.T, s *schema.Schema) { + fields := []schema.Field{ + schema.Field{ + Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, + PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, + }, + schema.Field{Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time}, + schema.Field{Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time}, + schema.Field{Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time}, + schema.Field{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, + schema.Field{Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint}, + schema.Field{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, + schema.Field{Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int}, + schema.Field{Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint}, + } + + for _, f := range fields { + f.Creatable = true + f.Updatable = true + if f.TagSettings == nil { + if f.Tag != "" { + f.TagSettings = schema.ParseTagSetting(f.Tag) + } else { + f.TagSettings = map[string]string{} + } + } + + if foundField, ok := s.FieldsByName[f.Name]; !ok { + t.Errorf("schema %v failed to look up field with name %v", s, f.Name) + } else { + checkSchemaField(t, foundField, f) + + if field, ok := s.FieldsByDBName[f.DBName]; !ok || foundField != field { + t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + } + + for _, name := range []string{f.DBName, f.Name} { + if field := s.LookUpField(name); field == nil || foundField != field { + t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + } + } + } + } +} + +func checkSchemaField(t *testing.T, parsedField *schema.Field, field schema.Field) { + equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"} + + for _, name := range equalFieldNames { + got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface() + expects := reflect.ValueOf(field).FieldByName(name).Interface() + if !reflect.DeepEqual(got, expects) { + t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got) + } + } +} diff --git a/schema/utils.go b/schema/utils.go index f2dd90af..4774fd75 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -6,7 +6,7 @@ import ( "strings" ) -func parseTagSetting(tags reflect.StructTag) map[string]string { +func ParseTagSetting(tags reflect.StructTag) map[string]string { setting := map[string]string{} for _, value := range strings.Split(tags.Get("gorm"), ";") { diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 878384a7..af975a55 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -1,4 +1,4 @@ -package gorm_test +package tests_test import ( "fmt" diff --git a/tests/model.go b/tests/model.go new file mode 100644 index 00000000..0be3e97a --- /dev/null +++ b/tests/model.go @@ -0,0 +1,58 @@ +package tests + +import ( + "database/sql" + "time" + + "github.com/jinzhu/gorm" +) + +// User has one `Account` (has one), many `Pets` (has many) and `Toys` (has many - polymorphic) +// He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table) +// He speaks many languages (many to many) and has many friends (many to many - single-table) +// His pet also has one Toy (has one - polymorphic) +type User struct { + gorm.Model + Name string + Age uint + Birthday *time.Time + Account Account + Pets []*Pet + Toys []Toy `gorm:"polymorphic:Owner"` + CompanyID *int + Company Company + ManagerID uint + Manager *User + Team []User `foreignkey:ManagerID` + Friends []*User `gorm:"many2many:user_friends"` + Languages []Language `gorm:"many2many:user_speaks"` +} + +type Account struct { + gorm.Model + UserID sql.NullInt64 + Number string +} + +type Pet struct { + gorm.Model + UserID uint + Name string + Toy Toy `gorm:"polymorphic:Owner;"` +} + +type Toy struct { + gorm.Model + OwnerID string + OwnerType string +} + +type Company struct { + ID uint + Name string +} + +type Language struct { + Code string `gorm:primarykey` + Name string +} From a4a0895a8589acc0116fc84eb4ce0139f52917a7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 1 Feb 2020 21:48:06 +0800 Subject: [PATCH 220/881] Test parse schema relations --- logger/logger.go | 8 +-- schema/field.go | 7 +- schema/relationship.go | 58 ++++++++++++----- schema/schema.go | 62 ++++++++++++------ schema/schema_helper_test.go | 123 +++++++++++++++++++++++++++++++++++ schema/schema_test.go | 77 +++++++--------------- tests/model.go | 2 +- 7 files changed, 242 insertions(+), 95 deletions(-) create mode 100644 schema/schema_helper_test.go diff --git a/logger/logger.go b/logger/logger.go index 9d6e70bf..cad9be16 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -8,7 +8,7 @@ import ( type LogLevel int -var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", 0)} +var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", log.LstdFlags)} const ( Info LogLevel = iota + 1 @@ -40,21 +40,21 @@ func (logger Logger) LogMode(level LogLevel) Interface { // Info print info func (logger Logger) Info(msg string, data ...interface{}) { - if logger.logLevel >= Info { + if logger.logLevel <= Info { logger.Print("[info] " + fmt.Sprintf(msg, data...)) } } // Warn print warn messages func (logger Logger) Warn(msg string, data ...interface{}) { - if logger.logLevel >= Warn { + if logger.logLevel <= Warn { logger.Print("[warn] " + fmt.Sprintf(msg, data...)) } } // Error print error messages func (logger Logger) Error(msg string, data ...interface{}) { - if logger.logLevel >= Error { + if logger.logLevel <= Error { logger.Print("[error] " + fmt.Sprintf(msg, data...)) } } diff --git a/schema/field.go b/schema/field.go index 47250aa8..f1cd022b 100644 --- a/schema/field.go +++ b/schema/field.go @@ -176,7 +176,12 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { - field.EmbeddedSchema, schema.err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer) + var err error + field.Creatable = false + field.Updatable = false + if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { + schema.err = err + } for _, ef := range field.EmbeddedSchema.Fields { ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) diff --git a/schema/relationship.go b/schema/relationship.go index 5195589d..358d13e7 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -33,7 +33,7 @@ type Relationship struct { Schema *Schema FieldSchema *Schema JoinTable *Schema - ForeignKeys, PrimaryKeys []string + foreignKeys, primaryKeys []string } type Polymorphic struct { @@ -51,17 +51,19 @@ type Reference struct { func (schema *Schema) parseRelation(field *Field) { var ( + err error fieldValue = reflect.New(field.FieldType).Interface() relation = &Relationship{ Name: field.Name, Field: field, Schema: schema, - ForeignKeys: toColumns(field.TagSettings["FOREIGNKEY"]), - PrimaryKeys: toColumns(field.TagSettings["REFERENCES"]), + foreignKeys: toColumns(field.TagSettings["FOREIGNKEY"]), + primaryKeys: toColumns(field.TagSettings["REFERENCES"]), } ) - if relation.FieldSchema, schema.err = Parse(fieldValue, schema.cacheStore, schema.namer); schema.err != nil { + if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { + schema.err = err return } @@ -86,6 +88,20 @@ func (schema *Schema) parseRelation(field *Field) { relation.Type = HasMany } } + + if schema.err == nil { + schema.Relationships.Relations[relation.Name] = relation + switch relation.Type { + case HasOne: + schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation) + case HasMany: + schema.Relationships.HasMany = append(schema.Relationships.HasMany, relation) + case BelongsTo: + schema.Relationships.BelongsTo = append(schema.Relationships.BelongsTo, relation) + case Many2Many: + schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation) + } + } } // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` @@ -125,9 +141,9 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi }) primaryKeyField := schema.PrioritizedPrimaryField - if len(relation.ForeignKeys) > 0 { - if primaryKeyField = schema.LookUpField(relation.ForeignKeys[0]); primaryKeyField == nil || len(relation.ForeignKeys) > 1 { - schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.ForeignKeys, schema, field.Name) + if len(relation.foreignKeys) > 0 { + if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { + schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name) } } relation.References = append(relation.References, Reference{ @@ -144,6 +160,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel relation.Type = Many2Many var ( + err error joinTableFields []reflect.StructField fieldsMap = map[string]*Field{} ) @@ -169,7 +186,9 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } } - relation.JoinTable, schema.err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer) + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { + schema.err = err + } relation.JoinTable.Name = many2many relation.JoinTable.Table = schema.namer.JoinTableName(many2many) @@ -202,18 +221,23 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH } } - if len(relation.ForeignKeys) > 0 { - for _, foreignKey := range relation.ForeignKeys { + if len(relation.foreignKeys) > 0 { + for _, foreignKey := range relation.foreignKeys { if f := foreignSchema.LookUpField(foreignKey); f != nil { foreignFields = append(foreignFields, f) } else { - reguessOrErr("unsupported relations %v for %v on field %v with foreign keys %v", relation.FieldSchema, schema, field.Name, relation.ForeignKeys) + reguessOrErr("unsupported relations %v for %v on field %v with foreign keys %v", relation.FieldSchema, schema, field.Name, relation.foreignKeys) return } } } else { for _, primaryField := range primarySchema.PrimaryFields { - if f := foreignSchema.LookUpField(field.Name + primaryField.Name); f != nil { + lookUpName := schema.Name + primaryField.Name + if !guessHas { + lookUpName = field.Name + primaryField.Name + } + + if f := foreignSchema.LookUpField(lookUpName); f != nil { foreignFields = append(foreignFields, f) primaryFields = append(primaryFields, primaryField) } @@ -221,19 +245,19 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH } if len(foreignFields) == 0 { - reguessOrErr("failed to guess %v's relations with %v's field %v", relation.FieldSchema, schema, field.Name) + reguessOrErr("failed to guess %v's relations with %v's field %v 1 g %v", relation.FieldSchema, schema, field.Name, guessHas) return - } else if len(relation.PrimaryKeys) > 0 { - for idx, primaryKey := range relation.PrimaryKeys { + } else if len(relation.primaryKeys) > 0 { + for idx, primaryKey := range relation.primaryKeys { if f := primarySchema.LookUpField(primaryKey); f != nil { if len(primaryFields) < idx+1 { primaryFields = append(primaryFields, f) } else if f != primaryFields[idx] { - reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.PrimaryKeys) + reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys) return } } else { - reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.PrimaryKeys) + reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys) return } } diff --git a/schema/schema.go b/schema/schema.go index 0b5548e3..d3404312 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -4,7 +4,6 @@ import ( "fmt" "go/ast" "reflect" - "strings" "sync" "github.com/jinzhu/gorm/logger" @@ -26,7 +25,7 @@ type Schema struct { } func (schema Schema) String() string { - return schema.ModelType.PkgPath() + return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name()) } func (schema Schema) LookUpField(name string) *Field { @@ -63,6 +62,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) Table: namer.TableName(modelType.Name()), FieldsByName: map[string]*Field{}, FieldsByDBName: map[string]*Field{}, + Relationships: Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore, namer: namer, } @@ -76,10 +76,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) for i := 0; i < modelType.NumField(); i++ { if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { - field := schema.ParseField(fieldStruct) - schema.Fields = append(schema.Fields, field) - if field.EmbeddedSchema != nil { + if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil { schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...) + } else { + schema.Fields = append(schema.Fields, field) } } } @@ -94,6 +94,27 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) { schema.FieldsByDBName[field.DBName] = field schema.FieldsByName[field.Name] = field + + if v != nil && v.PrimaryKey { + if schema.PrioritizedPrimaryField == v { + schema.PrioritizedPrimaryField = nil + } + + for idx, f := range schema.PrimaryFields { + if f == v { + schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) + } else if schema.PrioritizedPrimaryField == nil { + schema.PrioritizedPrimaryField = f + } + } + } + + if field.PrimaryKey { + if schema.PrioritizedPrimaryField == nil { + schema.PrioritizedPrimaryField = field + } + schema.PrimaryFields = append(schema.PrimaryFields, field) + } } } @@ -102,23 +123,26 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } - for db, field := range schema.FieldsByDBName { - if strings.ToLower(db) == "id" { - schema.PrioritizedPrimaryField = field - } - - if field.PrimaryKey { - if schema.PrioritizedPrimaryField == nil { - schema.PrioritizedPrimaryField = field - } - schema.PrimaryFields = append(schema.PrimaryFields, field) - } - - if field.DataType == "" { - defer schema.parseRelation(field) + if f := schema.LookUpField("id"); f != nil { + if f.PrimaryKey { + schema.PrioritizedPrimaryField = f + } else if len(schema.PrimaryFields) == 0 { + f.PrimaryKey = true + schema.PrioritizedPrimaryField = f + schema.PrimaryFields = append(schema.PrimaryFields, f) } } cacheStore.Store(modelType, schema) + + // parse relations for unidentified fields + for _, field := range schema.Fields { + if field.DataType == "" && field.Creatable { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err + } + } + } + return schema, schema.err } diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go new file mode 100644 index 00000000..eb0085c2 --- /dev/null +++ b/schema/schema_helper_test.go @@ -0,0 +1,123 @@ +package schema_test + +import ( + "reflect" + "testing" + + "github.com/jinzhu/gorm/schema" +) + +func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { + equalFieldNames := []string{"Name", "Table"} + + for _, name := range equalFieldNames { + got := reflect.ValueOf(s).Elem().FieldByName(name).Interface() + expects := reflect.ValueOf(v).FieldByName(name).Interface() + if !reflect.DeepEqual(got, expects) { + t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got) + } + } + + for idx, field := range primaryFields { + var found bool + for _, f := range s.PrimaryFields { + if f.Name == field { + found = true + } + } + + if idx == 0 { + if field != s.PrioritizedPrimaryField.Name { + t.Errorf("schema %v prioritized primary field should be %v, but got %v", s, field, s.PrioritizedPrimaryField.Name) + } + } + + if !found { + t.Errorf("schema %v failed to found priamry key: %v", s, field) + } + } +} + +func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*schema.Field)) { + if fc != nil { + fc(f) + } + + if f.TagSettings == nil { + if f.Tag != "" { + f.TagSettings = schema.ParseTagSetting(f.Tag) + } else { + f.TagSettings = map[string]string{} + } + } + + if parsedField, ok := s.FieldsByName[f.Name]; !ok { + t.Errorf("schema %v failed to look up field with name %v", s, f.Name) + } else { + equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"} + + for _, name := range equalFieldNames { + got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface() + expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface() + if !reflect.DeepEqual(got, expects) { + t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got) + } + } + + if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { + t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + } + + for _, name := range []string{f.DBName, f.Name} { + if field := s.LookUpField(name); field == nil || parsedField != field { + t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + } + } + + if f.PrimaryKey { + var found bool + for _, primaryField := range s.PrimaryFields { + if primaryField == parsedField { + found = true + } + } + + if !found { + t.Errorf("schema %v doesn't include field %v", s, f.Name) + } + } + } +} + +type Relation struct { + Name string + Type schema.RelationshipType + Polymorphic schema.Polymorphic + Schema string + FieldSchema string + JoinTable string + JoinTableFields []schema.Field + References []Reference +} + +type Reference struct { + PrimaryKey string + PrimarySchema string + ForeignKey string + ForeignSchema string + OwnPriamryKey bool +} + +func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { + if r, ok := s.Relationships.Relations[relation.Name]; ok { + if r.Name != relation.Name { + t.Errorf("schema %v relation name expects %v, but got %v", s, relation.Name, r.Name) + } + + if r.Type != relation.Type { + t.Errorf("schema %v relation name expects %v, but got %v", s, relation.Type, r.Type) + } + } else { + t.Errorf("schema %v failed to find relations by name %v", s, relation.Name) + } +} diff --git a/schema/schema_test.go b/schema/schema_test.go index eefac98b..8ea219e1 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -1,7 +1,6 @@ package schema_test import ( - "reflect" "sync" "testing" @@ -11,68 +10,40 @@ import ( func TestParseSchema(t *testing.T) { cacheMap := sync.Map{} - user, err := schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{}) + user, err := schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user, got error %v", err) } - checkSchemaFields(t, user) -} + // check schema + checkSchema(t, user, schema.Schema{Name: "User", Table: "users"}, []string{"ID"}) -func checkSchemaFields(t *testing.T, s *schema.Schema) { + // check fields fields := []schema.Field{ - schema.Field{ - Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, - PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, - }, - schema.Field{Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time}, - schema.Field{Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time}, - schema.Field{Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time}, - schema.Field{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, - schema.Field{Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint}, - schema.Field{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, - schema.Field{Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int}, - schema.Field{Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint}, + {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, + {Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time}, + {Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time}, + {Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time}, + {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, + {Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint}, + {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, + {Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int}, + {Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint}, } for _, f := range fields { - f.Creatable = true - f.Updatable = true - if f.TagSettings == nil { - if f.Tag != "" { - f.TagSettings = schema.ParseTagSetting(f.Tag) - } else { - f.TagSettings = map[string]string{} - } - } + checkSchemaField(t, user, &f, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + }) + } - if foundField, ok := s.FieldsByName[f.Name]; !ok { - t.Errorf("schema %v failed to look up field with name %v", s, f.Name) - } else { - checkSchemaField(t, foundField, f) - - if field, ok := s.FieldsByDBName[f.DBName]; !ok || foundField != field { - t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) - } - - for _, name := range []string{f.DBName, f.Name} { - if field := s.LookUpField(name); field == nil || foundField != field { - t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) - } - } - } - } -} - -func checkSchemaField(t *testing.T, parsedField *schema.Field, field schema.Field) { - equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"} - - for _, name := range equalFieldNames { - got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface() - expects := reflect.ValueOf(field).FieldByName(name).Interface() - if !reflect.DeepEqual(got, expects) { - t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got) - } + // check relations + relations := []Relation{ + {Name: "Pets", Type: schema.HasMany, Schema: "User", FieldSchema: "Pet", References: []Reference{{"ID", "User", "UserID", "Pet", true}}}, + } + for _, relation := range relations { + checkSchemaRelation(t, user, relation) } } diff --git a/tests/model.go b/tests/model.go index 0be3e97a..e2b69abc 100644 --- a/tests/model.go +++ b/tests/model.go @@ -23,7 +23,7 @@ type User struct { Company Company ManagerID uint Manager *User - Team []User `foreignkey:ManagerID` + Team []User `gorm:"foreignkey:ManagerID"` Friends []*User `gorm:"many2many:user_friends"` Languages []Language `gorm:"many2many:user_speaks"` } From 3cbd233758499f55bebf640264a2158aafe07096 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 2 Feb 2020 00:03:56 +0800 Subject: [PATCH 221/881] Add more tests for parse schema relations --- schema/field.go | 2 + schema/naming.go | 6 +-- schema/relationship.go | 31 ++++++----- schema/schema.go | 5 +- schema/schema_helper_test.go | 100 +++++++++++++++++++++++++++++++---- schema/schema_test.go | 55 ++++++++++++++++++- tests/model.go | 4 +- 7 files changed, 172 insertions(+), 31 deletions(-) diff --git a/schema/field.go b/schema/field.go index f1cd022b..570b3c50 100644 --- a/schema/field.go +++ b/schema/field.go @@ -55,6 +55,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { Updatable: true, Tag: fieldStruct.Tag, TagSettings: ParseTagSetting(fieldStruct.Tag), + Schema: schema, } for field.FieldType.Kind() == reflect.Ptr { @@ -183,6 +184,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { schema.err = err } for _, ef := range field.EmbeddedSchema.Fields { + ef.Schema = schema ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { diff --git a/schema/naming.go b/schema/naming.go index 5a2311b6..e6a5625e 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -11,7 +11,7 @@ import ( // Namer namer interface type Namer interface { TableName(table string) string - ColumnName(column string) string + ColumnName(table, column string) string JoinTableName(table string) string } @@ -30,13 +30,13 @@ func (ns NamingStrategy) TableName(str string) string { } // ColumnName convert string to column name -func (ns NamingStrategy) ColumnName(str string) string { +func (ns NamingStrategy) ColumnName(table, str string) string { return toDBName(str) } // JoinTableName convert string to join table name func (ns NamingStrategy) JoinTableName(str string) string { - return ns.TablePrefix + toDBName(str) + return ns.TablePrefix + inflection.Plural(toDBName(str)) } var ( diff --git a/schema/relationship.go b/schema/relationship.go index 358d13e7..b6aaefbd 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -4,6 +4,8 @@ import ( "fmt" "reflect" "strings" + + "github.com/jinzhu/inflection" ) // RelationshipType relationship type @@ -43,10 +45,10 @@ type Polymorphic struct { } type Reference struct { - PriamryKey *Field - PriamryValue string + PrimaryKey *Field + PrimaryValue string ForeignKey *Field - OwnPriamryKey bool + OwnPrimaryKey bool } func (schema *Schema) parseRelation(field *Field) { @@ -136,7 +138,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi if schema.err == nil { relation.References = append(relation.References, Reference{ - PriamryValue: relation.Polymorphic.Value, + PrimaryValue: relation.Polymorphic.Value, ForeignKey: relation.Polymorphic.PolymorphicType, }) @@ -147,9 +149,9 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi } } relation.References = append(relation.References, Reference{ - PriamryKey: primaryKeyField, - ForeignKey: relation.Polymorphic.PolymorphicType, - OwnPriamryKey: true, + PrimaryKey: primaryKeyField, + ForeignKey: relation.Polymorphic.PolymorphicID, + OwnPrimaryKey: true, }) } @@ -163,17 +165,20 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel err error joinTableFields []reflect.StructField fieldsMap = map[string]*Field{} + ownFieldsMap = map[string]bool{} // fix self join many2many ) - for _, s := range []*Schema{schema, relation.Schema} { + for _, s := range []*Schema{schema, relation.FieldSchema} { for _, primaryField := range s.PrimaryFields { fieldName := s.Name + primaryField.Name if _, ok := fieldsMap[fieldName]; ok { if field.Name != s.Name { - fieldName = field.Name + primaryField.Name + fieldName = inflection.Singular(field.Name) + primaryField.Name } else { fieldName = s.Name + primaryField.Name + "Reference" } + } else { + ownFieldsMap[fieldName] = true } fieldsMap[fieldName] = primaryField @@ -195,9 +200,9 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel // build references for _, f := range relation.JoinTable.Fields { relation.References = append(relation.References, Reference{ - PriamryKey: fieldsMap[f.Name], + PrimaryKey: fieldsMap[f.Name], ForeignKey: f, - OwnPriamryKey: schema == fieldsMap[f.Name].Schema, + OwnPrimaryKey: schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name], }) } return @@ -275,9 +280,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH // build references for idx, foreignField := range foreignFields { relation.References = append(relation.References, Reference{ - PriamryKey: primaryFields[idx], + PrimaryKey: primaryFields[idx], ForeignKey: foreignField, - OwnPriamryKey: schema == primarySchema, + OwnPrimaryKey: schema == primarySchema && guessHas, }) } diff --git a/schema/schema.go b/schema/schema.go index d3404312..5cd6146b 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -25,6 +25,9 @@ type Schema struct { } func (schema Schema) String() string { + if schema.ModelType.Name() == "" { + return fmt.Sprintf("%v(%v)", schema.Name, schema.Table) + } return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name()) } @@ -86,7 +89,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) for _, field := range schema.Fields { if field.DBName == "" { - field.DBName = namer.ColumnName(field.Name) + field.DBName = namer.ColumnName(schema.Table, field.Name) } if field.DBName != "" { diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index eb0085c2..ce91d8d1 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -1,7 +1,9 @@ package schema_test import ( + "fmt" "reflect" + "strings" "testing" "github.com/jinzhu/gorm/schema" @@ -90,14 +92,25 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* } type Relation struct { - Name string - Type schema.RelationshipType - Polymorphic schema.Polymorphic - Schema string - FieldSchema string - JoinTable string - JoinTableFields []schema.Field - References []Reference + Name string + Type schema.RelationshipType + Schema string + FieldSchema string + Polymorphic Polymorphic + JoinTable JoinTable + References []Reference +} + +type Polymorphic struct { + ID string + Type string + Value string +} + +type JoinTable struct { + Name string + Table string + Fields []schema.Field } type Reference struct { @@ -105,17 +118,82 @@ type Reference struct { PrimarySchema string ForeignKey string ForeignSchema string - OwnPriamryKey bool + PrimaryValue string + OwnPrimaryKey bool } func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { if r, ok := s.Relationships.Relations[relation.Name]; ok { if r.Name != relation.Name { - t.Errorf("schema %v relation name expects %v, but got %v", s, relation.Name, r.Name) + t.Errorf("schema %v relation name expects %v, but got %v", s, r.Name, relation.Name) } if r.Type != relation.Type { - t.Errorf("schema %v relation name expects %v, but got %v", s, relation.Type, r.Type) + t.Errorf("schema %v relation name expects %v, but got %v", s, r.Type, relation.Type) + } + + if r.Schema.Name != relation.Schema { + t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name) + } + + if r.FieldSchema.Name != relation.FieldSchema { + t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name) + } + + if r.Polymorphic != nil { + if r.Polymorphic.PolymorphicID.Name != relation.Polymorphic.ID { + t.Errorf("schema %v relation's polymorphic id field expects %v, but got %v", s, relation.Polymorphic.ID, r.Polymorphic.PolymorphicID.Name) + } + + if r.Polymorphic.PolymorphicType.Name != relation.Polymorphic.Type { + t.Errorf("schema %v relation's polymorphic type field expects %v, but got %v", s, relation.Polymorphic.Type, r.Polymorphic.PolymorphicType.Name) + } + + if r.Polymorphic.Value != relation.Polymorphic.Value { + t.Errorf("schema %v relation's polymorphic value expects %v, but got %v", s, relation.Polymorphic.Value, r.Polymorphic.Value) + } + } + + if r.JoinTable != nil { + if r.JoinTable.Name != relation.JoinTable.Name { + t.Errorf("schema %v relation's join table name expects %v, but got %v", s, relation.JoinTable.Name, r.JoinTable.Name) + } + + if r.JoinTable.Table != relation.JoinTable.Table { + t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table) + } + + for _, f := range relation.JoinTable.Fields { + checkSchemaField(t, r.JoinTable, &f, nil) + } + } + + if len(relation.References) != len(r.References) { + t.Errorf("schema %v relation's reference's count doesn't match, expects %v, but got %v", s, len(relation.References), len(r.References)) + } + + for _, ref := range relation.References { + var found bool + for _, rf := range r.References { + if (rf.PrimaryKey == nil || (rf.PrimaryKey.Name == ref.PrimaryKey && rf.PrimaryKey.Schema.Name == ref.PrimarySchema)) && (rf.PrimaryValue == ref.PrimaryValue) && (rf.ForeignKey.Name == ref.ForeignKey && rf.ForeignKey.Schema.Name == ref.ForeignSchema) && (rf.OwnPrimaryKey == ref.OwnPrimaryKey) { + found = true + } + } + + if !found { + var refs []string + for _, rf := range r.References { + var primaryKey, primaryKeySchema string + if rf.PrimaryKey != nil { + primaryKey, primaryKeySchema = rf.PrimaryKey.Name, rf.PrimaryKey.Schema.Name + } + refs = append(refs, fmt.Sprintf( + "{PrimaryKey: %v PrimaryKeySchame: %v ForeignKey: %v ForeignKeySchema: %v PrimaryValue: %v OwnPrimaryKey: %v}", + primaryKey, primaryKeySchema, rf.ForeignKey.Name, rf.ForeignKey.Schema.Name, rf.PrimaryValue, rf.OwnPrimaryKey, + )) + } + t.Errorf("schema %v relation %v failed to found reference %+v, has %v", s, relation.Name, ref, strings.Join(refs, ", ")) + } } } else { t.Errorf("schema %v failed to find relations by name %v", s, relation.Name) diff --git a/schema/schema_test.go b/schema/schema_test.go index 8ea219e1..526a98bd 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -41,8 +41,61 @@ func TestParseSchema(t *testing.T) { // check relations relations := []Relation{ - {Name: "Pets", Type: schema.HasMany, Schema: "User", FieldSchema: "Pet", References: []Reference{{"ID", "User", "UserID", "Pet", true}}}, + { + Name: "Account", Type: schema.HasOne, Schema: "User", FieldSchema: "Account", + References: []Reference{{"ID", "User", "UserID", "Account", "", true}}, + }, + { + Name: "Pets", Type: schema.HasMany, Schema: "User", FieldSchema: "Pet", + References: []Reference{{"ID", "User", "UserID", "Pet", "", true}}, + }, + { + Name: "Toys", Type: schema.HasMany, Schema: "User", FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{{"ID", "User", "OwnerID", "Toy", "", true}, {"", "", "OwnerType", "Toy", "users", false}}, + }, + { + Name: "Company", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Company", + References: []Reference{{"ID", "Company", "CompanyID", "User", "", false}}, + }, + { + Name: "Manager", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "ManagerID", "User", "", false}}, + }, + { + Name: "Team", Type: schema.HasMany, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "ManagerID", "User", "", true}}, + }, + { + Name: "Languages", Type: schema.Many2Many, Schema: "User", FieldSchema: "Language", + JoinTable: JoinTable{Name: "UserSpeak", Table: "user_speaks", Fields: []schema.Field{ + { + Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, + }, + { + Name: "LanguageCode", DBName: "language_code", BindNames: []string{"LanguageCode"}, DataType: schema.String, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, + }, + }}, + References: []Reference{{"ID", "User", "UserID", "UserSpeak", "", true}, {"Code", "Language", "LanguageCode", "UserSpeak", "", false}}, + }, + { + Name: "Friends", Type: schema.Many2Many, Schema: "User", FieldSchema: "User", + JoinTable: JoinTable{Name: "user_friends", Table: "user_friends", Fields: []schema.Field{ + { + Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, + }, + { + Name: "FriendID", DBName: "friend_id", BindNames: []string{"FriendID"}, DataType: schema.Uint, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, + }, + }}, + References: []Reference{{"ID", "User", "UserID", "user_friends", "", true}, {"ID", "User", "FriendID", "user_friends", "", false}}, + }, } + for _, relation := range relations { checkSchemaRelation(t, user, relation) } diff --git a/tests/model.go b/tests/model.go index e2b69abc..62000352 100644 --- a/tests/model.go +++ b/tests/model.go @@ -24,8 +24,8 @@ type User struct { ManagerID uint Manager *User Team []User `gorm:"foreignkey:ManagerID"` + Languages []Language `gorm:"many2many:UserSpeak"` Friends []*User `gorm:"many2many:user_friends"` - Languages []Language `gorm:"many2many:user_speaks"` } type Account struct { @@ -53,6 +53,6 @@ type Company struct { } type Language struct { - Code string `gorm:primarykey` + Code string `gorm:"primarykey"` Name string } From 8cb15cadde6e2c3ff1cc19e1182ce98b734ea7d5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 2 Feb 2020 08:35:01 +0800 Subject: [PATCH 222/881] Improve test structure --- callbacks/callbacks.go | 12 ++ callbacks/create.go | 24 ++++ callbacks/interface.go | 11 ++ dialects/mysql/go.mod | 7 + dialects/mysql/mysql.go | 29 ++++ dialects/mysql/mysql_test.go | 12 ++ dialects/sqlite/go.mod | 7 + dialects/sqlite/sqlite.go | 28 ++++ dialects/sqlite/sqlite_test.go | 15 ++ finisher_api.go | 1 + gorm.go | 33 ++++- schema/schema_helper_test.go | 250 +++++++++++++++++---------------- tests/create_test.go | 1 + 13 files changed, 304 insertions(+), 126 deletions(-) create mode 100644 callbacks/callbacks.go create mode 100644 callbacks/create.go create mode 100644 callbacks/interface.go create mode 100644 dialects/mysql/go.mod create mode 100644 dialects/mysql/mysql.go create mode 100644 dialects/mysql/mysql_test.go create mode 100644 dialects/sqlite/go.mod create mode 100644 dialects/sqlite/sqlite.go create mode 100644 dialects/sqlite/sqlite_test.go create mode 100644 tests/create_test.go diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go new file mode 100644 index 00000000..7fd12cb7 --- /dev/null +++ b/callbacks/callbacks.go @@ -0,0 +1,12 @@ +package callbacks + +import "github.com/jinzhu/gorm" + +func RegisterDefaultCallbacks(db *gorm.DB) { + callback := db.Callback() + callback.Create().Register("gorm:before_create", BeforeCreate) + callback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations) + callback.Create().Register("gorm:create", Create) + callback.Create().Register("gorm:save_after_associations", SaveAfterAssociations) + callback.Create().Register("gorm:after_create", AfterCreate) +} diff --git a/callbacks/create.go b/callbacks/create.go new file mode 100644 index 00000000..2fe27140 --- /dev/null +++ b/callbacks/create.go @@ -0,0 +1,24 @@ +package callbacks + +import "github.com/jinzhu/gorm" + +func BeforeCreate(db *gorm.DB) { + // before save + // before create + + // assign timestamp +} + +func SaveBeforeAssociations(db *gorm.DB) { +} + +func Create(db *gorm.DB) { +} + +func SaveAfterAssociations(db *gorm.DB) { +} + +func AfterCreate(db *gorm.DB) { + // after save + // after create +} diff --git a/callbacks/interface.go b/callbacks/interface.go new file mode 100644 index 00000000..0ef64fcd --- /dev/null +++ b/callbacks/interface.go @@ -0,0 +1,11 @@ +package callbacks + +import "github.com/jinzhu/gorm" + +type beforeSaveInterface interface { + BeforeSave(*gorm.DB) error +} + +type beforeCreateInterface interface { + BeforeCreate(*gorm.DB) error +} diff --git a/dialects/mysql/go.mod b/dialects/mysql/go.mod new file mode 100644 index 00000000..a1f29122 --- /dev/null +++ b/dialects/mysql/go.mod @@ -0,0 +1,7 @@ +module github.com/jinzhu/gorm/dialects/mysql + +go 1.13 + +require ( + github.com/go-sql-driver/mysql v1.5.0 +) diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go new file mode 100644 index 00000000..ba306889 --- /dev/null +++ b/dialects/mysql/mysql.go @@ -0,0 +1,29 @@ +package mysql + +import ( + _ "github.com/go-sql-driver/mysql" + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/callbacks" +) + +type Dialector struct { +} + +func Open(dsn string) gorm.Dialector { + return &Dialector{} +} + +func (Dialector) Initialize(db *gorm.DB) error { + // register callbacks + callbacks.RegisterDefaultCallbacks(db) + + return nil +} + +func (Dialector) Migrator() gorm.Migrator { + return nil +} + +func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string { + return "?" +} diff --git a/dialects/mysql/mysql_test.go b/dialects/mysql/mysql_test.go new file mode 100644 index 00000000..49c26915 --- /dev/null +++ b/dialects/mysql/mysql_test.go @@ -0,0 +1,12 @@ +package mysql_test + +import ( + "testing" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/dialects/mysql" +) + +func TestOpen(t *testing.T) { + gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"), nil) +} diff --git a/dialects/sqlite/go.mod b/dialects/sqlite/go.mod new file mode 100644 index 00000000..db3370e9 --- /dev/null +++ b/dialects/sqlite/go.mod @@ -0,0 +1,7 @@ +module github.com/jinzhu/gorm/dialects/mysql + +go 1.13 + +require ( + github.com/mattn/go-sqlite3 v2.0.3+incompatible +) diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go new file mode 100644 index 00000000..f3c3f0c7 --- /dev/null +++ b/dialects/sqlite/sqlite.go @@ -0,0 +1,28 @@ +package sqlite + +import ( + "github.com/jinzhu/gorm/callbacks" + _ "github.com/mattn/go-sqlite3" +) + +type Dialector struct { +} + +func Open(dsn string) gorm.Dialector { + return &Dialector{} +} + +func (Dialector) Initialize(db *gorm.DB) error { + // register callbacks + callbacks.RegisterDefaultCallbacks(db) + + return nil +} + +func (Dialector) Migrator() gorm.Migrator { + return nil +} + +func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string { + return "?" +} diff --git a/dialects/sqlite/sqlite_test.go b/dialects/sqlite/sqlite_test.go new file mode 100644 index 00000000..f0429a12 --- /dev/null +++ b/dialects/sqlite/sqlite_test.go @@ -0,0 +1,15 @@ +package sqlite_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/jinzhu/gorm" +) + +var DB *gorm.DB + +func TestOpen(t *testing.T) { + db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) +} diff --git a/finisher_api.go b/finisher_api.go index 2668e1fe..b155e90d 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -12,6 +12,7 @@ func (db *DB) Count(sql string, values ...interface{}) (tx *DB) { // First find first record that match given conditions, order by primary key func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() + tx.callbacks.Create().Execute(tx.Limit(1).Order("id")) return } diff --git a/gorm.go b/gorm.go index 6ceac412..896d07f9 100644 --- a/gorm.go +++ b/gorm.go @@ -13,7 +13,7 @@ import ( type Config struct { // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity // You can cancel it by setting `SkipDefaultTransaction` to true - SkipDefaultTransaction bool + SkipDefaultTransaction bool // TODO // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer @@ -27,6 +27,7 @@ type Config struct { // Dialector GORM database dialector type Dialector interface { + Initialize(*DB) error Migrator() Migrator BindVar(stmt Statement, v interface{}) string } @@ -36,7 +37,8 @@ type DB struct { *Config Dialector Instance - clone bool + clone bool + callbacks *callbacks } // Session session config when create new session @@ -48,15 +50,33 @@ type Session struct { // Open initialize db session based on dialector func Open(dialector Dialector, config *Config) (db *DB, err error) { + if config == nil { + config = &Config{} + } + if config.NamingStrategy == nil { config.NamingStrategy = schema.NamingStrategy{} } - return &DB{ + if config.Logger == nil { + config.Logger = logger.Default + } + + if config.NowFunc == nil { + config.NowFunc = func() time.Time { return time.Now().Local() } + } + + db = &DB{ Config: config, Dialector: dialector, clone: true, - }, nil + callbacks: InitializeCallbacks(), + } + + if dialector != nil { + err = dialector.Initialize(db) + } + return } // Session create new db session @@ -112,6 +132,11 @@ func (db *DB) Get(key string) (interface{}, bool) { return nil, false } +// Callback returns callback manager +func (db *DB) Callback() *callbacks { + return db.callbacks +} + func (db *DB) getInstance() *DB { if db.clone { ctx := db.Instance.Context diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index ce91d8d1..05f41131 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -10,85 +10,89 @@ import ( ) func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { - equalFieldNames := []string{"Name", "Table"} - - for _, name := range equalFieldNames { - got := reflect.ValueOf(s).Elem().FieldByName(name).Interface() - expects := reflect.ValueOf(v).FieldByName(name).Interface() - if !reflect.DeepEqual(got, expects) { - t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got) - } - } - - for idx, field := range primaryFields { - var found bool - for _, f := range s.PrimaryFields { - if f.Name == field { - found = true - } - } - - if idx == 0 { - if field != s.PrioritizedPrimaryField.Name { - t.Errorf("schema %v prioritized primary field should be %v, but got %v", s, field, s.PrioritizedPrimaryField.Name) - } - } - - if !found { - t.Errorf("schema %v failed to found priamry key: %v", s, field) - } - } -} - -func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*schema.Field)) { - if fc != nil { - fc(f) - } - - if f.TagSettings == nil { - if f.Tag != "" { - f.TagSettings = schema.ParseTagSetting(f.Tag) - } else { - f.TagSettings = map[string]string{} - } - } - - if parsedField, ok := s.FieldsByName[f.Name]; !ok { - t.Errorf("schema %v failed to look up field with name %v", s, f.Name) - } else { - equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"} + t.Run("CheckSchema/"+s.Name, func(t *testing.T) { + equalFieldNames := []string{"Name", "Table"} for _, name := range equalFieldNames { - got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface() - expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface() + got := reflect.ValueOf(s).Elem().FieldByName(name).Interface() + expects := reflect.ValueOf(v).FieldByName(name).Interface() if !reflect.DeepEqual(got, expects) { - t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got) + t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got) } } - if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { - t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) - } - - for _, name := range []string{f.DBName, f.Name} { - if field := s.LookUpField(name); field == nil || parsedField != field { - t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) - } - } - - if f.PrimaryKey { + for idx, field := range primaryFields { var found bool - for _, primaryField := range s.PrimaryFields { - if primaryField == parsedField { + for _, f := range s.PrimaryFields { + if f.Name == field { found = true } } + if idx == 0 { + if field != s.PrioritizedPrimaryField.Name { + t.Errorf("schema %v prioritized primary field should be %v, but got %v", s, field, s.PrioritizedPrimaryField.Name) + } + } + if !found { - t.Errorf("schema %v doesn't include field %v", s, f.Name) + t.Errorf("schema %v failed to found priamry key: %v", s, field) } } - } + }) +} + +func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*schema.Field)) { + t.Run("CheckField/"+f.Name, func(t *testing.T) { + if fc != nil { + fc(f) + } + + if f.TagSettings == nil { + if f.Tag != "" { + f.TagSettings = schema.ParseTagSetting(f.Tag) + } else { + f.TagSettings = map[string]string{} + } + } + + if parsedField, ok := s.FieldsByName[f.Name]; !ok { + t.Errorf("schema %v failed to look up field with name %v", s, f.Name) + } else { + equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"} + + for _, name := range equalFieldNames { + got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface() + expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface() + if !reflect.DeepEqual(got, expects) { + t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got) + } + } + + if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { + t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + } + + for _, name := range []string{f.DBName, f.Name} { + if field := s.LookUpField(name); field == nil || parsedField != field { + t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + } + } + + if f.PrimaryKey { + var found bool + for _, primaryField := range s.PrimaryFields { + if primaryField == parsedField { + found = true + } + } + + if !found { + t.Errorf("schema %v doesn't include field %v", s, f.Name) + } + } + } + }) } type Relation struct { @@ -123,79 +127,81 @@ type Reference struct { } func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { - if r, ok := s.Relationships.Relations[relation.Name]; ok { - if r.Name != relation.Name { - t.Errorf("schema %v relation name expects %v, but got %v", s, r.Name, relation.Name) - } - - if r.Type != relation.Type { - t.Errorf("schema %v relation name expects %v, but got %v", s, r.Type, relation.Type) - } - - if r.Schema.Name != relation.Schema { - t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name) - } - - if r.FieldSchema.Name != relation.FieldSchema { - t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name) - } - - if r.Polymorphic != nil { - if r.Polymorphic.PolymorphicID.Name != relation.Polymorphic.ID { - t.Errorf("schema %v relation's polymorphic id field expects %v, but got %v", s, relation.Polymorphic.ID, r.Polymorphic.PolymorphicID.Name) + t.Run("CheckRelation/"+relation.Name, func(t *testing.T) { + if r, ok := s.Relationships.Relations[relation.Name]; ok { + if r.Name != relation.Name { + t.Errorf("schema %v relation name expects %v, but got %v", s, r.Name, relation.Name) } - if r.Polymorphic.PolymorphicType.Name != relation.Polymorphic.Type { - t.Errorf("schema %v relation's polymorphic type field expects %v, but got %v", s, relation.Polymorphic.Type, r.Polymorphic.PolymorphicType.Name) + if r.Type != relation.Type { + t.Errorf("schema %v relation name expects %v, but got %v", s, r.Type, relation.Type) } - if r.Polymorphic.Value != relation.Polymorphic.Value { - t.Errorf("schema %v relation's polymorphic value expects %v, but got %v", s, relation.Polymorphic.Value, r.Polymorphic.Value) - } - } - - if r.JoinTable != nil { - if r.JoinTable.Name != relation.JoinTable.Name { - t.Errorf("schema %v relation's join table name expects %v, but got %v", s, relation.JoinTable.Name, r.JoinTable.Name) + if r.Schema.Name != relation.Schema { + t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name) } - if r.JoinTable.Table != relation.JoinTable.Table { - t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table) + if r.FieldSchema.Name != relation.FieldSchema { + t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name) } - for _, f := range relation.JoinTable.Fields { - checkSchemaField(t, r.JoinTable, &f, nil) - } - } + if r.Polymorphic != nil { + if r.Polymorphic.PolymorphicID.Name != relation.Polymorphic.ID { + t.Errorf("schema %v relation's polymorphic id field expects %v, but got %v", s, relation.Polymorphic.ID, r.Polymorphic.PolymorphicID.Name) + } - if len(relation.References) != len(r.References) { - t.Errorf("schema %v relation's reference's count doesn't match, expects %v, but got %v", s, len(relation.References), len(r.References)) - } + if r.Polymorphic.PolymorphicType.Name != relation.Polymorphic.Type { + t.Errorf("schema %v relation's polymorphic type field expects %v, but got %v", s, relation.Polymorphic.Type, r.Polymorphic.PolymorphicType.Name) + } - for _, ref := range relation.References { - var found bool - for _, rf := range r.References { - if (rf.PrimaryKey == nil || (rf.PrimaryKey.Name == ref.PrimaryKey && rf.PrimaryKey.Schema.Name == ref.PrimarySchema)) && (rf.PrimaryValue == ref.PrimaryValue) && (rf.ForeignKey.Name == ref.ForeignKey && rf.ForeignKey.Schema.Name == ref.ForeignSchema) && (rf.OwnPrimaryKey == ref.OwnPrimaryKey) { - found = true + if r.Polymorphic.Value != relation.Polymorphic.Value { + t.Errorf("schema %v relation's polymorphic value expects %v, but got %v", s, relation.Polymorphic.Value, r.Polymorphic.Value) } } - if !found { - var refs []string + if r.JoinTable != nil { + if r.JoinTable.Name != relation.JoinTable.Name { + t.Errorf("schema %v relation's join table name expects %v, but got %v", s, relation.JoinTable.Name, r.JoinTable.Name) + } + + if r.JoinTable.Table != relation.JoinTable.Table { + t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table) + } + + for _, f := range relation.JoinTable.Fields { + checkSchemaField(t, r.JoinTable, &f, nil) + } + } + + if len(relation.References) != len(r.References) { + t.Errorf("schema %v relation's reference's count doesn't match, expects %v, but got %v", s, len(relation.References), len(r.References)) + } + + for _, ref := range relation.References { + var found bool for _, rf := range r.References { - var primaryKey, primaryKeySchema string - if rf.PrimaryKey != nil { - primaryKey, primaryKeySchema = rf.PrimaryKey.Name, rf.PrimaryKey.Schema.Name + if (rf.PrimaryKey == nil || (rf.PrimaryKey.Name == ref.PrimaryKey && rf.PrimaryKey.Schema.Name == ref.PrimarySchema)) && (rf.PrimaryValue == ref.PrimaryValue) && (rf.ForeignKey.Name == ref.ForeignKey && rf.ForeignKey.Schema.Name == ref.ForeignSchema) && (rf.OwnPrimaryKey == ref.OwnPrimaryKey) { + found = true } - refs = append(refs, fmt.Sprintf( - "{PrimaryKey: %v PrimaryKeySchame: %v ForeignKey: %v ForeignKeySchema: %v PrimaryValue: %v OwnPrimaryKey: %v}", - primaryKey, primaryKeySchema, rf.ForeignKey.Name, rf.ForeignKey.Schema.Name, rf.PrimaryValue, rf.OwnPrimaryKey, - )) } - t.Errorf("schema %v relation %v failed to found reference %+v, has %v", s, relation.Name, ref, strings.Join(refs, ", ")) + + if !found { + var refs []string + for _, rf := range r.References { + var primaryKey, primaryKeySchema string + if rf.PrimaryKey != nil { + primaryKey, primaryKeySchema = rf.PrimaryKey.Name, rf.PrimaryKey.Schema.Name + } + refs = append(refs, fmt.Sprintf( + "{PrimaryKey: %v PrimaryKeySchame: %v ForeignKey: %v ForeignKeySchema: %v PrimaryValue: %v OwnPrimaryKey: %v}", + primaryKey, primaryKeySchema, rf.ForeignKey.Name, rf.ForeignKey.Schema.Name, rf.PrimaryValue, rf.OwnPrimaryKey, + )) + } + t.Errorf("schema %v relation %v failed to found reference %+v, has %v", s, relation.Name, ref, strings.Join(refs, ", ")) + } } + } else { + t.Errorf("schema %v failed to find relations by name %v", s, relation.Name) } - } else { - t.Errorf("schema %v failed to find relations by name %v", s, relation.Name) - } + }) } diff --git a/tests/create_test.go b/tests/create_test.go new file mode 100644 index 00000000..ca8701d2 --- /dev/null +++ b/tests/create_test.go @@ -0,0 +1 @@ +package tests From d833efe8b941e301ab5e983b9ee7eed447fec6f5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 2 Feb 2020 14:40:44 +0800 Subject: [PATCH 223/881] Work on clauses --- callbacks.go | 13 ++++ callbacks/create.go | 23 +++++- callbacks/query.go | 13 ++++ clause/clause.go | 129 ++++----------------------------- clause/from.go | 22 ++++++ clause/group_by.go | 6 ++ clause/join.go | 23 ++++++ clause/limit.go | 6 ++ clause/order_by.go | 4 + clause/query.go | 6 -- clause/select.go | 45 ++++++++++++ clause/where.go | 77 ++++++++++++++++++++ clause/with.go | 4 + dialects/sqlite/go.mod | 6 +- dialects/sqlite/go.sum | 2 + dialects/sqlite/sqlite.go | 1 + dialects/sqlite/sqlite_test.go | 18 ++++- finisher_api.go | 6 +- go.mod | 5 +- gorm.go | 43 +++++------ interfaces.go | 21 ++++++ schema/schema.go | 10 ++- schema/schema_helper_test.go | 22 +----- statement.go | 31 +++++++- tests/create_test.go | 1 - tests/tests.go | 42 +++++++++++ tests/utils.go | 19 +++++ 27 files changed, 413 insertions(+), 185 deletions(-) create mode 100644 callbacks/query.go create mode 100644 clause/from.go create mode 100644 clause/group_by.go create mode 100644 clause/join.go create mode 100644 clause/limit.go create mode 100644 clause/order_by.go create mode 100644 clause/select.go create mode 100644 clause/where.go create mode 100644 clause/with.go create mode 100644 dialects/sqlite/go.sum create mode 100644 interfaces.go delete mode 100644 tests/create_test.go create mode 100644 tests/tests.go create mode 100644 tests/utils.go diff --git a/callbacks.go b/callbacks.go index a7f30612..22d2eda3 100644 --- a/callbacks.go +++ b/callbacks.go @@ -1,9 +1,11 @@ package gorm import ( + "errors" "fmt" "github.com/jinzhu/gorm/logger" + "github.com/jinzhu/gorm/schema" "github.com/jinzhu/gorm/utils" ) @@ -67,6 +69,17 @@ func (cs *callbacks) Raw() *processor { } func (p *processor) Execute(db *DB) { + if stmt := db.Statement; stmt != nil && stmt.Dest != nil { + var err error + stmt.Schema, err = schema.Parse(stmt.Dest, db.cacheStore, db.NamingStrategy) + + if err != nil && !errors.Is(err, schema.ErrUnsupportedDataType) { + db.AddError(err) + } else if stmt.Table == "" && stmt.Schema != nil { + stmt.Table = stmt.Schema.Table + } + } + for _, f := range p.fns { f(db) } diff --git a/callbacks/create.go b/callbacks/create.go index 2fe27140..5a3aaa24 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -1,6 +1,10 @@ package callbacks -import "github.com/jinzhu/gorm" +import ( + "fmt" + + "github.com/jinzhu/gorm" +) func BeforeCreate(db *gorm.DB) { // before save @@ -13,6 +17,9 @@ func SaveBeforeAssociations(db *gorm.DB) { } func Create(db *gorm.DB) { + db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING") + + fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) } func SaveAfterAssociations(db *gorm.DB) { @@ -22,3 +29,17 @@ func AfterCreate(db *gorm.DB) { // after save // after create } + +func objectToFieldsMap(stmt *gorm.Statement) { + if stmt.Schema != nil { + if s, ok := stmt.Clauses["SELECT"]; ok { + s.Attrs + } + + if s, ok := stmt.Clauses["OMIT"]; ok { + s.Attrs + } + + stmt.Schema.LookUpField(s.S) + } +} diff --git a/callbacks/query.go b/callbacks/query.go new file mode 100644 index 00000000..5d27ea17 --- /dev/null +++ b/callbacks/query.go @@ -0,0 +1,13 @@ +package callbacks + +import "github.com/jinzhu/gorm" + +func Query(db *gorm.DB) { +} + +func Preload(db *gorm.DB) { +} + +func AfterQuery(db *gorm.DB) { + // after find +} diff --git a/clause/clause.go b/clause/clause.go index 1b4a7e85..c0ebe7e2 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -51,124 +51,21 @@ type OverrideNameInterface interface { OverrideName() string } -//////////////////////////////////////////////////////////////////////////////// -// Predefined Clauses -//////////////////////////////////////////////////////////////////////////////// - -// Where where clause -type Where struct { - AndConditions AddConditions - ORConditions []ORConditions - builders []Expression +// Column quote with name +type Column struct { + Table string + Name string + Alias string + Raw bool } -func (where Where) Name() string { - return "WHERE" +func ToColumns(value ...interface{}) []Column { + return nil } -func (where Where) Build(builder Builder) { - var withConditions bool - - if len(where.AndConditions) > 0 { - withConditions = true - where.AndConditions.Build(builder) - } - - if len(where.builders) > 0 { - for _, b := range where.builders { - if withConditions { - builder.Write(" AND ") - } - withConditions = true - b.Build(builder) - } - } - - var singleOrConditions []ORConditions - for _, or := range where.ORConditions { - if len(or) == 1 { - if withConditions { - builder.Write(" OR ") - or.Build(builder) - } else { - singleOrConditions = append(singleOrConditions, or) - } - } else { - withConditions = true - builder.Write(" AND (") - or.Build(builder) - builder.WriteByte(')') - } - } - - for _, or := range singleOrConditions { - if withConditions { - builder.Write(" AND ") - or.Build(builder) - } else { - withConditions = true - or.Build(builder) - } - } - - if !withConditions { - builder.Write(" FALSE") - } - - return -} - -func (where Where) MergeExpression(expr Expression) { - if w, ok := expr.(Where); ok { - where.AndConditions = append(where.AndConditions, w.AndConditions...) - where.ORConditions = append(where.ORConditions, w.ORConditions...) - where.builders = append(where.builders, w.builders...) - } else { - where.builders = append(where.builders, expr) - } -} - -// Select select attrs when querying, updating, creating -type Select struct { - Omit bool -} - -// Join join clause -type Join struct { - Table string - Type string // left join books on - ON []Expression - builders []Expression -} - -func (join Join) Build(builder Builder) { - // TODO -} - -func (join Join) MergeExpression(expr Expression) { - if j, ok := expr.(Join); ok { - join.builders = append(join.builders, j.builders...) - } else { - join.builders = append(join.builders, expr) - } -} - -// GroupBy group by clause -type GroupBy struct { -} - -// Having having clause -type Having struct { -} - -// Order order clause -type Order struct { -} - -// Limit limit clause -type Limit struct { -} - -// Offset offset clause -type Offset struct { +// Table quote with name +type Table struct { + Table string + Alias string + Raw bool } diff --git a/clause/from.go b/clause/from.go new file mode 100644 index 00000000..610d69a4 --- /dev/null +++ b/clause/from.go @@ -0,0 +1,22 @@ +package clause + +// From from clause +type From struct { + Tables []Table +} + +// Name from clause name +func (From) Name() string { + return "FROM" +} + +// Build build from clause +func (from From) Build(builder Builder) { + for idx, table := range from.Tables { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(table) + } +} diff --git a/clause/group_by.go b/clause/group_by.go new file mode 100644 index 00000000..bce94109 --- /dev/null +++ b/clause/group_by.go @@ -0,0 +1,6 @@ +package clause + +// GroupBy group by clause +type GroupBy struct { + Having Where +} diff --git a/clause/join.go b/clause/join.go new file mode 100644 index 00000000..6b0e8f97 --- /dev/null +++ b/clause/join.go @@ -0,0 +1,23 @@ +package clause + +// Join join clause +type Join struct { + Table From // From + Type string // INNER, LEFT, RIGHT, FULL, CROSS JOIN + Using []Column + ON Where +} + +// TODO multiple joins + +func (join Join) Build(builder Builder) { + // TODO +} + +func (join Join) MergeExpression(expr Expression) { + // if j, ok := expr.(Join); ok { + // join.builders = append(join.builders, j.builders...) + // } else { + // join.builders = append(join.builders, expr) + // } +} diff --git a/clause/limit.go b/clause/limit.go new file mode 100644 index 00000000..8fbc0055 --- /dev/null +++ b/clause/limit.go @@ -0,0 +1,6 @@ +package clause + +// Limit limit clause +type Limit struct { + Offset uint +} diff --git a/clause/order_by.go b/clause/order_by.go new file mode 100644 index 00000000..a11a3c48 --- /dev/null +++ b/clause/order_by.go @@ -0,0 +1,4 @@ +package clause + +type OrderBy struct { +} diff --git a/clause/query.go b/clause/query.go index 7b5491e5..949678d9 100644 --- a/clause/query.go +++ b/clause/query.go @@ -2,12 +2,6 @@ package clause import "strings" -// Column quote with name -type Column struct { - Table string - Name string -} - //////////////////////////////////////////////////////////////////////////////// // Query Expressions //////////////////////////////////////////////////////////////////////////////// diff --git a/clause/select.go b/clause/select.go new file mode 100644 index 00000000..1342c411 --- /dev/null +++ b/clause/select.go @@ -0,0 +1,45 @@ +package clause + +// Select select attrs when querying, updating, creating +type Select struct { + SelectColumns []Column + OmitColumns []Column +} + +// SelectInterface select clause interface +type SelectInterface interface { + Selects() []Column + Omits() []Column +} + +func (s Select) Selects() []Column { + return s.SelectColumns +} + +func (s Select) Omits() []Column { + return s.OmitColumns +} + +func (s Select) Build(builder Builder) { + if len(s.SelectColumns) > 0 { + for idx, column := range s.SelectColumns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + } else { + builder.WriteByte('*') + } +} + +func (s Select) MergeExpression(expr Expression) { + if v, ok := expr.(SelectInterface); ok { + if len(s.SelectColumns) == 0 { + s.SelectColumns = v.Selects() + } + if len(s.OmitColumns) == 0 { + s.OmitColumns = v.Omits() + } + } +} diff --git a/clause/where.go b/clause/where.go new file mode 100644 index 00000000..888b9d07 --- /dev/null +++ b/clause/where.go @@ -0,0 +1,77 @@ +package clause + +// Where where clause +type Where struct { + AndConditions AddConditions + ORConditions []ORConditions + builders []Expression +} + +// Name where clause name +func (where Where) Name() string { + return "WHERE" +} + +// Build build where clause +func (where Where) Build(builder Builder) { + var withConditions bool + + if len(where.AndConditions) > 0 { + withConditions = true + where.AndConditions.Build(builder) + } + + if len(where.builders) > 0 { + for _, b := range where.builders { + if withConditions { + builder.Write(" AND ") + } + withConditions = true + b.Build(builder) + } + } + + var singleOrConditions []ORConditions + for _, or := range where.ORConditions { + if len(or) == 1 { + if withConditions { + builder.Write(" OR ") + or.Build(builder) + } else { + singleOrConditions = append(singleOrConditions, or) + } + } else { + withConditions = true + builder.Write(" AND (") + or.Build(builder) + builder.WriteByte(')') + } + } + + for _, or := range singleOrConditions { + if withConditions { + builder.Write(" AND ") + or.Build(builder) + } else { + withConditions = true + or.Build(builder) + } + } + + if !withConditions { + builder.Write(" FALSE") + } + + return +} + +// MergeExpression merge where clauses +func (where Where) MergeExpression(expr Expression) { + if w, ok := expr.(Where); ok { + where.AndConditions = append(where.AndConditions, w.AndConditions...) + where.ORConditions = append(where.ORConditions, w.ORConditions...) + where.builders = append(where.builders, w.builders...) + } else { + where.builders = append(where.builders, expr) + } +} diff --git a/clause/with.go b/clause/with.go new file mode 100644 index 00000000..7e9eaef1 --- /dev/null +++ b/clause/with.go @@ -0,0 +1,4 @@ +package clause + +type With struct { +} diff --git a/dialects/sqlite/go.mod b/dialects/sqlite/go.mod index db3370e9..79d48da8 100644 --- a/dialects/sqlite/go.mod +++ b/dialects/sqlite/go.mod @@ -1,7 +1,5 @@ -module github.com/jinzhu/gorm/dialects/mysql +module github.com/jinzhu/gorm/dialects/sqlite go 1.13 -require ( - github.com/mattn/go-sqlite3 v2.0.3+incompatible -) +require github.com/mattn/go-sqlite3 v2.0.3+incompatible diff --git a/dialects/sqlite/go.sum b/dialects/sqlite/go.sum new file mode 100644 index 00000000..d6744290 --- /dev/null +++ b/dialects/sqlite/go.sum @@ -0,0 +1,2 @@ +github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= +github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index f3c3f0c7..bcd6bd5c 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -1,6 +1,7 @@ package sqlite import ( + "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" _ "github.com/mattn/go-sqlite3" ) diff --git a/dialects/sqlite/sqlite_test.go b/dialects/sqlite/sqlite_test.go index f0429a12..51c1def0 100644 --- a/dialects/sqlite/sqlite_test.go +++ b/dialects/sqlite/sqlite_test.go @@ -1,15 +1,27 @@ package sqlite_test import ( + "fmt" "os" "path/filepath" "testing" "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/dialects/sqlite" + "github.com/jinzhu/gorm/tests" ) -var DB *gorm.DB +var ( + DB *gorm.DB + err error +) -func TestOpen(t *testing.T) { - db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) +func init() { + if DB, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}); err != nil { + panic(fmt.Sprintf("failed to initialize database, got error %v", err)) + } +} + +func TestSqlite(t *testing.T) { + tests.RunTestsSuit(t, DB) } diff --git a/finisher_api.go b/finisher_api.go index b155e90d..c79915d2 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -12,7 +12,9 @@ func (db *DB) Count(sql string, values ...interface{}) (tx *DB) { // First find first record that match given conditions, order by primary key func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() - tx.callbacks.Create().Execute(tx.Limit(1).Order("id")) + tx.Statement.Dest = out + tx.Limit(1) + tx.callbacks.Query().Execute(tx) return } @@ -35,12 +37,10 @@ func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) { } func (db *DB) Row() *sql.Row { - // TODO return nil } func (db *DB) Rows() (*sql.Rows, error) { - // TODO return nil, nil } diff --git a/go.mod b/go.mod index 516a9759..820046ba 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/jinzhu/gorm go 1.13 -require github.com/jinzhu/inflection v1.0.0 +require ( + github.com/jinzhu/inflection v1.0.0 + gopkg.in/errgo.v2 v2.1.0 +) diff --git a/gorm.go b/gorm.go index 896d07f9..2264b9ae 100644 --- a/gorm.go +++ b/gorm.go @@ -2,6 +2,7 @@ package gorm import ( "context" + "sync" "time" "github.com/jinzhu/gorm/clause" @@ -12,36 +13,28 @@ import ( // Config GORM config type Config struct { // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity - // You can cancel it by setting `SkipDefaultTransaction` to true - SkipDefaultTransaction bool // TODO - + // You can disable it by setting `SkipDefaultTransaction` to true + SkipDefaultTransaction bool // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer - // Logger Logger logger.Interface - // NowFunc the function to be used when creating a new timestamp NowFunc func() time.Time } -// Dialector GORM database dialector -type Dialector interface { - Initialize(*DB) error - Migrator() Migrator - BindVar(stmt Statement, v interface{}) string -} - // DB GORM DB definition type DB struct { *Config Dialector Instance - clone bool - callbacks *callbacks + DB CommonDB + clone bool + callbacks *callbacks + cacheStore *sync.Map } -// Session session config when create new session +// Session session config when create session with Session() method type Session struct { Context context.Context Logger logger.Interface @@ -67,10 +60,11 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { } db = &DB{ - Config: config, - Dialector: dialector, - clone: true, - callbacks: InitializeCallbacks(), + Config: config, + Dialector: dialector, + clone: true, + callbacks: InitializeCallbacks(), + cacheStore: &sync.Map{}, } if dialector != nil { @@ -113,10 +107,6 @@ func (db *DB) Debug() (tx *DB) { return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)}) } -func (db *DB) Close() error { - return nil -} - // Set store value with key into current db instance's context func (db *DB) Set(key string, value interface{}) *DB { tx := db.getInstance() @@ -145,12 +135,15 @@ func (db *DB) getInstance() *DB { } return &DB{ - Config: db.Config, - Dialector: db.Dialector, Instance: Instance{ Context: ctx, Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}}, }, + Config: db.Config, + Dialector: db.Dialector, + DB: db.DB, + callbacks: db.callbacks, + cacheStore: db.cacheStore, } } diff --git a/interfaces.go b/interfaces.go new file mode 100644 index 00000000..98d04592 --- /dev/null +++ b/interfaces.go @@ -0,0 +1,21 @@ +package gorm + +import ( + "context" + "database/sql" +) + +// Dialector GORM database dialector +type Dialector interface { + Initialize(*DB) error + Migrator() Migrator + BindVar(stmt Statement, v interface{}) string +} + +// CommonDB common db interface +type CommonDB interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} diff --git a/schema/schema.go b/schema/schema.go index 5cd6146b..53170e18 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -1,6 +1,7 @@ package schema import ( + "errors" "fmt" "go/ast" "reflect" @@ -9,6 +10,9 @@ import ( "github.com/jinzhu/gorm/logger" ) +// ErrUnsupportedDataType unsupported data type +var ErrUnsupportedDataType = errors.New("unsupported data type") + type Schema struct { Name string ModelType reflect.Type @@ -50,9 +54,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if modelType.Kind() != reflect.Struct { if modelType.PkgPath() == "" { - return nil, fmt.Errorf("unsupported data %+v when parsing model", dest) + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - return nil, fmt.Errorf("unsupported data type %v when parsing model", modelType.PkgPath()) + return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } if v, ok := cacheStore.Load(modelType); ok { @@ -88,7 +92,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } for _, field := range schema.Fields { - if field.DBName == "" { + if field.DBName == "" && field.DataType != "" { field.DBName = namer.ColumnName(schema.Table, field.Name) } diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 05f41131..db38355d 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -2,24 +2,16 @@ package schema_test import ( "fmt" - "reflect" "strings" "testing" "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/gorm/tests" ) func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { t.Run("CheckSchema/"+s.Name, func(t *testing.T) { - equalFieldNames := []string{"Name", "Table"} - - for _, name := range equalFieldNames { - got := reflect.ValueOf(s).Elem().FieldByName(name).Interface() - expects := reflect.ValueOf(v).FieldByName(name).Interface() - if !reflect.DeepEqual(got, expects) { - t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got) - } - } + tests.AssertEqual(t, s, v, "Name", "Table") for idx, field := range primaryFields { var found bool @@ -59,15 +51,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* if parsedField, ok := s.FieldsByName[f.Name]; !ok { t.Errorf("schema %v failed to look up field with name %v", s, f.Name) } else { - equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"} - - for _, name := range equalFieldNames { - got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface() - expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface() - if !reflect.DeepEqual(got, expects) { - t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got) - } - } + tests.AssertEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) diff --git a/statement.go b/statement.go index 30d45b98..86359177 100644 --- a/statement.go +++ b/statement.go @@ -10,6 +10,7 @@ import ( "sync" "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" ) // Instance db instance @@ -37,6 +38,7 @@ type Statement struct { Clauses map[string]clause.Clause Settings sync.Map DB *DB + Schema *schema.Schema // SQL Builder SQL strings.Builder @@ -69,9 +71,32 @@ func (stmt Statement) WriteQuoted(field interface{}) (err error) { } // Quote returns quoted value -func (stmt Statement) Quote(field interface{}) (str string) { - // FIXME - return fmt.Sprint(field) +func (stmt Statement) Quote(field interface{}) string { + var str strings.Builder + + switch v := field.(type) { + case clause.Table: + str.WriteString(v.Table) + if v.Alias != "" { + str.WriteString(" AS ") + str.WriteString(v.Alias) + } + case clause.Column: + if v.Table != "" { + str.WriteString(v.Table) + str.WriteByte('.') + } + + str.WriteString(v.Name) + if v.Alias != "" { + str.WriteString(" AS ") + str.WriteString(v.Alias) + } + default: + fmt.Sprint(field) + } + + return str.String() } // Write write string diff --git a/tests/create_test.go b/tests/create_test.go deleted file mode 100644 index ca8701d2..00000000 --- a/tests/create_test.go +++ /dev/null @@ -1 +0,0 @@ -package tests diff --git a/tests/tests.go b/tests/tests.go new file mode 100644 index 00000000..b3246a79 --- /dev/null +++ b/tests/tests.go @@ -0,0 +1,42 @@ +package tests + +import ( + "testing" + "time" + + "github.com/jinzhu/gorm" +) + +func Now() *time.Time { + now := time.Now() + return &now +} + +func RunTestsSuit(t *testing.T, db *gorm.DB) { + TestCreate(t, db) +} + +func TestCreate(t *testing.T, db *gorm.DB) { + t.Run("Create", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + } + + if err := db.Create(&user).Error; err != nil { + t.Errorf("errors happened when create: %v", err) + } + + if user.ID == 0 { + t.Errorf("user's primary key should has value after create, got : %v", user.ID) + } + + var newUser User + if err := db.Where("id = ?", user.ID).First(&newUser).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + AssertEqual(t, newUser, user, "Name", "Age", "Birthday") + } + }) +} diff --git a/tests/utils.go b/tests/utils.go new file mode 100644 index 00000000..d12df2dc --- /dev/null +++ b/tests/utils.go @@ -0,0 +1,19 @@ +package tests + +import ( + "reflect" + "testing" +) + +func AssertEqual(t *testing.T, r, e interface{}, names ...string) { + for _, name := range names { + got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() + expects := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() + + if !reflect.DeepEqual(got, expects) { + t.Run(name, func(t *testing.T) { + t.Errorf("expects: %v, got %v", expects, got) + }) + } + } +} From 728c0d4470ec02629483fe90b11f7a0dec17bded Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 2 Feb 2020 19:32:27 +0800 Subject: [PATCH 224/881] Add callbacks --- callbacks.go | 29 ++++++++++++++++++----------- callbacks/callbacks.go | 39 +++++++++++++++++++++++++++++++++------ callbacks/create.go | 16 +--------------- callbacks/delete.go | 12 ++++++++++++ callbacks/transaction.go | 9 +++++++++ callbacks/update.go | 12 ++++++++++++ dialects/sqlite/go.mod | 5 ----- dialects/sqlite/go.sum | 2 -- go.mod | 5 +---- go.sum | 2 -- gorm.go | 3 ++- statement.go | 14 +++++++++++--- tests/callbacks_test.go | 4 ++-- 13 files changed, 101 insertions(+), 51 deletions(-) create mode 100644 callbacks/delete.go create mode 100644 callbacks/transaction.go create mode 100644 callbacks/update.go delete mode 100644 dialects/sqlite/go.mod delete mode 100644 dialects/sqlite/go.sum delete mode 100644 go.sum diff --git a/callbacks.go b/callbacks.go index 22d2eda3..51ee150f 100644 --- a/callbacks.go +++ b/callbacks.go @@ -9,15 +9,15 @@ import ( "github.com/jinzhu/gorm/utils" ) -func InitializeCallbacks() *callbacks { +func initializeCallbacks(db *DB) *callbacks { return &callbacks{ processors: map[string]*processor{ - "create": &processor{}, - "query": &processor{}, - "update": &processor{}, - "delete": &processor{}, - "row": &processor{}, - "raw": &processor{}, + "create": &processor{db: db}, + "query": &processor{db: db}, + "update": &processor{db: db}, + "delete": &processor{db: db}, + "row": &processor{db: db}, + "raw": &processor{db: db}, }, } } @@ -118,7 +118,14 @@ func (p *processor) Replace(name string, fn func(*DB)) error { return (&callback{processor: p}).Replace(name, fn) } -func (p *processor) compile(db *DB) (err error) { +func (p *processor) compile() (err error) { + var callbacks []*callback + for _, callback := range p.callbacks { + if callback.match == nil || callback.match(p.db) { + callbacks = append(callbacks, callback) + } + } + if p.fns, err = sortCallbacks(p.callbacks); err != nil { logger.Default.Error("Got error when compile callbacks, got %v", err) } @@ -139,7 +146,7 @@ func (c *callback) Register(name string, fn func(*DB)) error { c.name = name c.handler = fn c.processor.callbacks = append(c.processor.callbacks, c) - return c.processor.compile(c.processor.db) + return c.processor.compile() } func (c *callback) Remove(name string) error { @@ -147,7 +154,7 @@ func (c *callback) Remove(name string) error { c.name = name c.remove = true c.processor.callbacks = append(c.processor.callbacks, c) - return c.processor.compile(c.processor.db) + return c.processor.compile() } func (c *callback) Replace(name string, fn func(*DB)) error { @@ -156,7 +163,7 @@ func (c *callback) Replace(name string, fn func(*DB)) error { c.handler = fn c.replace = true c.processor.callbacks = append(c.processor.callbacks, c) - return c.processor.compile(c.processor.db) + return c.processor.compile() } // getRIndex get right index from string slice diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index 7fd12cb7..a3e5245b 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -3,10 +3,37 @@ package callbacks import "github.com/jinzhu/gorm" func RegisterDefaultCallbacks(db *gorm.DB) { - callback := db.Callback() - callback.Create().Register("gorm:before_create", BeforeCreate) - callback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations) - callback.Create().Register("gorm:create", Create) - callback.Create().Register("gorm:save_after_associations", SaveAfterAssociations) - callback.Create().Register("gorm:after_create", AfterCreate) + enableTransaction := func(db *gorm.DB) bool { + return !db.SkipDefaultTransaction + } + + createCallback := db.Callback().Create() + createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + createCallback.Register("gorm:before_create", BeforeCreate) + createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) + createCallback.Register("gorm:create", Create) + createCallback.Register("gorm:save_after_associations", SaveAfterAssociations) + createCallback.Register("gorm:after_create", AfterCreate) + createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + + queryCallback := db.Callback().Query() + queryCallback.Register("gorm:query", BeforeCreate) + queryCallback.Register("gorm:preload", Preload) + queryCallback.Register("gorm:after_query", AfterQuery) + + deleteCallback := db.Callback().Delete() + deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + deleteCallback.Register("gorm:before_delete", BeforeDelete) + deleteCallback.Register("gorm:delete", Delete) + deleteCallback.Register("gorm:after_delete", AfterDelete) + deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + + updateCallback := db.Callback().Update() + updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + updateCallback.Register("gorm:before_update", BeforeUpdate) + updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) + updateCallback.Register("gorm:update", Update) + updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations) + updateCallback.Register("gorm:after_update", AfterUpdate) + updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) } diff --git a/callbacks/create.go b/callbacks/create.go index 5a3aaa24..028cdbc4 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -18,7 +18,7 @@ func SaveBeforeAssociations(db *gorm.DB) { func Create(db *gorm.DB) { db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING") - + db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) } @@ -29,17 +29,3 @@ func AfterCreate(db *gorm.DB) { // after save // after create } - -func objectToFieldsMap(stmt *gorm.Statement) { - if stmt.Schema != nil { - if s, ok := stmt.Clauses["SELECT"]; ok { - s.Attrs - } - - if s, ok := stmt.Clauses["OMIT"]; ok { - s.Attrs - } - - stmt.Schema.LookUpField(s.S) - } -} diff --git a/callbacks/delete.go b/callbacks/delete.go new file mode 100644 index 00000000..96c392f2 --- /dev/null +++ b/callbacks/delete.go @@ -0,0 +1,12 @@ +package callbacks + +import "github.com/jinzhu/gorm" + +func BeforeDelete(db *gorm.DB) { +} + +func Delete(db *gorm.DB) { +} + +func AfterDelete(db *gorm.DB) { +} diff --git a/callbacks/transaction.go b/callbacks/transaction.go new file mode 100644 index 00000000..253c4e82 --- /dev/null +++ b/callbacks/transaction.go @@ -0,0 +1,9 @@ +package callbacks + +import "github.com/jinzhu/gorm" + +func BeginTransaction(db *gorm.DB) { +} + +func CommitOrRollbackTransaction(db *gorm.DB) { +} diff --git a/callbacks/update.go b/callbacks/update.go new file mode 100644 index 00000000..8e504403 --- /dev/null +++ b/callbacks/update.go @@ -0,0 +1,12 @@ +package callbacks + +import "github.com/jinzhu/gorm" + +func BeforeUpdate(db *gorm.DB) { +} + +func Update(db *gorm.DB) { +} + +func AfterUpdate(db *gorm.DB) { +} diff --git a/dialects/sqlite/go.mod b/dialects/sqlite/go.mod deleted file mode 100644 index 79d48da8..00000000 --- a/dialects/sqlite/go.mod +++ /dev/null @@ -1,5 +0,0 @@ -module github.com/jinzhu/gorm/dialects/sqlite - -go 1.13 - -require github.com/mattn/go-sqlite3 v2.0.3+incompatible diff --git a/dialects/sqlite/go.sum b/dialects/sqlite/go.sum deleted file mode 100644 index d6744290..00000000 --- a/dialects/sqlite/go.sum +++ /dev/null @@ -1,2 +0,0 @@ -github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= -github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= diff --git a/go.mod b/go.mod index 820046ba..516a9759 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,4 @@ module github.com/jinzhu/gorm go 1.13 -require ( - github.com/jinzhu/inflection v1.0.0 - gopkg.in/errgo.v2 v2.1.0 -) +require github.com/jinzhu/inflection v1.0.0 diff --git a/go.sum b/go.sum deleted file mode 100644 index a310b071..00000000 --- a/go.sum +++ /dev/null @@ -1,2 +0,0 @@ -github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= -github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= diff --git a/gorm.go b/gorm.go index 2264b9ae..8ac7e057 100644 --- a/gorm.go +++ b/gorm.go @@ -63,10 +63,11 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { Config: config, Dialector: dialector, clone: true, - callbacks: InitializeCallbacks(), cacheStore: &sync.Map{}, } + db.callbacks = initializeCallbacks(db) + if dialector != nil { err = dialector.Initialize(db) } diff --git a/statement.go b/statement.go index 86359177..4d959cbb 100644 --- a/statement.go +++ b/statement.go @@ -21,6 +21,13 @@ type Instance struct { Statement *Statement } +func (instance Instance) ToSQL(clauses ...string) (string, []interface{}) { + if len(clauses) > 0 { + instance.Statement.Build(clauses...) + } + return instance.Statement.SQL.String(), instance.Statement.Vars +} + // AddError add error to instance func (inst Instance) AddError(err error) { if inst.Error == nil { @@ -205,16 +212,17 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con // Build build sql with clauses names func (stmt Statement) Build(clauses ...string) { - var includeSpace bool + var firstClauseWritten bool for _, name := range clauses { if c, ok := stmt.Clauses[name]; ok { - if includeSpace { + if firstClauseWritten { stmt.WriteByte(' ') } - includeSpace = true + firstClauseWritten = true c.Build(stmt) } } + // TODO handle named vars } diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index af975a55..f8dc3e81 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -99,8 +99,8 @@ func TestCallbacks(t *testing.T) { } for idx, data := range datas { - var err error - callbacks := gorm.InitializeCallbacks() + db, err := gorm.Open(nil, nil) + callbacks := db.Callback() for _, c := range data.callbacks { var v interface{} = callbacks.Create() From d52ee0aa44609f407a0148b766754e801a60ec4f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 3 Feb 2020 10:40:03 +0800 Subject: [PATCH 225/881] Work on create callbacks --- callbacks/create.go | 11 +++++- chainable_api.go | 12 ++++-- clause/insert.go | 34 +++++++++++++++++ clause/value.go | 39 +++++++++++++++++++ dialects/postgres/postgres.go | 33 +++++++++++++++++ dialects/sqlite/sqlite.go | 12 ++++-- finisher_api.go | 70 ++++++++++++++++++----------------- go.mod | 6 ++- gorm.go | 20 +++++----- interfaces.go | 2 +- statement.go | 49 +++++++++++++++++++----- 11 files changed, 224 insertions(+), 64 deletions(-) create mode 100644 clause/insert.go create mode 100644 clause/value.go create mode 100644 dialects/postgres/postgres.go diff --git a/callbacks/create.go b/callbacks/create.go index 028cdbc4..983b95ce 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" ) func BeforeCreate(db *gorm.DB) { @@ -17,8 +18,14 @@ func SaveBeforeAssociations(db *gorm.DB) { } func Create(db *gorm.DB) { - db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING") - db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + db.Statement.AddClauseIfNotExists(clause.Insert{ + Table: clause.Table{Table: db.Statement.Table}, + }) + + db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT", "RETURNING") + result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + fmt.Println(err) + fmt.Println(result) fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) } diff --git a/chainable_api.go b/chainable_api.go index 95d5975c..b577d5cf 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -55,7 +55,9 @@ func (db *DB) Omit(columns ...string) (tx *DB) { func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Where{AndConditions: tx.Statement.BuildCondtion(query, args...)}) + tx.Statement.AddClause(clause.Where{ + AndConditions: tx.Statement.BuildCondtion(query, args...), + }) return } @@ -63,7 +65,9 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Where{ - AndConditions: []clause.Expression{clause.NotConditions(tx.Statement.BuildCondtion(query, args...))}, + AndConditions: []clause.Expression{ + clause.NotConditions(tx.Statement.BuildCondtion(query, args...)), + }, }) return } @@ -72,7 +76,9 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Where{ - ORConditions: []clause.ORConditions{tx.Statement.BuildCondtion(query, args...)}, + ORConditions: []clause.ORConditions{ + tx.Statement.BuildCondtion(query, args...), + }, }) return } diff --git a/clause/insert.go b/clause/insert.go new file mode 100644 index 00000000..e056b35e --- /dev/null +++ b/clause/insert.go @@ -0,0 +1,34 @@ +package clause + +type Insert struct { + Table Table + Priority string +} + +// Name insert clause name +func (insert Insert) Name() string { + return "INSERT" +} + +// Build build insert clause +func (insert Insert) Build(builder Builder) { + if insert.Priority != "" { + builder.Write(insert.Priority) + builder.WriteByte(' ') + } + + builder.Write("INTO ") + builder.WriteQuoted(insert.Table) +} + +// MergeExpression merge insert clauses +func (insert Insert) MergeExpression(expr Expression) { + if v, ok := expr.(Insert); ok { + if insert.Priority == "" { + insert.Priority = v.Priority + } + if insert.Table.Table == "" { + insert.Table = v.Table + } + } +} diff --git a/clause/value.go b/clause/value.go new file mode 100644 index 00000000..4de0d91e --- /dev/null +++ b/clause/value.go @@ -0,0 +1,39 @@ +package clause + +type Values struct { + Columns []Column + Values [][]interface{} +} + +// Name from clause name +func (Values) Name() string { + return "" +} + +// Build build from clause +func (values Values) Build(builder Builder) { + if len(values.Columns) > 0 { + builder.WriteByte('(') + for idx, column := range values.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + builder.WriteByte(')') + + builder.Write(" VALUES ") + + for idx, value := range values.Values { + builder.WriteByte('(') + if idx > 0 { + builder.WriteByte(',') + } + + builder.Write(builder.AddVar(value...)) + builder.WriteByte(')') + } + } else { + builder.Write("DEFAULT VALUES") + } +} diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go new file mode 100644 index 00000000..3abf05e3 --- /dev/null +++ b/dialects/postgres/postgres.go @@ -0,0 +1,33 @@ +package postgres + +import ( + "database/sql" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/callbacks" + _ "github.com/lib/pq" +) + +type Dialector struct { + DSN string +} + +func Open(dsn string) gorm.Dialector { + return &Dialector{DSN: dsn} +} + +func (dialector Dialector) Initialize(db *gorm.DB) (err error) { + // register callbacks + callbacks.RegisterDefaultCallbacks(db) + + db.DB, err = sql.Open("postgres", dialector.DSN) + return +} + +func (Dialector) Migrator() gorm.Migrator { + return nil +} + +func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { + return "?" +} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index bcd6bd5c..91c3389e 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -1,29 +1,33 @@ package sqlite import ( + "database/sql" + "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" _ "github.com/mattn/go-sqlite3" ) type Dialector struct { + DSN string } func Open(dsn string) gorm.Dialector { - return &Dialector{} + return &Dialector{DSN: dsn} } -func (Dialector) Initialize(db *gorm.DB) error { +func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db) - return nil + db.DB, err = sql.Open("sqlite3", dialector.DSN) + return } func (Dialector) Migrator() gorm.Migrator { return nil } -func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string { +func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } diff --git a/finisher_api.go b/finisher_api.go index c79915d2..a311ca78 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -4,7 +4,16 @@ import ( "database/sql" ) -func (db *DB) Count(sql string, values ...interface{}) (tx *DB) { +// Create insert the value into database +func (db *DB) Create(value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = value + tx.callbacks.Create().Execute(tx) + return +} + +// Save update value in database, if the value doesn't have primary key, will insert it +func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() return } @@ -36,32 +45,12 @@ func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) { return } -func (db *DB) Row() *sql.Row { - return nil -} - -func (db *DB) Rows() (*sql.Rows, error) { - return nil, nil -} - -// Scan scan value to a struct -func (db *DB) Scan(dest interface{}) (tx *DB) { +func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() return } -func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error { - return nil -} - -// Create insert the value into database -func (db *DB) Create(value interface{}) (tx *DB) { - tx = db.getInstance() - return -} - -// Save update value in database, if the value doesn't have primary key, will insert it -func (db *DB) Save(value interface{}) (tx *DB) { +func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() return } @@ -78,7 +67,7 @@ func (db *DB) Updates(values interface{}) (tx *DB) { return } -func (db *DB) UpdateColumn(attrs ...interface{}) (tx *DB) { +func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { tx = db.getInstance() return } @@ -88,16 +77,6 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) { return } -func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) { - tx = db.getInstance() - return -} - -func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { - tx = db.getInstance() - return -} - // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition func (db *DB) Delete(value interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() @@ -119,6 +98,29 @@ func (db *DB) Association(column string) *Association { return nil } +func (db *DB) Count(value interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Row() *sql.Row { + return nil +} + +func (db *DB) Rows() (*sql.Rows, error) { + return nil, nil +} + +// Scan scan value to a struct +func (db *DB) Scan(dest interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error { + return nil +} + func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { panicked := true tx := db.Begin(opts...) diff --git a/go.mod b/go.mod index 516a9759..1f4d31a2 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,8 @@ module github.com/jinzhu/gorm go 1.13 -require github.com/jinzhu/inflection v1.0.0 +require ( + github.com/jinzhu/inflection v1.0.0 + github.com/lib/pq v1.3.0 + github.com/mattn/go-sqlite3 v2.0.3+incompatible +) diff --git a/gorm.go b/gorm.go index 8ac7e057..a72314bd 100644 --- a/gorm.go +++ b/gorm.go @@ -28,10 +28,11 @@ type DB struct { *Config Dialector Instance - DB CommonDB - clone bool - callbacks *callbacks - cacheStore *sync.Map + DB CommonDB + ClauseBuilders map[string]clause.ClauseBuilder + clone bool + callbacks *callbacks + cacheStore *sync.Map } // Session session config when create session with Session() method @@ -140,11 +141,12 @@ func (db *DB) getInstance() *DB { Context: ctx, Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}}, }, - Config: db.Config, - Dialector: db.Dialector, - DB: db.DB, - callbacks: db.callbacks, - cacheStore: db.cacheStore, + Config: db.Config, + Dialector: db.Dialector, + ClauseBuilders: db.ClauseBuilders, + DB: db.DB, + callbacks: db.callbacks, + cacheStore: db.cacheStore, } } diff --git a/interfaces.go b/interfaces.go index 98d04592..6ba24dc4 100644 --- a/interfaces.go +++ b/interfaces.go @@ -9,7 +9,7 @@ import ( type Dialector interface { Initialize(*DB) error Migrator() Migrator - BindVar(stmt Statement, v interface{}) string + BindVar(stmt *Statement, v interface{}) string } // CommonDB common db interface diff --git a/statement.go b/statement.go index 4d959cbb..c01be0f5 100644 --- a/statement.go +++ b/statement.go @@ -5,6 +5,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "log" "strconv" "strings" "sync" @@ -21,7 +22,7 @@ type Instance struct { Statement *Statement } -func (instance Instance) ToSQL(clauses ...string) (string, []interface{}) { +func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) { if len(clauses) > 0 { instance.Statement.Build(clauses...) } @@ -29,7 +30,7 @@ func (instance Instance) ToSQL(clauses ...string) (string, []interface{}) { } // AddError add error to instance -func (inst Instance) AddError(err error) { +func (inst *Instance) AddError(err error) { if inst.Error == nil { inst.Error = err } else { @@ -55,11 +56,11 @@ type Statement struct { // StatementOptimizer statement optimizer interface type StatementOptimizer interface { - OptimizeStatement(Statement) + OptimizeStatement(*Statement) } // Write write string -func (stmt Statement) Write(sql ...string) (err error) { +func (stmt *Statement) Write(sql ...string) (err error) { for _, s := range sql { _, err = stmt.SQL.WriteString(s) } @@ -67,12 +68,12 @@ func (stmt Statement) Write(sql ...string) (err error) { } // Write write string -func (stmt Statement) WriteByte(c byte) (err error) { +func (stmt *Statement) WriteByte(c byte) (err error) { return stmt.SQL.WriteByte(c) } // WriteQuoted write quoted field -func (stmt Statement) WriteQuoted(field interface{}) (err error) { +func (stmt *Statement) WriteQuoted(field interface{}) (err error) { _, err = stmt.SQL.WriteString(stmt.Quote(field)) return } @@ -107,7 +108,7 @@ func (stmt Statement) Quote(field interface{}) string { } // Write write string -func (stmt Statement) AddVar(vars ...interface{}) string { +func (stmt *Statement) AddVar(vars ...interface{}) string { var placeholders strings.Builder for idx, v := range vars { if idx > 0 { @@ -134,7 +135,7 @@ func (stmt Statement) AddVar(vars ...interface{}) string { } // AddClause add clause -func (stmt Statement) AddClause(v clause.Interface) { +func (stmt *Statement) AddClause(v clause.Interface) { if optimizer, ok := v.(StatementOptimizer); ok { optimizer.OptimizeStatement(stmt) } @@ -154,6 +155,30 @@ func (stmt Statement) AddClause(v clause.Interface) { stmt.Clauses[v.Name()] = c } +// AddClauseIfNotExists add clause if not exists +func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { + if optimizer, ok := v.(StatementOptimizer); ok { + optimizer.OptimizeStatement(stmt) + } + + log.Println(v.Name()) + if c, ok := stmt.Clauses[v.Name()]; !ok { + if namer, ok := v.(clause.OverrideNameInterface); ok { + c.Name = namer.OverrideName() + } else { + c.Name = v.Name() + } + + if c.Expression != nil { + v.MergeExpression(c.Expression) + } + + c.Expression = v + stmt.Clauses[v.Name()] = c + log.Println(stmt.Clauses[v.Name()]) + } +} + // BuildCondtion build condition func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) { if sql, ok := query.(string); ok { @@ -211,7 +236,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con } // Build build sql with clauses names -func (stmt Statement) Build(clauses ...string) { +func (stmt *Statement) Build(clauses ...string) { var firstClauseWritten bool for _, name := range clauses { @@ -221,7 +246,11 @@ func (stmt Statement) Build(clauses ...string) { } firstClauseWritten = true - c.Build(stmt) + if b, ok := stmt.DB.ClauseBuilders[name]; ok { + b.Build(c, stmt) + } else { + c.Build(stmt) + } } } // TODO handle named vars From 46b1c85f88e332a36dec31b17a3bd8e6eae07da9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 4 Feb 2020 08:56:15 +0800 Subject: [PATCH 226/881] Add more clauses --- callbacks.go | 20 +++++++++++++------- callbacks/callbacks.go | 6 ++++-- callbacks/create.go | 2 +- callbacks/query.go | 17 ++++++++++++++++- chainable_api.go | 19 ++++++++++++++++++- clause/clause.go | 31 +++++++++++-------------------- clause/expression.go | 25 ++++++++++++++++++------- clause/from.go | 7 +++++++ clause/on_conflict.go | 6 ++++++ clause/order_by.go | 34 ++++++++++++++++++++++++++++++++++ clause/select.go | 12 ++++++++---- finisher_api.go | 8 ++++++-- gorm.go | 9 +++++---- statement.go | 16 +++++++++++++--- 14 files changed, 160 insertions(+), 52 deletions(-) create mode 100644 clause/on_conflict.go diff --git a/callbacks.go b/callbacks.go index 51ee150f..8546ae16 100644 --- a/callbacks.go +++ b/callbacks.go @@ -69,14 +69,20 @@ func (cs *callbacks) Raw() *processor { } func (p *processor) Execute(db *DB) { - if stmt := db.Statement; stmt != nil && stmt.Dest != nil { - var err error - stmt.Schema, err = schema.Parse(stmt.Dest, db.cacheStore, db.NamingStrategy) + if stmt := db.Statement; stmt != nil { + if stmt.Model == nil { + stmt.Model = stmt.Dest + } - if err != nil && !errors.Is(err, schema.ErrUnsupportedDataType) { - db.AddError(err) - } else if stmt.Table == "" && stmt.Schema != nil { - stmt.Table = stmt.Schema.Table + if stmt.Model != nil { + var err error + stmt.Schema, err = schema.Parse(stmt.Model, db.cacheStore, db.NamingStrategy) + + if err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") { + db.AddError(err) + } else if stmt.Table == "" && stmt.Schema != nil { + stmt.Table = stmt.Schema.Table + } } } diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index a3e5245b..f9d5543d 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -1,6 +1,8 @@ package callbacks -import "github.com/jinzhu/gorm" +import ( + "github.com/jinzhu/gorm" +) func RegisterDefaultCallbacks(db *gorm.DB) { enableTransaction := func(db *gorm.DB) bool { @@ -17,7 +19,7 @@ func RegisterDefaultCallbacks(db *gorm.DB) { createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) queryCallback := db.Callback().Query() - queryCallback.Register("gorm:query", BeforeCreate) + queryCallback.Register("gorm:query", Query) queryCallback.Register("gorm:preload", Preload) queryCallback.Register("gorm:after_query", AfterQuery) diff --git a/callbacks/create.go b/callbacks/create.go index 983b95ce..58256085 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -22,7 +22,7 @@ func Create(db *gorm.DB) { Table: clause.Table{Table: db.Statement.Table}, }) - db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT", "RETURNING") + db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) fmt.Println(err) fmt.Println(result) diff --git a/callbacks/query.go b/callbacks/query.go index 5d27ea17..edf8f281 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -1,8 +1,23 @@ package callbacks -import "github.com/jinzhu/gorm" +import ( + "fmt" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" +) func Query(db *gorm.DB) { + db.Statement.AddClauseIfNotExists(clause.Select{}) + db.Statement.AddClauseIfNotExists(clause.From{ + Tables: []clause.Table{{Table: clause.CurrentTable}}, + }) + + db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") + result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + fmt.Println(err) + fmt.Println(result) + fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) } func Preload(db *gorm.DB) { diff --git a/chainable_api.go b/chainable_api.go index b577d5cf..f358d316 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -1,6 +1,10 @@ package gorm -import "github.com/jinzhu/gorm/clause" +import ( + "fmt" + + "github.com/jinzhu/gorm/clause" +) // Model specify the model you would like to run db operations // // update all users's name to `hello` @@ -107,6 +111,19 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { // db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression func (db *DB) Order(value interface{}) (tx *DB) { tx = db.getInstance() + + switch v := value.(type) { + case clause.OrderBy: + db.Statement.AddClause(clause.OrderByClause{ + Columns: []clause.OrderBy{v}, + }) + default: + db.Statement.AddClause(clause.OrderByClause{ + Columns: []clause.OrderBy{{ + Column: clause.Column{Name: fmt.Sprint(value), Raw: true}, + }}, + }) + } return } diff --git a/clause/clause.go b/clause/clause.go index c0ebe7e2..6d4698e9 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -11,11 +11,6 @@ type Clause struct { Builder ClauseBuilder } -// ClauseBuilder clause builder, allows to custmize how to build clause -type ClauseBuilder interface { - Build(Clause, Builder) -} - // Build build clause func (c Clause) Build(builder Builder) { if c.Builder != nil { @@ -47,25 +42,21 @@ type Interface interface { MergeExpression(Expression) } +// OverrideNameInterface override name interface type OverrideNameInterface interface { OverrideName() string } -// Column quote with name -type Column struct { - Table string - Name string - Alias string - Raw bool +// ClauseBuilder clause builder, allows to custmize how to build clause +type ClauseBuilder interface { + Build(Clause, Builder) } -func ToColumns(value ...interface{}) []Column { - return nil -} - -// Table quote with name -type Table struct { - Table string - Alias string - Raw bool +// Builder builder interface +type Builder interface { + WriteByte(byte) error + Write(sql ...string) error + WriteQuoted(field interface{}) error + AddVar(vars ...interface{}) string + Quote(field interface{}) string } diff --git a/clause/expression.go b/clause/expression.go index 17313d43..722df7c7 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -1,5 +1,10 @@ package clause +const ( + PrimaryKey string = "@@@priamry_key@@@" + CurrentTable string = "@@@table@@@" +) + // Expression expression interface type Expression interface { Build(builder Builder) @@ -10,13 +15,19 @@ type NegationExpressionBuilder interface { NegationBuild(builder Builder) } -// Builder builder interface -type Builder interface { - WriteByte(byte) error - Write(sql ...string) error - WriteQuoted(field interface{}) error - AddVar(vars ...interface{}) string - Quote(field interface{}) string +// Column quote with name +type Column struct { + Table string + Name string + Alias string + Raw bool +} + +// Table quote with name +type Table struct { + Table string + Alias string + Raw bool } // Expr raw expression diff --git a/clause/from.go b/clause/from.go index 610d69a4..1a7bcb5c 100644 --- a/clause/from.go +++ b/clause/from.go @@ -20,3 +20,10 @@ func (from From) Build(builder Builder) { builder.WriteQuoted(table) } } + +// MergeExpression merge order by clauses +func (from From) MergeExpression(expr Expression) { + if v, ok := expr.(From); ok { + from.Tables = append(v.Tables, from.Tables...) + } +} diff --git a/clause/on_conflict.go b/clause/on_conflict.go new file mode 100644 index 00000000..5cbe3dd7 --- /dev/null +++ b/clause/on_conflict.go @@ -0,0 +1,6 @@ +package clause + +type OnConflict struct { + ON string // duplicate key + Values *Values // update c=c+1 +} diff --git a/clause/order_by.go b/clause/order_by.go index a11a3c48..6025e1ba 100644 --- a/clause/order_by.go +++ b/clause/order_by.go @@ -1,4 +1,38 @@ package clause type OrderBy struct { + Column Column + Desc bool + Reorder bool +} + +type OrderByClause struct { + Columns []OrderBy +} + +// Name where clause name +func (orderBy OrderByClause) Name() string { + return "ORDER BY" +} + +// Build build where clause +func (orderBy OrderByClause) Build(builder Builder) { + for i := len(orderBy.Columns) - 1; i >= 0; i-- { + builder.WriteQuoted(orderBy.Columns[i].Column) + + if orderBy.Columns[i].Desc { + builder.Write(" DESC") + } + + if orderBy.Columns[i].Reorder { + break + } + } +} + +// MergeExpression merge order by clauses +func (orderBy OrderByClause) MergeExpression(expr Expression) { + if v, ok := expr.(OrderByClause); ok { + orderBy.Columns = append(v.Columns, orderBy.Columns...) + } } diff --git a/clause/select.go b/clause/select.go index 1342c411..7f0e4438 100644 --- a/clause/select.go +++ b/clause/select.go @@ -1,15 +1,19 @@ package clause +// SelectInterface select clause interface +type SelectInterface interface { + Selects() []Column + Omits() []Column +} + // Select select attrs when querying, updating, creating type Select struct { SelectColumns []Column OmitColumns []Column } -// SelectInterface select clause interface -type SelectInterface interface { - Selects() []Column - Omits() []Column +func (s Select) Name() string { + return "SELECT" } func (s Select) Selects() []Column { diff --git a/finisher_api.go b/finisher_api.go index a311ca78..06809651 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -2,6 +2,8 @@ package gorm import ( "database/sql" + + "github.com/jinzhu/gorm/clause" ) // Create insert the value into database @@ -20,9 +22,11 @@ func (db *DB) Save(value interface{}) (tx *DB) { // First find first record that match given conditions, order by primary key func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { - tx = db.getInstance() + tx = db.getInstance().Limit(1).Order(clause.OrderBy{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Desc: true, + }) tx.Statement.Dest = out - tx.Limit(1) tx.callbacks.Query().Execute(tx) return } diff --git a/gorm.go b/gorm.go index a72314bd..10d61f80 100644 --- a/gorm.go +++ b/gorm.go @@ -61,10 +61,11 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { } db = &DB{ - Config: config, - Dialector: dialector, - clone: true, - cacheStore: &sync.Map{}, + Config: config, + Dialector: dialector, + ClauseBuilders: map[string]clause.ClauseBuilder{}, + clone: true, + cacheStore: &sync.Map{}, } db.callbacks = initializeCallbacks(db) diff --git a/statement.go b/statement.go index c01be0f5..b2407599 100644 --- a/statement.go +++ b/statement.go @@ -84,18 +84,28 @@ func (stmt Statement) Quote(field interface{}) string { switch v := field.(type) { case clause.Table: - str.WriteString(v.Table) + if v.Alias != "" { str.WriteString(" AS ") str.WriteString(v.Alias) } case clause.Column: if v.Table != "" { - str.WriteString(v.Table) + if v.Table == clause.CurrentTable { + str.WriteString(stmt.Table) + } else { + str.WriteString(v.Table) + } str.WriteByte('.') } - str.WriteString(v.Name) + if v.Name == clause.PrimaryKey { + if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil { + str.WriteString(stmt.Schema.PrioritizedPrimaryField.DBName) + } + } else { + str.WriteString(v.Name) + } if v.Alias != "" { str.WriteString(" AS ") str.WriteString(v.Alias) From 9d19be0826ab6b22b435160af73042e5de82a758 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 4 Feb 2020 09:51:19 +0800 Subject: [PATCH 227/881] Setup clauses tests --- callbacks/query.go | 4 +--- clause/clause_test.go | 54 +++++++++++++++++++++++++++++++++++++++++++ clause/from.go | 16 +++++++++---- statement.go | 5 ++++ 4 files changed, 71 insertions(+), 8 deletions(-) create mode 100644 clause/clause_test.go diff --git a/callbacks/query.go b/callbacks/query.go index edf8f281..8d13095e 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -9,9 +9,7 @@ import ( func Query(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Select{}) - db.Statement.AddClauseIfNotExists(clause.From{ - Tables: []clause.Table{{Table: clause.CurrentTable}}, - }) + db.Statement.AddClauseIfNotExists(clause.From{}) db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/clause/clause_test.go b/clause/clause_test.go new file mode 100644 index 00000000..97d30f2d --- /dev/null +++ b/clause/clause_test.go @@ -0,0 +1,54 @@ +package clause_test + +import ( + "fmt" + "reflect" + "sync" + "testing" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/gorm/tests" +) + +func TestClause(t *testing.T) { + var ( + db, _ = gorm.Open(nil, nil) + results = []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{{ + []clause.Interface{clause.Select{}, clause.From{}}, + "SELECT * FROM users", []interface{}{}, + }} + ) + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + var ( + user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt = gorm.Statement{ + DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}, + } + buildNames []string + ) + + for _, c := range result.Clauses { + buildNames = append(buildNames, c.Name()) + stmt.AddClause(c) + } + + stmt.Build(buildNames...) + + if stmt.SQL.String() != result.Result { + t.Errorf("SQL expects %v got %v", result.Result, stmt.SQL.String()) + } + + if reflect.DeepEqual(stmt.Vars, result.Vars) { + t.Errorf("Vars expects %+v got %v", stmt.Vars, result.Vars) + } + }) + } +} diff --git a/clause/from.go b/clause/from.go index 1a7bcb5c..b7665bc3 100644 --- a/clause/from.go +++ b/clause/from.go @@ -10,14 +10,20 @@ func (From) Name() string { return "FROM" } +var currentTable = Table{Table: CurrentTable} + // Build build from clause func (from From) Build(builder Builder) { - for idx, table := range from.Tables { - if idx > 0 { - builder.WriteByte(',') - } + if len(from.Tables) > 0 { + for idx, table := range from.Tables { + if idx > 0 { + builder.WriteByte(',') + } - builder.WriteQuoted(table) + builder.WriteQuoted(table) + } + } else { + builder.WriteQuoted(currentTable) } } diff --git a/statement.go b/statement.go index b2407599..26acb319 100644 --- a/statement.go +++ b/statement.go @@ -84,6 +84,11 @@ func (stmt Statement) Quote(field interface{}) string { switch v := field.(type) { case clause.Table: + if v.Table == clause.CurrentTable { + str.WriteString(stmt.Table) + } else { + str.WriteString(v.Table) + } if v.Alias != "" { str.WriteString(" AS ") From 0160bab7dccd14a6b936bd8884ec3058e5b45972 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 5 Feb 2020 11:14:58 +0800 Subject: [PATCH 228/881] Add clause tests --- chainable_api.go | 2 +- clause/clause_test.go | 14 ++++++++------ clause/expression.go | 5 +++++ clause/query.go | 12 ++++++++++-- clause/where.go | 8 ++++---- dialects/mysql/mysql.go | 4 ++++ dialects/postgres/postgres.go | 4 ++++ dialects/sqlite/sqlite.go | 4 ++++ go.mod | 5 +++-- gorm.go | 19 +++++++++++++------ interfaces.go | 1 + statement.go | 11 +++++++++++ tests/dummy_dialecter.go | 24 ++++++++++++++++++++++++ 13 files changed, 92 insertions(+), 21 deletions(-) create mode 100644 tests/dummy_dialecter.go diff --git a/chainable_api.go b/chainable_api.go index f358d316..cac7495d 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -80,7 +80,7 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Where{ - ORConditions: []clause.ORConditions{ + OrConditions: []clause.OrConditions{ tx.Statement.BuildCondtion(query, args...), }, }) diff --git a/clause/clause_test.go b/clause/clause_test.go index 97d30f2d..37f07686 100644 --- a/clause/clause_test.go +++ b/clause/clause_test.go @@ -12,17 +12,19 @@ import ( "github.com/jinzhu/gorm/tests" ) -func TestClause(t *testing.T) { +func TestClauses(t *testing.T) { var ( - db, _ = gorm.Open(nil, nil) + db, _ = gorm.Open(tests.DummyDialector{}, nil) results = []struct { Clauses []clause.Interface Result string Vars []interface{} - }{{ - []clause.Interface{clause.Select{}, clause.From{}}, - "SELECT * FROM users", []interface{}{}, - }} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{AndConditions: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}}}}, + "SELECT * FROM `users` WHERE `users`.`id` = ?", []interface{}{"1"}, + }, + } ) for idx, result := range results { diff --git a/clause/expression.go b/clause/expression.go index 722df7c7..3ddc146d 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -5,6 +5,11 @@ const ( CurrentTable string = "@@@table@@@" ) +var PrimaryColumn = Column{ + Table: CurrentTable, + Name: PrimaryKey, +} + // Expression expression interface type Expression interface { Build(builder Builder) diff --git a/clause/query.go b/clause/query.go index 949678d9..ce609014 100644 --- a/clause/query.go +++ b/clause/query.go @@ -6,6 +6,14 @@ import "strings" // Query Expressions //////////////////////////////////////////////////////////////////////////////// +func Add(exprs ...Expression) AddConditions { + return AddConditions(exprs) +} + +func Or(exprs ...Expression) OrConditions { + return OrConditions(exprs) +} + type AddConditions []Expression func (cs AddConditions) Build(builder Builder) { @@ -17,9 +25,9 @@ func (cs AddConditions) Build(builder Builder) { } } -type ORConditions []Expression +type OrConditions []Expression -func (cs ORConditions) Build(builder Builder) { +func (cs OrConditions) Build(builder Builder) { for idx, c := range cs { if idx > 0 { builder.Write(" OR ") diff --git a/clause/where.go b/clause/where.go index 888b9d07..de82662c 100644 --- a/clause/where.go +++ b/clause/where.go @@ -3,7 +3,7 @@ package clause // Where where clause type Where struct { AndConditions AddConditions - ORConditions []ORConditions + OrConditions []OrConditions builders []Expression } @@ -31,8 +31,8 @@ func (where Where) Build(builder Builder) { } } - var singleOrConditions []ORConditions - for _, or := range where.ORConditions { + var singleOrConditions []OrConditions + for _, or := range where.OrConditions { if len(or) == 1 { if withConditions { builder.Write(" OR ") @@ -69,7 +69,7 @@ func (where Where) Build(builder Builder) { func (where Where) MergeExpression(expr Expression) { if w, ok := expr.(Where); ok { where.AndConditions = append(where.AndConditions, w.AndConditions...) - where.ORConditions = append(where.ORConditions, w.ORConditions...) + where.OrConditions = append(where.OrConditions, w.OrConditions...) where.builders = append(where.builders, w.builders...) } else { where.builders = append(where.builders, expr) diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index ba306889..b402ef95 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -27,3 +27,7 @@ func (Dialector) Migrator() gorm.Migrator { func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string { return "?" } + +func (Dialector) QuoteChars() [2]byte { + return [2]byte{'`', '`'} // `name` +} diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 3abf05e3..9ea0048a 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -31,3 +31,7 @@ func (Dialector) Migrator() gorm.Migrator { func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } + +func (Dialector) QuoteChars() [2]byte { + return [2]byte{'"', '"'} // "name" +} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 91c3389e..80a18cfb 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -31,3 +31,7 @@ func (Dialector) Migrator() gorm.Migrator { func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } + +func (Dialector) QuoteChars() [2]byte { + return [2]byte{'`', '`'} // `name` +} diff --git a/go.mod b/go.mod index 1f4d31a2..e47297fb 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,8 @@ module github.com/jinzhu/gorm go 1.13 require ( + github.com/go-sql-driver/mysql v1.5.0 // indirect github.com/jinzhu/inflection v1.0.0 - github.com/lib/pq v1.3.0 - github.com/mattn/go-sqlite3 v2.0.3+incompatible + github.com/lib/pq v1.3.0 // indirect + github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect ) diff --git a/gorm.go b/gorm.go index 10d61f80..23f812d1 100644 --- a/gorm.go +++ b/gorm.go @@ -23,16 +23,21 @@ type Config struct { NowFunc func() time.Time } +type shared struct { + callbacks *callbacks + cacheStore *sync.Map + quoteChars [2]byte +} + // DB GORM DB definition type DB struct { *Config Dialector Instance - DB CommonDB ClauseBuilders map[string]clause.ClauseBuilder + DB CommonDB clone bool - callbacks *callbacks - cacheStore *sync.Map + *shared } // Session session config when create session with Session() method @@ -65,13 +70,16 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { Dialector: dialector, ClauseBuilders: map[string]clause.ClauseBuilder{}, clone: true, - cacheStore: &sync.Map{}, + shared: &shared{ + cacheStore: &sync.Map{}, + }, } db.callbacks = initializeCallbacks(db) if dialector != nil { err = dialector.Initialize(db) + db.quoteChars = dialector.QuoteChars() } return } @@ -146,8 +154,7 @@ func (db *DB) getInstance() *DB { Dialector: db.Dialector, ClauseBuilders: db.ClauseBuilders, DB: db.DB, - callbacks: db.callbacks, - cacheStore: db.cacheStore, + shared: db.shared, } } diff --git a/interfaces.go b/interfaces.go index 6ba24dc4..71522455 100644 --- a/interfaces.go +++ b/interfaces.go @@ -10,6 +10,7 @@ type Dialector interface { Initialize(*DB) error Migrator() Migrator BindVar(stmt *Statement, v interface{}) string + QuoteChars() [2]byte } // CommonDB common db interface diff --git a/statement.go b/statement.go index 26acb319..bc07b6e4 100644 --- a/statement.go +++ b/statement.go @@ -81,6 +81,7 @@ func (stmt *Statement) WriteQuoted(field interface{}) (err error) { // Quote returns quoted value func (stmt Statement) Quote(field interface{}) string { var str strings.Builder + str.WriteByte(stmt.DB.quoteChars[0]) switch v := field.(type) { case clause.Table: @@ -91,8 +92,11 @@ func (stmt Statement) Quote(field interface{}) string { } if v.Alias != "" { + str.WriteByte(stmt.DB.quoteChars[1]) str.WriteString(" AS ") + str.WriteByte(stmt.DB.quoteChars[0]) str.WriteString(v.Alias) + str.WriteByte(stmt.DB.quoteChars[1]) } case clause.Column: if v.Table != "" { @@ -101,7 +105,9 @@ func (stmt Statement) Quote(field interface{}) string { } else { str.WriteString(v.Table) } + str.WriteByte(stmt.DB.quoteChars[1]) str.WriteByte('.') + str.WriteByte(stmt.DB.quoteChars[0]) } if v.Name == clause.PrimaryKey { @@ -111,14 +117,19 @@ func (stmt Statement) Quote(field interface{}) string { } else { str.WriteString(v.Name) } + if v.Alias != "" { + str.WriteByte(stmt.DB.quoteChars[1]) str.WriteString(" AS ") + str.WriteByte(stmt.DB.quoteChars[0]) str.WriteString(v.Alias) + str.WriteByte(stmt.DB.quoteChars[1]) } default: fmt.Sprint(field) } + str.WriteByte(stmt.DB.quoteChars[1]) return str.String() } diff --git a/tests/dummy_dialecter.go b/tests/dummy_dialecter.go new file mode 100644 index 00000000..e2cda8fc --- /dev/null +++ b/tests/dummy_dialecter.go @@ -0,0 +1,24 @@ +package tests + +import ( + "github.com/jinzhu/gorm" +) + +type DummyDialector struct { +} + +func (DummyDialector) Initialize(*gorm.DB) error { + return nil +} + +func (DummyDialector) Migrator() gorm.Migrator { + return nil +} + +func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string { + return "?" +} + +func (DummyDialector) QuoteChars() [2]byte { + return [2]byte{'`', '`'} // `name` +} From 1f38ec4410c763aea65e6c086b9c47b8a5318228 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 7 Feb 2020 23:45:35 +0800 Subject: [PATCH 229/881] Finish clauses tests --- chainable_api.go | 30 ++-- clause/clause.go | 66 ++++++--- clause/clause_test.go | 55 +++---- clause/delete.go | 23 +++ clause/delete_test.go | 31 ++++ clause/expression.go | 172 ++++++++++++++++++---- clause/from.go | 63 +++++++- clause/from_test.go | 75 ++++++++++ clause/group_by.go | 33 ++++- clause/group_by_test.go | 40 +++++ clause/insert.go | 25 ++-- clause/insert_test.go | 35 +++++ clause/join.go | 23 --- clause/limit.go | 40 ++++- clause/limit_test.go | 46 ++++++ clause/locking.go | 48 ++++++ clause/locking_test.go | 43 ++++++ clause/on_conflict.go | 6 - clause/order_by.go | 39 +++-- clause/order_by_test.go | 49 +++++++ clause/query.go | 258 --------------------------------- clause/returning.go | 30 ++++ clause/returning_test.go | 36 +++++ clause/select.go | 35 ++--- clause/select_test.go | 41 ++++++ clause/set.go | 37 +++++ clause/set_test.go | 38 +++++ clause/update.go | 38 +++++ clause/update_test.go | 35 +++++ clause/{value.go => values.go} | 10 +- clause/values_test.go | 33 +++++ clause/where.go | 156 +++++++++++++------- clause/where_test.go | 63 ++++++++ finisher_api.go | 2 +- statement.go | 72 ++++----- 35 files changed, 1282 insertions(+), 544 deletions(-) create mode 100644 clause/delete.go create mode 100644 clause/delete_test.go create mode 100644 clause/from_test.go create mode 100644 clause/group_by_test.go create mode 100644 clause/insert_test.go delete mode 100644 clause/join.go create mode 100644 clause/limit_test.go create mode 100644 clause/locking.go create mode 100644 clause/locking_test.go delete mode 100644 clause/on_conflict.go create mode 100644 clause/order_by_test.go delete mode 100644 clause/query.go create mode 100644 clause/returning.go create mode 100644 clause/returning_test.go create mode 100644 clause/select_test.go create mode 100644 clause/set.go create mode 100644 clause/set_test.go create mode 100644 clause/update.go create mode 100644 clause/update_test.go rename clause/{value.go => values.go} (76%) create mode 100644 clause/values_test.go create mode 100644 clause/where_test.go diff --git a/chainable_api.go b/chainable_api.go index cac7495d..432026cf 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -31,8 +31,8 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { } if len(whereConds) > 0 { - tx.Statement.AddClause(clause.Where{ - AndConditions: tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...), + tx.Statement.AddClause(&clause.Where{ + tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...), }) } return @@ -59,8 +59,8 @@ func (db *DB) Omit(columns ...string) (tx *DB) { func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Where{ - AndConditions: tx.Statement.BuildCondtion(query, args...), + tx.Statement.AddClause(&clause.Where{ + tx.Statement.BuildCondtion(query, args...), }) return } @@ -68,10 +68,8 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { // Not add NOT condition func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Where{ - AndConditions: []clause.Expression{ - clause.NotConditions(tx.Statement.BuildCondtion(query, args...)), - }, + tx.Statement.AddClause(&clause.Where{ + []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}, }) return } @@ -79,10 +77,8 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { // Or add OR conditions func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Where{ - OrConditions: []clause.OrConditions{ - tx.Statement.BuildCondtion(query, args...), - }, + tx.Statement.AddClause(&clause.Where{ + []clause.Expression{clause.Or(tx.Statement.BuildCondtion(query, args...)...)}, }) return } @@ -113,13 +109,13 @@ func (db *DB) Order(value interface{}) (tx *DB) { tx = db.getInstance() switch v := value.(type) { - case clause.OrderBy: - db.Statement.AddClause(clause.OrderByClause{ - Columns: []clause.OrderBy{v}, + case clause.OrderByColumn: + db.Statement.AddClause(clause.OrderBy{ + Columns: []clause.OrderByColumn{v}, }) default: - db.Statement.AddClause(clause.OrderByClause{ - Columns: []clause.OrderBy{{ + db.Statement.AddClause(clause.OrderBy{ + Columns: []clause.OrderByColumn{{ Column: clause.Column{Name: fmt.Sprint(value), Raw: true}, }}, }) diff --git a/clause/clause.go b/clause/clause.go index 6d4698e9..df8e3a57 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -1,5 +1,26 @@ package clause +// Interface clause interface +type Interface interface { + Name() string + Build(Builder) + MergeClause(*Clause) +} + +// ClauseBuilder clause builder, allows to custmize how to build clause +type ClauseBuilder interface { + Build(Clause, Builder) +} + +// Builder builder interface +type Builder interface { + WriteByte(byte) error + Write(sql ...string) error + WriteQuoted(field interface{}) error + AddVar(vars ...interface{}) string + Quote(field interface{}) string +} + // Clause type Clause struct { Name string // WHERE @@ -18,7 +39,7 @@ func (c Clause) Build(builder Builder) { } else { builders := c.BeforeExpressions if c.Name != "" { - builders = append(builders, Expr{c.Name}) + builders = append(builders, Expr{SQL: c.Name}) } builders = append(builders, c.AfterNameExpressions...) @@ -35,28 +56,27 @@ func (c Clause) Build(builder Builder) { } } -// Interface clause interface -type Interface interface { - Name() string - Build(Builder) - MergeExpression(Expression) +const ( + PrimaryKey string = "@@@priamry_key@@@" + CurrentTable string = "@@@table@@@" +) + +var ( + currentTable = Table{Name: CurrentTable} + PrimaryColumn = Column{Table: CurrentTable, Name: PrimaryKey} +) + +// Column quote with name +type Column struct { + Table string + Name string + Alias string + Raw bool } -// OverrideNameInterface override name interface -type OverrideNameInterface interface { - OverrideName() string -} - -// ClauseBuilder clause builder, allows to custmize how to build clause -type ClauseBuilder interface { - Build(Clause, Builder) -} - -// Builder builder interface -type Builder interface { - WriteByte(byte) error - Write(sql ...string) error - WriteQuoted(field interface{}) error - AddVar(vars ...interface{}) string - Quote(field interface{}) string +// Table quote with name +type Table struct { + Name string + Alias string + Raw bool } diff --git a/clause/clause_test.go b/clause/clause_test.go index 37f07686..30ea9343 100644 --- a/clause/clause_test.go +++ b/clause/clause_test.go @@ -1,8 +1,8 @@ package clause_test import ( - "fmt" "reflect" + "strings" "sync" "testing" @@ -12,45 +12,32 @@ import ( "github.com/jinzhu/gorm/tests" ) -func TestClauses(t *testing.T) { +var db, _ = gorm.Open(tests.DummyDialector{}, nil) + +func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, vars []interface{}) { var ( - db, _ = gorm.Open(tests.DummyDialector{}, nil) - results = []struct { - Clauses []clause.Interface - Result string - Vars []interface{} - }{ - { - []clause.Interface{clause.Select{}, clause.From{}, clause.Where{AndConditions: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}}}}, - "SELECT * FROM `users` WHERE `users`.`id` = ?", []interface{}{"1"}, - }, - } + buildNames []string + buildNamesMap = map[string]bool{} + user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} ) - for idx, result := range results { - t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { - var ( - user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) - stmt = gorm.Statement{ - DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}, - } - buildNames []string - ) + for _, c := range clauses { + if _, ok := buildNamesMap[c.Name()]; !ok { + buildNames = append(buildNames, c.Name()) + buildNamesMap[c.Name()] = true + } - for _, c := range result.Clauses { - buildNames = append(buildNames, c.Name()) - stmt.AddClause(c) - } + stmt.AddClause(c) + } - stmt.Build(buildNames...) + stmt.Build(buildNames...) - if stmt.SQL.String() != result.Result { - t.Errorf("SQL expects %v got %v", result.Result, stmt.SQL.String()) - } + if strings.TrimSpace(stmt.SQL.String()) != result { + t.Errorf("SQL expects %v got %v", result, stmt.SQL.String()) + } - if reflect.DeepEqual(stmt.Vars, result.Vars) { - t.Errorf("Vars expects %+v got %v", stmt.Vars, result.Vars) - } - }) + if !reflect.DeepEqual(stmt.Vars, vars) { + t.Errorf("Vars expects %+v got %v", stmt.Vars, vars) } } diff --git a/clause/delete.go b/clause/delete.go new file mode 100644 index 00000000..2a622b45 --- /dev/null +++ b/clause/delete.go @@ -0,0 +1,23 @@ +package clause + +type Delete struct { + Modifier string +} + +func (d Delete) Name() string { + return "DELETE" +} + +func (d Delete) Build(builder Builder) { + builder.Write("DELETE") + + if d.Modifier != "" { + builder.WriteByte(' ') + builder.Write(d.Modifier) + } +} + +func (d Delete) MergeClause(clause *Clause) { + clause.Name = "" + clause.Expression = d +} diff --git a/clause/delete_test.go b/clause/delete_test.go new file mode 100644 index 00000000..2faf8364 --- /dev/null +++ b/clause/delete_test.go @@ -0,0 +1,31 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestDelete(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Delete{}, clause.From{}}, + "DELETE FROM `users`", nil, + }, + { + []clause.Interface{clause.Delete{Modifier: "LOW_PRIORITY"}, clause.From{}}, + "DELETE LOW_PRIORITY FROM `users`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/expression.go b/clause/expression.go index 3ddc146d..048b0980 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -1,14 +1,6 @@ package clause -const ( - PrimaryKey string = "@@@priamry_key@@@" - CurrentTable string = "@@@table@@@" -) - -var PrimaryColumn = Column{ - Table: CurrentTable, - Name: PrimaryKey, -} +import "strings" // Expression expression interface type Expression interface { @@ -20,27 +12,155 @@ type NegationExpressionBuilder interface { NegationBuild(builder Builder) } -// Column quote with name -type Column struct { - Table string - Name string - Alias string - Raw bool -} - -// Table quote with name -type Table struct { - Table string - Alias string - Raw bool -} - // Expr raw expression type Expr struct { - Value string + SQL string + Vars []interface{} } // Build build raw expression func (expr Expr) Build(builder Builder) { - builder.Write(expr.Value) + sql := expr.SQL + for _, v := range expr.Vars { + sql = strings.Replace(sql, " ? ", " "+builder.AddVar(v)+" ", 1) + } + builder.Write(sql) +} + +// IN Whether a value is within a set of values +type IN struct { + Column interface{} + Values []interface{} +} + +func (in IN) Build(builder Builder) { + builder.WriteQuoted(in.Column) + + switch len(in.Values) { + case 0: + builder.Write(" IN (NULL)") + case 1: + builder.Write(" = ", builder.AddVar(in.Values...)) + default: + builder.Write(" IN (", builder.AddVar(in.Values...), ")") + } +} + +func (in IN) NegationBuild(builder Builder) { + switch len(in.Values) { + case 0: + case 1: + builder.Write(" <> ", builder.AddVar(in.Values...)) + default: + builder.Write(" NOT IN (", builder.AddVar(in.Values...), ")") + } +} + +// Eq equal to for where +type Eq struct { + Column interface{} + Value interface{} +} + +func (eq Eq) Build(builder Builder) { + builder.WriteQuoted(eq.Column) + + if eq.Value == nil { + builder.Write(" IS NULL") + } else { + builder.Write(" = ", builder.AddVar(eq.Value)) + } +} + +func (eq Eq) NegationBuild(builder Builder) { + Neq{eq.Column, eq.Value}.Build(builder) +} + +// Neq not equal to for where +type Neq Eq + +func (neq Neq) Build(builder Builder) { + builder.WriteQuoted(neq.Column) + + if neq.Value == nil { + builder.Write(" IS NOT NULL") + } else { + builder.Write(" <> ", builder.AddVar(neq.Value)) + } +} + +func (neq Neq) NegationBuild(builder Builder) { + Eq{neq.Column, neq.Value}.Build(builder) +} + +// Gt greater than for where +type Gt Eq + +func (gt Gt) Build(builder Builder) { + builder.WriteQuoted(gt.Column) + builder.Write(" > ", builder.AddVar(gt.Value)) +} + +func (gt Gt) NegationBuild(builder Builder) { + Lte{gt.Column, gt.Value}.Build(builder) +} + +// Gte greater than or equal to for where +type Gte Eq + +func (gte Gte) Build(builder Builder) { + builder.WriteQuoted(gte.Column) + builder.Write(" >= ", builder.AddVar(gte.Value)) +} + +func (gte Gte) NegationBuild(builder Builder) { + Lt{gte.Column, gte.Value}.Build(builder) +} + +// Lt less than for where +type Lt Eq + +func (lt Lt) Build(builder Builder) { + builder.WriteQuoted(lt.Column) + builder.Write(" < ", builder.AddVar(lt.Value)) +} + +func (lt Lt) NegationBuild(builder Builder) { + Gte{lt.Column, lt.Value}.Build(builder) +} + +// Lte less than or equal to for where +type Lte Eq + +func (lte Lte) Build(builder Builder) { + builder.WriteQuoted(lte.Column) + builder.Write(" <= ", builder.AddVar(lte.Value)) +} + +func (lte Lte) NegationBuild(builder Builder) { + Gt{lte.Column, lte.Value}.Build(builder) +} + +// Like whether string matches regular expression +type Like Eq + +func (like Like) Build(builder Builder) { + builder.WriteQuoted(like.Column) + builder.Write(" LIKE ", builder.AddVar(like.Value)) +} + +func (like Like) NegationBuild(builder Builder) { + builder.WriteQuoted(like.Column) + builder.Write(" NOT LIKE ", builder.AddVar(like.Value)) +} + +// Map +type Map map[interface{}]interface{} + +func (m Map) Build(builder Builder) { + // TODO +} + +func (m Map) NegationBuild(builder Builder) { + // TODO } diff --git a/clause/from.go b/clause/from.go index b7665bc3..f01065b5 100644 --- a/clause/from.go +++ b/clause/from.go @@ -3,15 +3,31 @@ package clause // From from clause type From struct { Tables []Table + Joins []Join +} + +type JoinType string + +const ( + CrossJoin JoinType = "CROSS" + InnerJoin = "INNER" + LeftJoin = "LEFT" + RightJoin = "RIGHT" +) + +// Join join clause for from +type Join struct { + Type JoinType + Table Table + ON Where + Using []string } // Name from clause name -func (From) Name() string { +func (from From) Name() string { return "FROM" } -var currentTable = Table{Table: CurrentTable} - // Build build from clause func (from From) Build(builder Builder) { if len(from.Tables) > 0 { @@ -25,11 +41,42 @@ func (from From) Build(builder Builder) { } else { builder.WriteQuoted(currentTable) } -} -// MergeExpression merge order by clauses -func (from From) MergeExpression(expr Expression) { - if v, ok := expr.(From); ok { - from.Tables = append(v.Tables, from.Tables...) + for _, join := range from.Joins { + builder.WriteByte(' ') + join.Build(builder) } } + +func (join Join) Build(builder Builder) { + if join.Type != "" { + builder.Write(string(join.Type)) + builder.WriteByte(' ') + } + + builder.Write("JOIN ") + builder.WriteQuoted(join.Table) + + if len(join.ON.Exprs) > 0 { + builder.Write(" ON ") + join.ON.Build(builder) + } else if len(join.Using) > 0 { + builder.Write(" USING (") + for idx, c := range join.Using { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(c) + } + builder.WriteByte(')') + } +} + +// MergeClause merge from clause +func (from From) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(From); ok { + from.Tables = append(v.Tables, from.Tables...) + from.Joins = append(v.Joins, from.Joins...) + } + clause.Expression = from +} diff --git a/clause/from_test.go b/clause/from_test.go new file mode 100644 index 00000000..4b7b0e18 --- /dev/null +++ b/clause/from_test.go @@ -0,0 +1,75 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestFrom(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}}, + "SELECT * FROM `users`", nil, + }, + { + []clause.Interface{ + clause.Select{}, clause.From{ + Tables: []clause.Table{{Name: "users"}}, + Joins: []clause.Join{ + { + Type: clause.InnerJoin, + Table: clause.Table{Name: "articles"}, + ON: clause.Where{ + []clause.Expression{clause.Eq{clause.Column{Table: "articles", Name: "id"}, clause.PrimaryColumn}}, + }, + }, + }, + }, + }, + "SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id`", nil, + }, + { + []clause.Interface{ + clause.Select{}, clause.From{ + Tables: []clause.Table{{Name: "users"}}, + Joins: []clause.Join{ + { + Type: clause.InnerJoin, + Table: clause.Table{Name: "articles"}, + ON: clause.Where{ + []clause.Expression{clause.Eq{clause.Column{Table: "articles", Name: "id"}, clause.PrimaryColumn}}, + }, + }, { + Type: clause.LeftJoin, + Table: clause.Table{Name: "companies"}, + Using: []string{"company_name"}, + }, + }, + }, clause.From{ + Joins: []clause.Join{ + { + Type: clause.RightJoin, + Table: clause.Table{Name: "profiles"}, + ON: clause.Where{ + []clause.Expression{clause.Eq{clause.Column{Table: "profiles", Name: "email"}, clause.Column{Table: clause.CurrentTable, Name: "email"}}}, + }, + }, + }, + }, + }, + "SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id` LEFT JOIN `companies` USING (`company_name`) RIGHT JOIN `profiles` ON `profiles`.`email` = `users`.`email`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/group_by.go b/clause/group_by.go index bce94109..8d164731 100644 --- a/clause/group_by.go +++ b/clause/group_by.go @@ -2,5 +2,36 @@ package clause // GroupBy group by clause type GroupBy struct { - Having Where + Columns []Column + Having Where +} + +// Name from clause name +func (groupBy GroupBy) Name() string { + return "GROUP BY" +} + +// Build build group by clause +func (groupBy GroupBy) Build(builder Builder) { + for idx, column := range groupBy.Columns { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(column) + } + + if len(groupBy.Having.Exprs) > 0 { + builder.Write(" HAVING ") + groupBy.Having.Build(builder) + } +} + +// MergeClause merge group by clause +func (groupBy GroupBy) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(GroupBy); ok { + groupBy.Columns = append(v.Columns, groupBy.Columns...) + groupBy.Having.Exprs = append(v.Having.Exprs, groupBy.Having.Exprs...) + } + clause.Expression = groupBy } diff --git a/clause/group_by_test.go b/clause/group_by_test.go new file mode 100644 index 00000000..35be84a4 --- /dev/null +++ b/clause/group_by_test.go @@ -0,0 +1,40 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestGroupBy(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{ + Columns: []clause.Column{{Name: "role"}}, + Having: clause.Where{[]clause.Expression{clause.Eq{"role", "admin"}}}, + }}, + "SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", []interface{}{"admin"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{ + Columns: []clause.Column{{Name: "role"}}, + Having: clause.Where{[]clause.Expression{clause.Eq{"role", "admin"}}}, + }, clause.GroupBy{ + Columns: []clause.Column{{Name: "gender"}}, + Having: clause.Where{[]clause.Expression{clause.Neq{"gender", "U"}}}, + }}, + "SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", []interface{}{"admin", "U"}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/insert.go b/clause/insert.go index e056b35e..3f86c98f 100644 --- a/clause/insert.go +++ b/clause/insert.go @@ -2,7 +2,7 @@ package clause type Insert struct { Table Table - Priority string + Modifier string } // Name insert clause name @@ -12,23 +12,28 @@ func (insert Insert) Name() string { // Build build insert clause func (insert Insert) Build(builder Builder) { - if insert.Priority != "" { - builder.Write(insert.Priority) + if insert.Modifier != "" { + builder.Write(insert.Modifier) builder.WriteByte(' ') } builder.Write("INTO ") - builder.WriteQuoted(insert.Table) + if insert.Table.Name == "" { + builder.WriteQuoted(currentTable) + } else { + builder.WriteQuoted(insert.Table) + } } -// MergeExpression merge insert clauses -func (insert Insert) MergeExpression(expr Expression) { - if v, ok := expr.(Insert); ok { - if insert.Priority == "" { - insert.Priority = v.Priority +// MergeClause merge insert clause +func (insert Insert) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Insert); ok { + if insert.Modifier == "" { + insert.Modifier = v.Modifier } - if insert.Table.Table == "" { + if insert.Table.Name == "" { insert.Table = v.Table } } + clause.Expression = insert } diff --git a/clause/insert_test.go b/clause/insert_test.go new file mode 100644 index 00000000..b1a57803 --- /dev/null +++ b/clause/insert_test.go @@ -0,0 +1,35 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestInsert(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Insert{}}, + "INSERT INTO `users`", nil, + }, + { + []clause.Interface{clause.Insert{Modifier: "LOW_PRIORITY"}}, + "INSERT LOW_PRIORITY INTO `users`", nil, + }, + { + []clause.Interface{clause.Insert{Table: clause.Table{Name: "products"}, Modifier: "LOW_PRIORITY"}}, + "INSERT LOW_PRIORITY INTO `products`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/join.go b/clause/join.go deleted file mode 100644 index 6b0e8f97..00000000 --- a/clause/join.go +++ /dev/null @@ -1,23 +0,0 @@ -package clause - -// Join join clause -type Join struct { - Table From // From - Type string // INNER, LEFT, RIGHT, FULL, CROSS JOIN - Using []Column - ON Where -} - -// TODO multiple joins - -func (join Join) Build(builder Builder) { - // TODO -} - -func (join Join) MergeExpression(expr Expression) { - // if j, ok := expr.(Join); ok { - // join.builders = append(join.builders, j.builders...) - // } else { - // join.builders = append(join.builders, expr) - // } -} diff --git a/clause/limit.go b/clause/limit.go index 8fbc0055..7b16f339 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -1,6 +1,44 @@ package clause +import "strconv" + // Limit limit clause type Limit struct { - Offset uint + Limit int + Offset int +} + +// Name where clause name +func (limit Limit) Name() string { + return "LIMIT" +} + +// Build build where clause +func (limit Limit) Build(builder Builder) { + if limit.Limit > 0 { + builder.Write("LIMIT ") + builder.Write(strconv.Itoa(limit.Limit)) + + if limit.Offset > 0 { + builder.Write(" OFFSET ") + builder.Write(strconv.Itoa(limit.Offset)) + } + } +} + +// MergeClause merge order by clauses +func (limit Limit) MergeClause(clause *Clause) { + clause.Name = "" + + if v, ok := clause.Expression.(Limit); ok { + if limit.Limit == 0 && v.Limit > 0 { + limit.Limit = v.Limit + } + + if limit.Offset == 0 && v.Offset > 0 { + limit.Offset = v.Offset + } + } + + clause.Expression = limit } diff --git a/clause/limit_test.go b/clause/limit_test.go new file mode 100644 index 00000000..7b76aaf4 --- /dev/null +++ b/clause/limit_test.go @@ -0,0 +1,46 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestLimit(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{ + Limit: 10, + Offset: 20, + }}, + "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}}, + "SELECT * FROM `users` LIMIT 10 OFFSET 30", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, + "SELECT * FROM `users` LIMIT 10", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}}, + "SELECT * FROM `users`", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}}, + "SELECT * FROM `users` LIMIT 50 OFFSET 30", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/locking.go b/clause/locking.go new file mode 100644 index 00000000..48b84b34 --- /dev/null +++ b/clause/locking.go @@ -0,0 +1,48 @@ +package clause + +type For struct { + Lockings []Locking +} + +type Locking struct { + Strength string + Table Table + Options string +} + +// Name where clause name +func (f For) Name() string { + return "FOR" +} + +// Build build where clause +func (f For) Build(builder Builder) { + for idx, locking := range f.Lockings { + if idx > 0 { + builder.WriteByte(' ') + } + + builder.Write("FOR ") + builder.Write(locking.Strength) + if locking.Table.Name != "" { + builder.Write(" OF ") + builder.WriteQuoted(locking.Table) + } + + if locking.Options != "" { + builder.WriteByte(' ') + builder.Write(locking.Options) + } + } +} + +// MergeClause merge order by clauses +func (f For) MergeClause(clause *Clause) { + clause.Name = "" + + if v, ok := clause.Expression.(For); ok { + f.Lockings = append(v.Lockings, f.Lockings...) + } + + clause.Expression = f +} diff --git a/clause/locking_test.go b/clause/locking_test.go new file mode 100644 index 00000000..6b054404 --- /dev/null +++ b/clause/locking_test.go @@ -0,0 +1,43 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestFor(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.For{ + Lockings: []clause.Locking{{Strength: "UPDATE"}}, + }}, + "SELECT * FROM `users` FOR UPDATE", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.For{ + Lockings: []clause.Locking{{Strength: "UPDATE"}, {Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}}, + }}, + "SELECT * FROM `users` FOR UPDATE FOR SHARE OF `users`", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.For{ + Lockings: []clause.Locking{{Strength: "UPDATE"}, {Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}}, + }, clause.For{ + Lockings: []clause.Locking{{Strength: "UPDATE", Options: "NOWAIT"}}, + }}, + "SELECT * FROM `users` FOR UPDATE FOR SHARE OF `users` FOR UPDATE NOWAIT", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/on_conflict.go b/clause/on_conflict.go deleted file mode 100644 index 5cbe3dd7..00000000 --- a/clause/on_conflict.go +++ /dev/null @@ -1,6 +0,0 @@ -package clause - -type OnConflict struct { - ON string // duplicate key - Values *Values // update c=c+1 -} diff --git a/clause/order_by.go b/clause/order_by.go index 6025e1ba..2734f2bc 100644 --- a/clause/order_by.go +++ b/clause/order_by.go @@ -1,38 +1,47 @@ package clause -type OrderBy struct { +type OrderByColumn struct { Column Column Desc bool Reorder bool } -type OrderByClause struct { - Columns []OrderBy +type OrderBy struct { + Columns []OrderByColumn } // Name where clause name -func (orderBy OrderByClause) Name() string { +func (orderBy OrderBy) Name() string { return "ORDER BY" } // Build build where clause -func (orderBy OrderByClause) Build(builder Builder) { - for i := len(orderBy.Columns) - 1; i >= 0; i-- { - builder.WriteQuoted(orderBy.Columns[i].Column) +func (orderBy OrderBy) Build(builder Builder) { + for idx, column := range orderBy.Columns { + if idx > 0 { + builder.WriteByte(',') + } - if orderBy.Columns[i].Desc { + builder.WriteQuoted(column.Column) + if column.Desc { builder.Write(" DESC") } - - if orderBy.Columns[i].Reorder { - break - } } } -// MergeExpression merge order by clauses -func (orderBy OrderByClause) MergeExpression(expr Expression) { - if v, ok := expr.(OrderByClause); ok { +// MergeClause merge order by clauses +func (orderBy OrderBy) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(OrderBy); ok { + for i := len(orderBy.Columns) - 1; i >= 0; i-- { + if orderBy.Columns[i].Reorder { + orderBy.Columns = orderBy.Columns[i:] + clause.Expression = orderBy + return + } + } + orderBy.Columns = append(v.Columns, orderBy.Columns...) } + + clause.Expression = orderBy } diff --git a/clause/order_by_test.go b/clause/order_by_test.go new file mode 100644 index 00000000..2c74a322 --- /dev/null +++ b/clause/order_by_test.go @@ -0,0 +1,49 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestOrderBy(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, + }}, + "SELECT * FROM `users` ORDER BY `users`.`id` DESC", nil, + }, + { + []clause.Interface{ + clause.Select{}, clause.From{}, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, + }, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "name"}}}, + }, + }, + "SELECT * FROM `users` ORDER BY `users`.`id` DESC,`name`", nil, + }, + { + []clause.Interface{ + clause.Select{}, clause.From{}, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, + }, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "name"}, Reorder: true}}, + }, + }, + "SELECT * FROM `users` ORDER BY `name`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/query.go b/clause/query.go deleted file mode 100644 index ce609014..00000000 --- a/clause/query.go +++ /dev/null @@ -1,258 +0,0 @@ -package clause - -import "strings" - -//////////////////////////////////////////////////////////////////////////////// -// Query Expressions -//////////////////////////////////////////////////////////////////////////////// - -func Add(exprs ...Expression) AddConditions { - return AddConditions(exprs) -} - -func Or(exprs ...Expression) OrConditions { - return OrConditions(exprs) -} - -type AddConditions []Expression - -func (cs AddConditions) Build(builder Builder) { - for idx, c := range cs { - if idx > 0 { - builder.Write(" AND ") - } - c.Build(builder) - } -} - -type OrConditions []Expression - -func (cs OrConditions) Build(builder Builder) { - for idx, c := range cs { - if idx > 0 { - builder.Write(" OR ") - } - c.Build(builder) - } -} - -type NotConditions []Expression - -func (cs NotConditions) Build(builder Builder) { - for idx, c := range cs { - if idx > 0 { - builder.Write(" AND ") - } - - if negationBuilder, ok := c.(NegationExpressionBuilder); ok { - negationBuilder.NegationBuild(builder) - } else { - builder.Write(" NOT ") - c.Build(builder) - } - } -} - -// String raw sql for where -type String struct { - SQL string - Values []interface{} -} - -func (str String) Build(builder Builder) { - sql := str.SQL - for _, v := range str.Values { - sql = strings.Replace(sql, " ? ", " "+builder.AddVar(v)+" ", 1) - } - builder.Write(sql) -} - -// IN Whether a value is within a set of values -type IN struct { - Column interface{} - Values []interface{} -} - -func (in IN) Build(builder Builder) { - builder.WriteQuoted(in.Column) - - switch len(in.Values) { - case 0: - builder.Write(" IN (NULL)") - case 1: - builder.Write(" = ", builder.AddVar(in.Values...)) - default: - builder.Write(" IN (", builder.AddVar(in.Values...), ")") - } -} - -func (in IN) NegationBuild(builder Builder) { - switch len(in.Values) { - case 0: - case 1: - builder.Write(" <> ", builder.AddVar(in.Values...)) - default: - builder.Write(" NOT IN (", builder.AddVar(in.Values...), ")") - } -} - -// Eq equal to for where -type Eq struct { - Column interface{} - Value interface{} -} - -func (eq Eq) Build(builder Builder) { - builder.WriteQuoted(eq.Column) - - if eq.Value == nil { - builder.Write(" IS NULL") - } else { - builder.Write(" = ", builder.AddVar(eq.Value)) - } -} - -func (eq Eq) NegationBuild(builder Builder) { - Neq{eq.Column, eq.Value}.Build(builder) -} - -// Neq not equal to for where -type Neq struct { - Column interface{} - Value interface{} -} - -func (neq Neq) Build(builder Builder) { - builder.WriteQuoted(neq.Column) - - if neq.Value == nil { - builder.Write(" IS NOT NULL") - } else { - builder.Write(" <> ", builder.AddVar(neq.Value)) - } -} - -func (neq Neq) NegationBuild(builder Builder) { - Eq{neq.Column, neq.Value}.Build(builder) -} - -// Gt greater than for where -type Gt struct { - Column interface{} - Value interface{} -} - -func (gt Gt) Build(builder Builder) { - builder.WriteQuoted(gt.Column) - builder.Write(" > ", builder.AddVar(gt.Value)) -} - -func (gt Gt) NegationBuild(builder Builder) { - Lte{gt.Column, gt.Value}.Build(builder) -} - -// Gte greater than or equal to for where -type Gte struct { - Column interface{} - Value interface{} -} - -func (gte Gte) Build(builder Builder) { - builder.WriteQuoted(gte.Column) - builder.Write(" >= ", builder.AddVar(gte.Value)) -} - -func (gte Gte) NegationBuild(builder Builder) { - Lt{gte.Column, gte.Value}.Build(builder) -} - -// Lt less than for where -type Lt struct { - Column interface{} - Value interface{} -} - -func (lt Lt) Build(builder Builder) { - builder.WriteQuoted(lt.Column) - builder.Write(" < ", builder.AddVar(lt.Value)) -} - -func (lt Lt) NegationBuild(builder Builder) { - Gte{lt.Column, lt.Value}.Build(builder) -} - -// Lte less than or equal to for where -type Lte struct { - Column interface{} - Value interface{} -} - -func (lte Lte) Build(builder Builder) { - builder.WriteQuoted(lte.Column) - builder.Write(" <= ", builder.AddVar(lte.Value)) -} - -func (lte Lte) NegationBuild(builder Builder) { - Gt{lte.Column, lte.Value}.Build(builder) -} - -// Like whether string matches regular expression -type Like struct { - Column interface{} - Value interface{} -} - -func (like Like) Build(builder Builder) { - builder.WriteQuoted(like.Column) - builder.Write(" LIKE ", builder.AddVar(like.Value)) -} - -func (like Like) NegationBuild(builder Builder) { - builder.WriteQuoted(like.Column) - builder.Write(" NOT LIKE ", builder.AddVar(like.Value)) -} - -// Map -type Map map[interface{}]interface{} - -func (m Map) Build(builder Builder) { - // TODO -} - -func (m Map) NegationBuild(builder Builder) { - // TODO -} - -// Attrs -type Attrs struct { - Value interface{} - Select []string - Omit []string -} - -func (attrs Attrs) Build(builder Builder) { - // TODO - // builder.WriteQuoted(like.Column) - // builder.Write(" LIKE ", builder.AddVar(like.Value)) -} - -func (attrs Attrs) NegationBuild(builder Builder) { - // TODO -} - -// ID -type ID struct { - Value []interface{} -} - -func (id ID) Build(builder Builder) { - if len(id.Value) == 1 { - } - // TODO - // builder.WriteQuoted(like.Column) - // builder.Write(" LIKE ", builder.AddVar(like.Value)) -} - -func (id ID) NegationBuild(builder Builder) { - // TODO -} diff --git a/clause/returning.go b/clause/returning.go new file mode 100644 index 00000000..04bc96da --- /dev/null +++ b/clause/returning.go @@ -0,0 +1,30 @@ +package clause + +type Returning struct { + Columns []Column +} + +// Name where clause name +func (returning Returning) Name() string { + return "RETURNING" +} + +// Build build where clause +func (returning Returning) Build(builder Builder) { + for idx, column := range returning.Columns { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(column) + } +} + +// MergeClause merge order by clauses +func (returning Returning) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Returning); ok { + returning.Columns = append(v.Columns, returning.Columns...) + } + + clause.Expression = returning +} diff --git a/clause/returning_test.go b/clause/returning_test.go new file mode 100644 index 00000000..e9fed1cb --- /dev/null +++ b/clause/returning_test.go @@ -0,0 +1,36 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestReturning(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Returning{ + []clause.Column{clause.PrimaryColumn}, + }}, + "SELECT * FROM `users` RETURNING `users`.`id`", nil, + }, { + []clause.Interface{clause.Select{}, clause.From{}, clause.Returning{ + []clause.Column{clause.PrimaryColumn}, + }, clause.Returning{ + []clause.Column{{Name: "name"}, {Name: "age"}}, + }}, + "SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/select.go b/clause/select.go index 7f0e4438..4bb1af8d 100644 --- a/clause/select.go +++ b/clause/select.go @@ -1,32 +1,18 @@ package clause -// SelectInterface select clause interface -type SelectInterface interface { - Selects() []Column - Omits() []Column -} - // Select select attrs when querying, updating, creating type Select struct { - SelectColumns []Column - OmitColumns []Column + Columns []Column + Omits []Column } func (s Select) Name() string { return "SELECT" } -func (s Select) Selects() []Column { - return s.SelectColumns -} - -func (s Select) Omits() []Column { - return s.OmitColumns -} - func (s Select) Build(builder Builder) { - if len(s.SelectColumns) > 0 { - for idx, column := range s.SelectColumns { + if len(s.Columns) > 0 { + for idx, column := range s.Columns { if idx > 0 { builder.WriteByte(',') } @@ -37,13 +23,10 @@ func (s Select) Build(builder Builder) { } } -func (s Select) MergeExpression(expr Expression) { - if v, ok := expr.(SelectInterface); ok { - if len(s.SelectColumns) == 0 { - s.SelectColumns = v.Selects() - } - if len(s.OmitColumns) == 0 { - s.OmitColumns = v.Omits() - } +func (s Select) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Select); ok { + s.Columns = append(v.Columns, s.Columns...) + s.Omits = append(v.Omits, s.Omits...) } + clause.Expression = s } diff --git a/clause/select_test.go b/clause/select_test.go new file mode 100644 index 00000000..8255e51b --- /dev/null +++ b/clause/select_test.go @@ -0,0 +1,41 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestSelect(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}}, + "SELECT * FROM `users`", nil, + }, + { + []clause.Interface{clause.Select{ + Columns: []clause.Column{clause.PrimaryColumn}, + }, clause.From{}}, + "SELECT `users`.`id` FROM `users`", nil, + }, + { + []clause.Interface{clause.Select{ + Columns: []clause.Column{clause.PrimaryColumn}, + }, clause.Select{ + Columns: []clause.Column{{Name: "name"}}, + }, clause.From{}}, + "SELECT `users`.`id`,`name` FROM `users`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/set.go b/clause/set.go new file mode 100644 index 00000000..3b7e972d --- /dev/null +++ b/clause/set.go @@ -0,0 +1,37 @@ +package clause + +type Set []Assignment + +type Assignment struct { + Column Column + Value interface{} +} + +func (set Set) Name() string { + return "SET" +} + +func (set Set) Build(builder Builder) { + if len(set) > 0 { + for idx, assignment := range set { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(assignment.Column) + builder.WriteByte('=') + builder.Write(builder.AddVar(assignment.Value)) + } + } else { + builder.WriteQuoted(PrimaryColumn) + builder.WriteByte('=') + builder.WriteQuoted(PrimaryColumn) + } +} + +// MergeClause merge assignments clauses +func (set Set) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Set); ok { + set = append(v, set...) + } + clause.Expression = set +} diff --git a/clause/set_test.go b/clause/set_test.go new file mode 100644 index 00000000..85754737 --- /dev/null +++ b/clause/set_test.go @@ -0,0 +1,38 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestSet(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{ + clause.Update{}, + clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), + }, + "UPDATE `users` SET `users`.`id`=?", []interface{}{1}, + }, + { + []clause.Interface{ + clause.Update{}, + clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), + clause.Set([]clause.Assignment{{clause.Column{Name: "name"}, "jinzhu"}}), + }, + "UPDATE `users` SET `users`.`id`=?,`name`=?", []interface{}{1, "jinzhu"}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/update.go b/clause/update.go new file mode 100644 index 00000000..c375b373 --- /dev/null +++ b/clause/update.go @@ -0,0 +1,38 @@ +package clause + +type Update struct { + Modifier string + Table Table +} + +// Name update clause name +func (update Update) Name() string { + return "UPDATE" +} + +// Build build update clause +func (update Update) Build(builder Builder) { + if update.Modifier != "" { + builder.Write(update.Modifier) + builder.WriteByte(' ') + } + + if update.Table.Name == "" { + builder.WriteQuoted(currentTable) + } else { + builder.WriteQuoted(update.Table) + } +} + +// MergeClause merge update clause +func (update Update) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Update); ok { + if update.Modifier == "" { + update.Modifier = v.Modifier + } + if update.Table.Name == "" { + update.Table = v.Table + } + } + clause.Expression = update +} diff --git a/clause/update_test.go b/clause/update_test.go new file mode 100644 index 00000000..adc48f03 --- /dev/null +++ b/clause/update_test.go @@ -0,0 +1,35 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestUpdate(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Update{}}, + "UPDATE `users`", nil, + }, + { + []clause.Interface{clause.Update{Modifier: "LOW_PRIORITY"}}, + "UPDATE LOW_PRIORITY `users`", nil, + }, + { + []clause.Interface{clause.Update{Table: clause.Table{Name: "products"}, Modifier: "LOW_PRIORITY"}}, + "UPDATE LOW_PRIORITY `products`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/value.go b/clause/values.go similarity index 76% rename from clause/value.go rename to clause/values.go index 4de0d91e..594b92e2 100644 --- a/clause/value.go +++ b/clause/values.go @@ -25,11 +25,11 @@ func (values Values) Build(builder Builder) { builder.Write(" VALUES ") for idx, value := range values.Values { - builder.WriteByte('(') if idx > 0 { builder.WriteByte(',') } + builder.WriteByte('(') builder.Write(builder.AddVar(value...)) builder.WriteByte(')') } @@ -37,3 +37,11 @@ func (values Values) Build(builder Builder) { builder.Write("DEFAULT VALUES") } } + +// MergeClause merge values clauses +func (values Values) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Values); ok { + values.Values = append(v.Values, values.Values...) + } + clause.Expression = values +} diff --git a/clause/values_test.go b/clause/values_test.go new file mode 100644 index 00000000..ced4f1e6 --- /dev/null +++ b/clause/values_test.go @@ -0,0 +1,33 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestValues(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{ + clause.Insert{}, + clause.Values{ + Columns: []clause.Column{{Name: "name"}, {Name: "age"}}, + Values: [][]interface{}{{"jinzhu", 18}, {"josh", 1}}, + }, + }, + "INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)", []interface{}{"jinzhu", 18, "josh", 1}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/where.go b/clause/where.go index de82662c..d0f57ed1 100644 --- a/clause/where.go +++ b/clause/where.go @@ -2,9 +2,7 @@ package clause // Where where clause type Where struct { - AndConditions AddConditions - OrConditions []OrConditions - builders []Expression + Exprs []Expression } // Name where clause name @@ -14,64 +12,122 @@ func (where Where) Name() string { // Build build where clause func (where Where) Build(builder Builder) { - var withConditions bool - - if len(where.AndConditions) > 0 { - withConditions = true - where.AndConditions.Build(builder) - } - - if len(where.builders) > 0 { - for _, b := range where.builders { - if withConditions { - builder.Write(" AND ") + // Switch position if the first query expression is a single Or condition + for idx, expr := range where.Exprs { + if v, ok := expr.(OrConditions); (!ok && expr != nil) || len(v.Exprs) > 1 { + if idx != 0 { + where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0] } - withConditions = true - b.Build(builder) + break } } - var singleOrConditions []OrConditions - for _, or := range where.OrConditions { - if len(or) == 1 { - if withConditions { - builder.Write(" OR ") - or.Build(builder) - } else { - singleOrConditions = append(singleOrConditions, or) + for idx, expr := range where.Exprs { + if expr != nil { + if idx > 0 { + if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { + builder.Write(" OR ") + } else { + builder.Write(" AND ") + } } - } else { - withConditions = true - builder.Write(" AND (") - or.Build(builder) - builder.WriteByte(')') - } - } - for _, or := range singleOrConditions { - if withConditions { - builder.Write(" AND ") - or.Build(builder) - } else { - withConditions = true - or.Build(builder) + expr.Build(builder) } } - if !withConditions { - builder.Write(" FALSE") - } - return } -// MergeExpression merge where clauses -func (where Where) MergeExpression(expr Expression) { - if w, ok := expr.(Where); ok { - where.AndConditions = append(where.AndConditions, w.AndConditions...) - where.OrConditions = append(where.OrConditions, w.OrConditions...) - where.builders = append(where.builders, w.builders...) - } else { - where.builders = append(where.builders, expr) +// MergeClause merge where clauses +func (where Where) MergeClause(clause *Clause) { + if w, ok := clause.Expression.(Where); ok { + where.Exprs = append(w.Exprs, where.Exprs...) + } + + clause.Expression = where +} + +func And(exprs ...Expression) Expression { + if len(exprs) == 0 { + return nil + } + return AndConditions{Exprs: exprs} +} + +type AndConditions struct { + Exprs []Expression +} + +func (and AndConditions) Build(builder Builder) { + if len(and.Exprs) > 1 { + builder.Write("(") + } + for idx, c := range and.Exprs { + if idx > 0 { + builder.Write(" AND ") + } + c.Build(builder) + } + if len(and.Exprs) > 1 { + builder.Write(")") + } +} + +func Or(exprs ...Expression) Expression { + if len(exprs) == 0 { + return nil + } + return OrConditions{Exprs: exprs} +} + +type OrConditions struct { + Exprs []Expression +} + +func (or OrConditions) Build(builder Builder) { + if len(or.Exprs) > 1 { + builder.Write("(") + } + for idx, c := range or.Exprs { + if idx > 0 { + builder.Write(" OR ") + } + c.Build(builder) + } + if len(or.Exprs) > 1 { + builder.Write(")") + } +} + +func Not(exprs ...Expression) Expression { + if len(exprs) == 0 { + return nil + } + return NotConditions{Exprs: exprs} +} + +type NotConditions struct { + Exprs []Expression +} + +func (not NotConditions) Build(builder Builder) { + if len(not.Exprs) > 1 { + builder.Write("(") + } + for idx, c := range not.Exprs { + if idx > 0 { + builder.Write(" AND ") + } + + if negationBuilder, ok := c.(NegationExpressionBuilder); ok { + negationBuilder.NegationBuild(builder) + } else { + builder.Write(" NOT ") + c.Build(builder) + } + } + if len(not.Exprs) > 1 { + builder.Write(")") } } diff --git a/clause/where_test.go b/clause/where_test.go new file mode 100644 index 00000000..450a0c89 --- /dev/null +++ b/clause/where_test.go @@ -0,0 +1,63 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestWhere(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?", []interface{}{"1", 18, "jinzhu"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Or(), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?", []interface{}{"1", "jinzhu"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, + }, clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"})}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, + }, clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.Not(clause.Gt{Column: "score", Value: 100}), clause.Like{Column: "name", Value: "%linus%"})}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/finisher_api.go b/finisher_api.go index 06809651..5389ed6a 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -22,7 +22,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { // First find first record that match given conditions, order by primary key func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { - tx = db.getInstance().Limit(1).Order(clause.OrderBy{ + tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Desc: true, }) diff --git a/statement.go b/statement.go index bc07b6e4..5dd49623 100644 --- a/statement.go +++ b/statement.go @@ -5,7 +5,6 @@ import ( "database/sql" "database/sql/driver" "fmt" - "log" "strconv" "strings" "sync" @@ -26,7 +25,7 @@ func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) { if len(clauses) > 0 { instance.Statement.Build(clauses...) } - return instance.Statement.SQL.String(), instance.Statement.Vars + return strings.TrimSpace(instance.Statement.SQL.String()), instance.Statement.Vars } // AddError add error to instance @@ -85,10 +84,10 @@ func (stmt Statement) Quote(field interface{}) string { switch v := field.(type) { case clause.Table: - if v.Table == clause.CurrentTable { + if v.Name == clause.CurrentTable { str.WriteString(stmt.Table) } else { - str.WriteString(v.Table) + str.WriteString(v.Name) } if v.Alias != "" { @@ -126,7 +125,7 @@ func (stmt Statement) Quote(field interface{}) string { str.WriteByte(stmt.DB.quoteChars[1]) } default: - fmt.Sprint(field) + str.WriteString(fmt.Sprint(field)) } str.WriteByte(stmt.DB.quoteChars[1]) @@ -141,19 +140,28 @@ func (stmt *Statement) AddVar(vars ...interface{}) string { placeholders.WriteByte(',') } - if namedArg, ok := v.(sql.NamedArg); ok && len(namedArg.Name) > 0 { - stmt.NamedVars = append(stmt.NamedVars, namedArg) - placeholders.WriteByte('@') - placeholders.WriteString(namedArg.Name) - } else if arrs, ok := v.([]interface{}); ok { + switch v := v.(type) { + case sql.NamedArg: + if len(v.Name) > 0 { + stmt.NamedVars = append(stmt.NamedVars, v) + placeholders.WriteByte('@') + placeholders.WriteString(v.Name) + } else { + stmt.Vars = append(stmt.Vars, v.Value) + placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value)) + } + case clause.Column: + placeholders.WriteString(stmt.Quote(v)) + case []interface{}: placeholders.WriteByte('(') - if len(arrs) > 0 { - placeholders.WriteString(stmt.AddVar(arrs...)) + if len(v) > 0 { + placeholders.WriteString(stmt.AddVar(v...)) } else { placeholders.WriteString("NULL") } placeholders.WriteByte(')') - } else { + default: + stmt.Vars = append(stmt.Vars, v) placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v)) } } @@ -166,42 +174,18 @@ func (stmt *Statement) AddClause(v clause.Interface) { optimizer.OptimizeStatement(stmt) } - c, _ := stmt.Clauses[v.Name()] - if namer, ok := v.(clause.OverrideNameInterface); ok { - c.Name = namer.OverrideName() - } else { + c, ok := stmt.Clauses[v.Name()] + if !ok { c.Name = v.Name() } - - if c.Expression != nil { - v.MergeExpression(c.Expression) - } - - c.Expression = v + v.MergeClause(&c) stmt.Clauses[v.Name()] = c } // AddClauseIfNotExists add clause if not exists func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { - if optimizer, ok := v.(StatementOptimizer); ok { - optimizer.OptimizeStatement(stmt) - } - - log.Println(v.Name()) - if c, ok := stmt.Clauses[v.Name()]; !ok { - if namer, ok := v.(clause.OverrideNameInterface); ok { - c.Name = namer.OverrideName() - } else { - c.Name = v.Name() - } - - if c.Expression != nil { - v.MergeExpression(c.Expression) - } - - c.Expression = v - stmt.Clauses[v.Name()] = c - log.Println(stmt.Clauses[v.Name()]) + if _, ok := stmt.Clauses[v.Name()]; !ok { + stmt.AddClause(v) } } @@ -211,7 +195,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con if i, err := strconv.Atoi(sql); err != nil { query = i } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { - return []clause.Expression{clause.String{SQL: sql, Values: args}} + return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} } } @@ -255,7 +239,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con } if len(conditions) == 0 { - conditions = append(conditions, clause.ID{Value: args}) + conditions = append(conditions, clause.IN{Column: clause.PrimaryColumn, Values: args}) } return conditions From c1afe197289c4abb99f440af7ad003d6d6224f24 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 14 Feb 2020 00:09:44 +0800 Subject: [PATCH 230/881] Add benchmark tests for clause --- clause/benchmarks_test.go | 56 +++++++++++++++++++++++++++++++++++++++ clause/where.go | 12 ++++----- statement.go | 6 ++--- 3 files changed, 65 insertions(+), 9 deletions(-) create mode 100644 clause/benchmarks_test.go diff --git a/clause/benchmarks_test.go b/clause/benchmarks_test.go new file mode 100644 index 00000000..33d3430a --- /dev/null +++ b/clause/benchmarks_test.go @@ -0,0 +1,56 @@ +package clause_test + +import ( + "sync" + "testing" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/gorm/tests" +) + +func BenchmarkSelect(b *testing.B) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + + for i := 0; i < b.N; i++ { + stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + clauses := []clause.Interface{clause.Select{}, clause.From{}, clause.Where{Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}}} + + for _, clause := range clauses { + stmt.AddClause(clause) + } + + stmt.Build("SELECT", "FROM", "WHERE") + _ = stmt.SQL.String() + } +} + +func BenchmarkComplexSelect(b *testing.B) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + + for i := 0; i < b.N; i++ { + stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + clauses := []clause.Interface{ + clause.Select{}, clause.From{}, + clause.Where{Exprs: []clause.Expression{ + clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, + clause.Gt{Column: "age", Value: 18}, + clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), + }}, + clause.Where{Exprs: []clause.Expression{ + clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}), + }}, + clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: clause.Where{[]clause.Expression{clause.Eq{"role", "admin"}}}}, + clause.Limit{Limit: 10, Offset: 20}, + clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}}, + } + + for _, clause := range clauses { + stmt.AddClause(clause) + } + + stmt.Build("SELECT", "FROM", "WHERE", "GROUP BY", "LIMIT", "ORDER BY") + _ = stmt.SQL.String() + } +} diff --git a/clause/where.go b/clause/where.go index d0f57ed1..0ee1a141 100644 --- a/clause/where.go +++ b/clause/where.go @@ -61,7 +61,7 @@ type AndConditions struct { func (and AndConditions) Build(builder Builder) { if len(and.Exprs) > 1 { - builder.Write("(") + builder.WriteByte('(') } for idx, c := range and.Exprs { if idx > 0 { @@ -70,7 +70,7 @@ func (and AndConditions) Build(builder Builder) { c.Build(builder) } if len(and.Exprs) > 1 { - builder.Write(")") + builder.WriteByte(')') } } @@ -87,7 +87,7 @@ type OrConditions struct { func (or OrConditions) Build(builder Builder) { if len(or.Exprs) > 1 { - builder.Write("(") + builder.WriteByte('(') } for idx, c := range or.Exprs { if idx > 0 { @@ -96,7 +96,7 @@ func (or OrConditions) Build(builder Builder) { c.Build(builder) } if len(or.Exprs) > 1 { - builder.Write(")") + builder.WriteByte(')') } } @@ -113,7 +113,7 @@ type NotConditions struct { func (not NotConditions) Build(builder Builder) { if len(not.Exprs) > 1 { - builder.Write("(") + builder.WriteByte('(') } for idx, c := range not.Exprs { if idx > 0 { @@ -128,6 +128,6 @@ func (not NotConditions) Build(builder Builder) { } } if len(not.Exprs) > 1 { - builder.Write(")") + builder.WriteByte(')') } } diff --git a/statement.go b/statement.go index 5dd49623..1c3934c1 100644 --- a/statement.go +++ b/statement.go @@ -153,13 +153,13 @@ func (stmt *Statement) AddVar(vars ...interface{}) string { case clause.Column: placeholders.WriteString(stmt.Quote(v)) case []interface{}: - placeholders.WriteByte('(') if len(v) > 0 { + placeholders.WriteByte('(') placeholders.WriteString(stmt.AddVar(v...)) + placeholders.WriteByte(')') } else { - placeholders.WriteString("NULL") + placeholders.WriteString("(NULL)") } - placeholders.WriteByte(')') default: stmt.Vars = append(stmt.Vars, v) placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v)) From 2cb88dc7c56b0eba123b8adf872d2520988bcfc7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 15 Feb 2020 16:04:21 +0800 Subject: [PATCH 231/881] Add Field Valuer, Setter --- schema/field.go | 357 +++++++++++++++++++++++++++++++++++++++++++ schema/field_test.go | 64 ++++++++ schema/schema.go | 2 + 3 files changed, 423 insertions(+) create mode 100644 schema/field_test.go diff --git a/schema/field.go b/schema/field.go index 570b3c50..15e94279 100644 --- a/schema/field.go +++ b/schema/field.go @@ -1,11 +1,15 @@ package schema import ( + "database/sql" "database/sql/driver" + "fmt" "reflect" "strconv" "sync" "time" + + "github.com/jinzhu/now" ) type DataType string @@ -43,6 +47,9 @@ type Field struct { TagSettings map[string]string Schema *Schema EmbeddedSchema *Schema + ReflectValuer func(reflect.Value) reflect.Value + Valuer func(reflect.Value) interface{} + Setter func(reflect.Value, interface{}) error } func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { @@ -186,6 +193,12 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { for _, ef := range field.EmbeddedSchema.Fields { ef.Schema = schema ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) + // index is negative means is pointer + if field.FieldType.Kind() == reflect.Struct { + ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) + } else { + ef.StructField.Index = append([]int{-fieldStruct.Index[0]}, ef.StructField.Index...) + } if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { ef.DBName = prefix + ef.DBName @@ -199,3 +212,347 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { return field } + +// ValueOf field value of +func (field *Field) ValueOf(value reflect.Value) interface{} { + if field != nil { + return field.Valuer(value) + } + return nil +} + +func (field *Field) Set(value reflect.Value, v interface{}) error { + if field != nil { + return field.Setter(value, v) + } + + return fmt.Errorf("failed to set field value: %v", field.Name) +} + +// create valuer, setter when parse struct +func (field *Field) setupValuerAndSetter() { + // Valuer + switch { + case len(field.StructField.Index) == 1: + field.Valuer = func(value reflect.Value) interface{} { + return value.Field(field.StructField.Index[0]).Interface() + } + case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: + field.Valuer = func(value reflect.Value) interface{} { + return value.Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface() + } + default: + field.Valuer = func(value reflect.Value) interface{} { + v := value.Field(field.StructField.Index[0]) + for _, idx := range field.StructField.Index[1:] { + if v.Kind() == reflect.Ptr { + if v.Type().Elem().Kind() == reflect.Struct { + if !v.IsNil() { + v = v.Elem().Field(-idx) + continue + } + } + return nil + } else { + v = v.Field(idx) + } + } + return v.Interface() + } + } + + // ReflectValuer + switch { + case len(field.StructField.Index) == 1: + if field.FieldType.Kind() == reflect.Ptr { + field.ReflectValuer = func(value reflect.Value) reflect.Value { + fieldValue := value.Field(field.StructField.Index[0]) + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + return fieldValue + } + } else { + field.ReflectValuer = func(value reflect.Value) reflect.Value { + return value.Field(field.StructField.Index[0]) + } + } + case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: + field.Valuer = func(value reflect.Value) interface{} { + return value.Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface() + } + default: + field.ReflectValuer = func(value reflect.Value) reflect.Value { + v := value.Field(field.StructField.Index[0]) + for _, idx := range field.StructField.Index[1:] { + if v.Kind() == reflect.Ptr { + if v.Type().Elem().Kind() == reflect.Struct { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + + if idx >= 0 { + v = v.Elem().Field(idx) + } else { + v = v.Elem().Field(-idx) + } + } + } else { + v = v.Field(idx) + } + } + return v + } + } + + // Setter + switch field.FieldType.Kind() { + case reflect.Bool: + field.Setter = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case bool: + field.ReflectValuer(value).SetBool(data) + case *bool: + field.ReflectValuer(value).SetBool(*data) + default: + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else { + field.ReflectValuer(value).SetBool(!reflect.ValueOf(v).IsZero()) + } + } + return nil + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.Setter = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case int64: + field.ReflectValuer(value).SetInt(data) + case int: + field.ReflectValuer(value).SetInt(int64(data)) + case int8: + field.ReflectValuer(value).SetInt(int64(data)) + case int16: + field.ReflectValuer(value).SetInt(int64(data)) + case int32: + field.ReflectValuer(value).SetInt(int64(data)) + case uint: + field.ReflectValuer(value).SetInt(int64(data)) + case uint8: + field.ReflectValuer(value).SetInt(int64(data)) + case uint16: + field.ReflectValuer(value).SetInt(int64(data)) + case uint32: + field.ReflectValuer(value).SetInt(int64(data)) + case uint64: + field.ReflectValuer(value).SetInt(int64(data)) + case float32: + field.ReflectValuer(value).SetInt(int64(data)) + case float64: + field.ReflectValuer(value).SetInt(int64(data)) + case []byte: + return field.Setter(value, string(data)) + case string: + if i, err := strconv.ParseInt(data, 0, 64); err == nil { + field.ReflectValuer(value).SetInt(i) + } else { + return err + } + default: + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else if reflectV.Kind() == reflect.Ptr { + return field.Setter(value, reflectV.Elem().Interface()) + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } + } + return nil + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.Setter = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case uint64: + field.ReflectValuer(value).SetUint(data) + case uint: + field.ReflectValuer(value).SetUint(uint64(data)) + case uint8: + field.ReflectValuer(value).SetUint(uint64(data)) + case uint16: + field.ReflectValuer(value).SetUint(uint64(data)) + case uint32: + field.ReflectValuer(value).SetUint(uint64(data)) + case int64: + field.ReflectValuer(value).SetUint(uint64(data)) + case int: + field.ReflectValuer(value).SetUint(uint64(data)) + case int8: + field.ReflectValuer(value).SetUint(uint64(data)) + case int16: + field.ReflectValuer(value).SetUint(uint64(data)) + case int32: + field.ReflectValuer(value).SetUint(uint64(data)) + case float32: + field.ReflectValuer(value).SetUint(uint64(data)) + case float64: + field.ReflectValuer(value).SetUint(uint64(data)) + case []byte: + return field.Setter(value, string(data)) + case string: + if i, err := strconv.ParseUint(data, 0, 64); err == nil { + field.ReflectValuer(value).SetUint(i) + } else { + return err + } + default: + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else if reflectV.Kind() == reflect.Ptr { + return field.Setter(value, reflectV.Elem().Interface()) + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } + } + return nil + } + case reflect.Float32, reflect.Float64: + field.Setter = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case float64: + field.ReflectValuer(value).SetFloat(data) + case float32: + field.ReflectValuer(value).SetFloat(float64(data)) + case int64: + field.ReflectValuer(value).SetFloat(float64(data)) + case int: + field.ReflectValuer(value).SetFloat(float64(data)) + case int8: + field.ReflectValuer(value).SetFloat(float64(data)) + case int16: + field.ReflectValuer(value).SetFloat(float64(data)) + case int32: + field.ReflectValuer(value).SetFloat(float64(data)) + case uint: + field.ReflectValuer(value).SetFloat(float64(data)) + case uint8: + field.ReflectValuer(value).SetFloat(float64(data)) + case uint16: + field.ReflectValuer(value).SetFloat(float64(data)) + case uint32: + field.ReflectValuer(value).SetFloat(float64(data)) + case uint64: + field.ReflectValuer(value).SetFloat(float64(data)) + case []byte: + return field.Setter(value, string(data)) + case string: + if i, err := strconv.ParseFloat(data, 64); err == nil { + field.ReflectValuer(value).SetFloat(i) + } else { + return err + } + default: + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else if reflectV.Kind() == reflect.Ptr { + return field.Setter(value, reflectV.Elem().Interface()) + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } + } + return nil + } + case reflect.String: + field.Setter = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case string: + field.ReflectValuer(value).SetString(data) + case []byte: + field.ReflectValuer(value).SetString(string(data)) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + field.ReflectValuer(value).SetString(fmt.Sprint(data)) + case float64, float32: + field.ReflectValuer(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) + default: + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else if reflectV.Kind() == reflect.Ptr { + return field.Setter(value, reflectV.Elem().Interface()) + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } + } + return nil + } + default: + fieldValue := reflect.New(field.FieldType) + switch fieldValue.Interface().(type) { + case time.Time: + field.Setter = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case time.Time: + field.ReflectValuer(value).Set(reflect.ValueOf(v)) + case *time.Time: + field.ReflectValuer(value).Set(reflect.ValueOf(v).Elem()) + case string: + if t, err := now.Parse(data); err == nil { + field.ReflectValuer(value).Set(reflect.ValueOf(t)) + } else { + return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) + } + default: + return fmt.Errorf("failed to set value %+v to time.Time field %v", v, field.Name) + } + return nil + } + case *time.Time: + field.Setter = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case time.Time: + field.ReflectValuer(value).Elem().Set(reflect.ValueOf(v)) + case *time.Time: + field.ReflectValuer(value).Set(reflect.ValueOf(v)) + case string: + if t, err := now.Parse(data); err == nil { + field.ReflectValuer(value).Elem().Set(reflect.ValueOf(t)) + } else { + return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) + } + default: + return fmt.Errorf("failed to set value %+v to time.Time field %v", v, field.Name) + } + return nil + } + default: + if fieldValue.CanAddr() { + if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + field.Setter = func(value reflect.Value, v interface{}) (err error) { + if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) + } + } else { + err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) + } + return + } + return + } + } + + field.Setter = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } + return nil + } + } + } +} diff --git a/schema/field_test.go b/schema/field_test.go new file mode 100644 index 00000000..c7814fbf --- /dev/null +++ b/schema/field_test.go @@ -0,0 +1,64 @@ +package schema_test + +import ( + "reflect" + "sync" + "testing" + "time" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/gorm/tests" +) + +func TestFieldValuerAndSetter(t *testing.T) { + var ( + cacheMap = sync.Map{} + userSchema, _ = schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{}) + user = tests.User{ + Model: gorm.Model{ + ID: 10, + CreatedAt: time.Now(), + DeletedAt: tests.Now(), + }, + Name: "valuer_and_setter", + Age: 18, + Birthday: tests.Now(), + } + reflectValue = reflect.ValueOf(user) + ) + + values := map[string]interface{}{ + "name": user.Name, + "id": user.ID, + "created_at": user.CreatedAt, + "deleted_at": user.DeletedAt, + "age": user.Age, + "birthday": user.Birthday, + } + + for k, v := range values { + if rv := userSchema.FieldsByDBName[k].ValueOf(reflectValue); rv != v { + t.Errorf("user's %v value should equal %+v, but got %+v", k, v, rv) + } + } + + newValues := map[string]interface{}{ + "name": "valuer_and_setter_2", + "id": "2", + "created_at": time.Now(), + "deleted_at": tests.Now(), + "age": 20, + "birthday": time.Now(), + } + + for k, v := range newValues { + if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v", k) + } + + if rv := userSchema.FieldsByDBName[k].ValueOf(reflectValue); rv != v { + t.Errorf("user's %v value should equal %+v after assign new value, but got %+v", k, v, rv) + } + } +} diff --git a/schema/schema.go b/schema/schema.go index 53170e18..2f3cdf88 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -128,6 +128,8 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if _, ok := schema.FieldsByName[field.Name]; !ok { schema.FieldsByName[field.Name] = field } + + field.setupValuerAndSetter() } if f := schema.LookUpField("id"); f != nil { From faee069a9fce8e919e05f54dd4a3a5b519803e7f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 15 Feb 2020 19:45:27 +0800 Subject: [PATCH 232/881] Test Field Valuer, Setter --- schema/field.go | 184 +++++++++++++++++++++-------------- schema/field_test.go | 87 +++++++++++++++-- schema/relationship.go | 6 +- schema/schema_helper_test.go | 32 ++++++ schema/schema_test.go | 3 +- tests/model.go | 3 +- 6 files changed, 225 insertions(+), 90 deletions(-) diff --git a/schema/field.go b/schema/field.go index 15e94279..b4610436 100644 --- a/schema/field.go +++ b/schema/field.go @@ -25,52 +25,53 @@ const ( ) type Field struct { - Name string - DBName string - BindNames []string - DataType DataType - DBDataType string - PrimaryKey bool - AutoIncrement bool - Creatable bool - Updatable bool - HasDefaultValue bool - DefaultValue string - NotNull bool - Unique bool - Comment string - Size int - Precision int - FieldType reflect.Type - StructField reflect.StructField - Tag reflect.StructTag - TagSettings map[string]string - Schema *Schema - EmbeddedSchema *Schema - ReflectValuer func(reflect.Value) reflect.Value - Valuer func(reflect.Value) interface{} - Setter func(reflect.Value, interface{}) error + Name string + DBName string + BindNames []string + DataType DataType + DBDataType string + PrimaryKey bool + AutoIncrement bool + Creatable bool + Updatable bool + HasDefaultValue bool + DefaultValue string + NotNull bool + Unique bool + Comment string + Size int + Precision int + FieldType reflect.Type + IndirectFieldType reflect.Type + StructField reflect.StructField + Tag reflect.StructTag + TagSettings map[string]string + Schema *Schema + EmbeddedSchema *Schema + ReflectValuer func(reflect.Value) reflect.Value + Valuer func(reflect.Value) interface{} + Setter func(reflect.Value, interface{}) error } func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field := &Field{ - Name: fieldStruct.Name, - BindNames: []string{fieldStruct.Name}, - FieldType: fieldStruct.Type, - StructField: fieldStruct, - Creatable: true, - Updatable: true, - Tag: fieldStruct.Tag, - TagSettings: ParseTagSetting(fieldStruct.Tag), - Schema: schema, + Name: fieldStruct.Name, + BindNames: []string{fieldStruct.Name}, + FieldType: fieldStruct.Type, + IndirectFieldType: fieldStruct.Type, + StructField: fieldStruct, + Creatable: true, + Updatable: true, + Tag: fieldStruct.Tag, + TagSettings: ParseTagSetting(fieldStruct.Tag), + Schema: schema, } - for field.FieldType.Kind() == reflect.Ptr { - field.FieldType = field.FieldType.Elem() + for field.IndirectFieldType.Kind() == reflect.Ptr { + field.IndirectFieldType = field.IndirectFieldType.Elem() } - fieldValue := reflect.New(field.FieldType) - + fieldValue := reflect.New(field.IndirectFieldType) // if field is valuer, used its value or first fields as data type if valuer, isValuer := fieldValue.Interface().(driver.Valuer); isValuer { var overrideFieldValue bool @@ -79,10 +80,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { fieldValue = reflect.ValueOf(v) } - if field.FieldType.Kind() == reflect.Struct { - for i := 0; i < field.FieldType.NumField(); i++ { + if field.IndirectFieldType.Kind() == reflect.Struct { + for i := 0; i < field.IndirectFieldType.NumField(); i++ { if !overrideFieldValue { - newFieldType := field.FieldType.Field(i).Type + newFieldType := field.IndirectFieldType.Field(i).Type for newFieldType.Kind() == reflect.Ptr { newFieldType = newFieldType.Elem() } @@ -92,7 +93,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } // copy tag settings from valuer - for key, value := range ParseTagSetting(field.FieldType.Field(i).Tag) { + for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag) { if _, ok := field.TagSettings[key]; !ok { field.TagSettings[key] = value } @@ -197,7 +198,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { if field.FieldType.Kind() == reflect.Struct { ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) } else { - ef.StructField.Index = append([]int{-fieldStruct.Index[0]}, ef.StructField.Index...) + ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...) } if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { @@ -235,26 +236,29 @@ func (field *Field) setupValuerAndSetter() { switch { case len(field.StructField.Index) == 1: field.Valuer = func(value reflect.Value) interface{} { - return value.Field(field.StructField.Index[0]).Interface() + return reflect.Indirect(value).Field(field.StructField.Index[0]).Interface() } case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: field.Valuer = func(value reflect.Value) interface{} { - return value.Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface() + return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface() } default: field.Valuer = func(value reflect.Value) interface{} { - v := value.Field(field.StructField.Index[0]) - for _, idx := range field.StructField.Index[1:] { - if v.Kind() == reflect.Ptr { + v := reflect.Indirect(value) + + for _, idx := range field.StructField.Index { + if idx >= 0 { + v = v.Field(idx) + } else { + v = v.Field(-idx - 1) + if v.Type().Elem().Kind() == reflect.Struct { if !v.IsNil() { - v = v.Elem().Field(-idx) - continue + v = v.Elem() } + } else { + return nil } - return nil - } else { - v = v.Field(idx) } } return v.Interface() @@ -266,7 +270,7 @@ func (field *Field) setupValuerAndSetter() { case len(field.StructField.Index) == 1: if field.FieldType.Kind() == reflect.Ptr { field.ReflectValuer = func(value reflect.Value) reflect.Value { - fieldValue := value.Field(field.StructField.Index[0]) + fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } @@ -274,31 +278,33 @@ func (field *Field) setupValuerAndSetter() { } } else { field.ReflectValuer = func(value reflect.Value) reflect.Value { - return value.Field(field.StructField.Index[0]) + return reflect.Indirect(value).Field(field.StructField.Index[0]) } } case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: - field.Valuer = func(value reflect.Value) interface{} { - return value.Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface() + field.ReflectValuer = func(value reflect.Value) reflect.Value { + return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) } default: field.ReflectValuer = func(value reflect.Value) reflect.Value { - v := value.Field(field.StructField.Index[0]) - for _, idx := range field.StructField.Index[1:] { + v := reflect.Indirect(value) + for _, idx := range field.StructField.Index { + if idx >= 0 { + v = v.Field(idx) + } else { + v = v.Field(-idx - 1) + } + if v.Kind() == reflect.Ptr { if v.Type().Elem().Kind() == reflect.Struct { if v.IsNil() { v.Set(reflect.New(v.Type().Elem())) } - - if idx >= 0 { - v = v.Elem().Field(idx) - } else { - v = v.Elem().Field(-idx) - } } - } else { - v = v.Field(idx) + + if idx < len(field.StructField.Index)-1 { + v = v.Elem() + } } } return v @@ -490,7 +496,7 @@ func (field *Field) setupValuerAndSetter() { } default: fieldValue := reflect.New(field.FieldType) - switch fieldValue.Interface().(type) { + switch fieldValue.Elem().Interface().(type) { case time.Time: field.Setter = func(value reflect.Value, v interface{}) error { switch data := v.(type) { @@ -528,6 +534,20 @@ func (field *Field) setupValuerAndSetter() { return nil } default: + if _, ok := fieldValue.Interface().(sql.Scanner); ok { + field.Setter = func(value reflect.Value, v interface{}) (err error) { + if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v) + } + } else { + err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v) + } + return + } + return + } + if fieldValue.CanAddr() { if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { field.Setter = func(value reflect.Value, v interface{}) (err error) { @@ -544,14 +564,28 @@ func (field *Field) setupValuerAndSetter() { } } - field.Setter = func(value reflect.Value, v interface{}) (err error) { - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) - } else { - return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + if field.FieldType.Kind() == reflect.Ptr { + field.Setter = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { + field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } + return nil + } + } else { + field.Setter = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } + return nil } - return nil } } } diff --git a/schema/field_test.go b/schema/field_test.go index c7814fbf..065d6d05 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -24,10 +24,12 @@ func TestFieldValuerAndSetter(t *testing.T) { Name: "valuer_and_setter", Age: 18, Birthday: tests.Now(), + Active: true, } - reflectValue = reflect.ValueOf(user) + reflectValue = reflect.ValueOf(&user) ) + // test valuer values := map[string]interface{}{ "name": user.Name, "id": user.ID, @@ -35,30 +37,95 @@ func TestFieldValuerAndSetter(t *testing.T) { "deleted_at": user.DeletedAt, "age": user.Age, "birthday": user.Birthday, + "active": true, } + checkField(t, userSchema, reflectValue, values) - for k, v := range values { - if rv := userSchema.FieldsByDBName[k].ValueOf(reflectValue); rv != v { - t.Errorf("user's %v value should equal %+v, but got %+v", k, v, rv) - } - } - + // test setter newValues := map[string]interface{}{ "name": "valuer_and_setter_2", - "id": "2", + "id": 2, "created_at": time.Now(), "deleted_at": tests.Now(), "age": 20, "birthday": time.Now(), + "active": false, } for k, v := range newValues { if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v", k) } + } + checkField(t, userSchema, reflectValue, newValues) +} - if rv := userSchema.FieldsByDBName[k].ValueOf(reflectValue); rv != v { - t.Errorf("user's %v value should equal %+v after assign new value, but got %+v", k, v, rv) +func TestPointerFieldValuerAndSetter(t *testing.T) { + var ( + cacheMap = sync.Map{} + userSchema, _ = schema.Parse(&User{}, &cacheMap, schema.NamingStrategy{}) + name = "pointer_field_valuer_and_setter" + age = 18 + active = true + user = User{ + Model: &gorm.Model{ + ID: 10, + CreatedAt: time.Now(), + DeletedAt: tests.Now(), + }, + Name: &name, + Age: &age, + Birthday: tests.Now(), + Active: &active, + } + reflectValue = reflect.ValueOf(&user) + ) + + // test valuer + values := map[string]interface{}{ + "name": user.Name, + "id": user.ID, + "created_at": user.CreatedAt, + "deleted_at": user.DeletedAt, + "age": user.Age, + "birthday": user.Birthday, + "active": true, + } + checkField(t, userSchema, reflectValue, values) + + // test setter + newValues := map[string]interface{}{ + "name": "valuer_and_setter_2", + "id": 2, + "created_at": time.Now(), + "deleted_at": tests.Now(), + "age": 20, + "birthday": time.Now(), + "active": false, + } + + for k, v := range newValues { + if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } + checkField(t, userSchema, reflectValue, newValues) +} + +type User struct { + *gorm.Model + Name *string + Age *int + Birthday *time.Time + Account *tests.Account + Pets []*tests.Pet + Toys []tests.Toy `gorm:"polymorphic:Owner"` + CompanyID *int + Company *tests.Company + ManagerID *int + Manager *User + Team []User `gorm:"foreignkey:ManagerID"` + Languages []tests.Language `gorm:"many2many:UserSpeak"` + Friends []*User `gorm:"many2many:user_friends"` + Active *bool } diff --git a/schema/relationship.go b/schema/relationship.go index b6aaefbd..671371fe 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -54,7 +54,7 @@ type Reference struct { func (schema *Schema) parseRelation(field *Field) { var ( err error - fieldValue = reflect.New(field.FieldType).Interface() + fieldValue = reflect.New(field.IndirectFieldType).Interface() relation = &Relationship{ Name: field.Name, Field: field, @@ -74,7 +74,7 @@ func (schema *Schema) parseRelation(field *Field) { } else if many2many, _ := field.TagSettings["MANY2MANY"]; many2many != "" { schema.buildMany2ManyRelation(relation, field, many2many) } else { - switch field.FieldType.Kind() { + switch field.IndirectFieldType.Kind() { case reflect.Struct, reflect.Slice: schema.guessRelation(relation, field, true) default: @@ -83,7 +83,7 @@ func (schema *Schema) parseRelation(field *Field) { } if relation.Type == "has" { - switch field.FieldType.Kind() { + switch field.IndirectFieldType.Kind() { case reflect.Struct: relation.Type = HasOne case reflect.Slice: diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index db38355d..4af0fc89 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -2,6 +2,7 @@ package schema_test import ( "fmt" + "reflect" "strings" "testing" @@ -189,3 +190,34 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { } }) } + +func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { + for k, v := range values { + t.Run("CheckField/"+k, func(t *testing.T) { + field := s.FieldsByDBName[k] + fv := field.ValueOf(value) + + if reflect.ValueOf(fv).Kind() == reflect.Ptr { + if reflect.ValueOf(v).Kind() == reflect.Ptr { + if fv != v { + t.Errorf("pointer expects: %p, but got %p", v, fv) + } + } else if fv == nil { + if v != nil { + t.Errorf("expects: %+v, but got nil", v) + } + } else if reflect.ValueOf(fv).Elem().Interface() != v { + t.Errorf("expects: %+v, but got %+v", v, fv) + } + } else if reflect.ValueOf(v).Kind() == reflect.Ptr { + if reflect.ValueOf(v).Elem().Interface() != fv { + t.Errorf("expects: %+v, but got %+v", v, fv) + } + } else if reflect.ValueOf(v).Type().ConvertibleTo(field.FieldType) { + if reflect.ValueOf(v).Convert(field.FieldType).Interface() != fv { + t.Errorf("expects: %+v, but got %+v", v, fv) + } + } + }) + } +} diff --git a/schema/schema_test.go b/schema/schema_test.go index 526a98bd..97da1d5d 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -29,7 +29,8 @@ func TestParseSchema(t *testing.T) { {Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint}, {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, {Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int}, - {Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint}, + {Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Int}, + {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, } for _, f := range fields { diff --git a/tests/model.go b/tests/model.go index 62000352..ac2156c7 100644 --- a/tests/model.go +++ b/tests/model.go @@ -21,11 +21,12 @@ type User struct { Toys []Toy `gorm:"polymorphic:Owner"` CompanyID *int Company Company - ManagerID uint + ManagerID int Manager *User Team []User `gorm:"foreignkey:ManagerID"` Languages []Language `gorm:"many2many:UserSpeak"` Friends []*User `gorm:"many2many:user_friends"` + Active bool } type Account struct { From 18236fa3d72c196d6a5c5ee4070626e305912645 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 16 Feb 2020 00:37:59 +0800 Subject: [PATCH 233/881] Add more tests for setter, valuer --- schema/field.go | 145 ++++++++++++++--------------------- schema/field_test.go | 137 +++++++++++++++++++++++++++------ schema/model_test.go | 41 ++++++++++ schema/schema_helper_test.go | 46 ++++++----- schema/schema_test.go | 45 ++++++++++- 5 files changed, 281 insertions(+), 133 deletions(-) create mode 100644 schema/model_test.go diff --git a/schema/field.go b/schema/field.go index b4610436..76f459ec 100644 --- a/schema/field.go +++ b/schema/field.go @@ -164,6 +164,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.Struct: if _, ok := fieldValue.Interface().(*time.Time); ok { field.DataType = Time + } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { + field.DataType = Time } case reflect.Array, reflect.Slice: if fieldValue.Type().Elem() == reflect.TypeOf(uint8(0)) { @@ -311,6 +313,24 @@ func (field *Field) setupValuerAndSetter() { } } + recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + return setter(value, v) + } + } else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { + field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) + } else if reflectV.Kind() == reflect.Ptr { + return field.Setter(value, reflectV.Elem().Interface()) + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } + return err + } + // Setter switch field.FieldType.Kind() { case reflect.Bool: @@ -321,17 +341,12 @@ func (field *Field) setupValuerAndSetter() { case *bool: field.ReflectValuer(value).SetBool(*data) default: - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) - } else { - field.ReflectValuer(value).SetBool(!reflect.ValueOf(v).IsZero()) - } + return recoverFunc(value, v, field.Setter) } return nil } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - field.Setter = func(value reflect.Value, v interface{}) error { + field.Setter = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case int64: field.ReflectValuer(value).SetInt(data) @@ -366,19 +381,12 @@ func (field *Field) setupValuerAndSetter() { return err } default: - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) - } else if reflectV.Kind() == reflect.Ptr { - return field.Setter(value, reflectV.Elem().Interface()) - } else { - return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) - } + return recoverFunc(value, v, field.Setter) } - return nil + return err } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - field.Setter = func(value reflect.Value, v interface{}) error { + field.Setter = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case uint64: field.ReflectValuer(value).SetUint(data) @@ -413,19 +421,12 @@ func (field *Field) setupValuerAndSetter() { return err } default: - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) - } else if reflectV.Kind() == reflect.Ptr { - return field.Setter(value, reflectV.Elem().Interface()) - } else { - return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) - } + return recoverFunc(value, v, field.Setter) } - return nil + return err } case reflect.Float32, reflect.Float64: - field.Setter = func(value reflect.Value, v interface{}) error { + field.Setter = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case float64: field.ReflectValuer(value).SetFloat(data) @@ -460,19 +461,12 @@ func (field *Field) setupValuerAndSetter() { return err } default: - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) - } else if reflectV.Kind() == reflect.Ptr { - return field.Setter(value, reflectV.Elem().Interface()) - } else { - return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) - } + return recoverFunc(value, v, field.Setter) } - return nil + return err } case reflect.String: - field.Setter = func(value reflect.Value, v interface{}) error { + field.Setter = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case string: field.ReflectValuer(value).SetString(data) @@ -483,16 +477,9 @@ func (field *Field) setupValuerAndSetter() { case float64, float32: field.ReflectValuer(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) default: - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) - } else if reflectV.Kind() == reflect.Ptr { - return field.Setter(value, reflectV.Elem().Interface()) - } else { - return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) - } + return recoverFunc(value, v, field.Setter) } - return nil + return err } default: fieldValue := reflect.New(field.FieldType) @@ -511,7 +498,7 @@ func (field *Field) setupValuerAndSetter() { return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) } default: - return fmt.Errorf("failed to set value %+v to time.Time field %v", v, field.Name) + return recoverFunc(value, v, field.Setter) } return nil } @@ -529,14 +516,35 @@ func (field *Field) setupValuerAndSetter() { return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) } default: - return fmt.Errorf("failed to set value %+v to time.Time field %v", v, field.Name) + return recoverFunc(value, v, field.Setter) } return nil } default: if _, ok := fieldValue.Interface().(sql.Scanner); ok { + // struct scanner field.Setter = func(value reflect.Value, v interface{}) (err error) { - if valuer, ok := v.(driver.Valuer); ok { + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) + } + } else { + err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) + } + return + } + } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { + // pointer scanner + field.Setter = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { + field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) + } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v) } @@ -545,46 +553,9 @@ func (field *Field) setupValuerAndSetter() { } return } - return - } - - if fieldValue.CanAddr() { - if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - field.Setter = func(value reflect.Value, v interface{}) (err error) { - if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) - } - } else { - err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) - } - return - } - return - } - } - - if field.FieldType.Kind() == reflect.Ptr { - field.Setter = func(value reflect.Value, v interface{}) (err error) { - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) - } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { - field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) - } else { - return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) - } - return nil - } } else { field.Setter = func(value reflect.Value, v interface{}) (err error) { - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) - } else { - return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) - } - return nil + return recoverFunc(value, v, field.Setter) } } } diff --git a/schema/field_test.go b/schema/field_test.go index 065d6d05..15dfa41d 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "database/sql" "reflect" "sync" "testing" @@ -13,8 +14,7 @@ import ( func TestFieldValuerAndSetter(t *testing.T) { var ( - cacheMap = sync.Map{} - userSchema, _ = schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{}) + userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) user = tests.User{ Model: gorm.Model{ ID: 10, @@ -54,20 +54,38 @@ func TestFieldValuerAndSetter(t *testing.T) { for k, v := range newValues { if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { - t.Errorf("no error should happen when assign value to field %v", k) + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } checkField(t, userSchema, reflectValue, newValues) + + // test valuer and other type + age := myint(10) + newValues2 := map[string]interface{}{ + "name": sql.NullString{String: "valuer_and_setter_3", Valid: true}, + "id": &sql.NullInt64{Int64: 3, Valid: true}, + "created_at": tests.Now(), + "deleted_at": time.Now(), + "age": &age, + "birthday": mytime(time.Now()), + "active": mybool(true), + } + + for k, v := range newValues2 { + if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + checkField(t, userSchema, reflectValue, newValues2) } func TestPointerFieldValuerAndSetter(t *testing.T) { var ( - cacheMap = sync.Map{} - userSchema, _ = schema.Parse(&User{}, &cacheMap, schema.NamingStrategy{}) - name = "pointer_field_valuer_and_setter" - age = 18 - active = true - user = User{ + userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + name = "pointer_field_valuer_and_setter" + age uint = 18 + active = true + user = User{ Model: &gorm.Model{ ID: 10, CreatedAt: time.Now(), @@ -110,22 +128,91 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { } } checkField(t, userSchema, reflectValue, newValues) + + // test valuer and other type + age2 := myint(10) + newValues2 := map[string]interface{}{ + "name": sql.NullString{String: "valuer_and_setter_3", Valid: true}, + "id": &sql.NullInt64{Int64: 3, Valid: true}, + "created_at": tests.Now(), + "deleted_at": time.Now(), + "age": &age2, + "birthday": mytime(time.Now()), + "active": mybool(true), + } + + for k, v := range newValues2 { + if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + checkField(t, userSchema, reflectValue, newValues2) } -type User struct { - *gorm.Model - Name *string - Age *int - Birthday *time.Time - Account *tests.Account - Pets []*tests.Pet - Toys []tests.Toy `gorm:"polymorphic:Owner"` - CompanyID *int - Company *tests.Company - ManagerID *int - Manager *User - Team []User `gorm:"foreignkey:ManagerID"` - Languages []tests.Language `gorm:"many2many:UserSpeak"` - Friends []*User `gorm:"many2many:user_friends"` - Active *bool +func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { + var ( + userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) + name = "advanced_data_type_valuer_and_setter" + deletedAt = mytime(time.Now()) + isAdmin = mybool(false) + user = AdvancedDataTypeUser{ + ID: sql.NullInt64{Int64: 10, Valid: true}, + Name: &sql.NullString{String: name, Valid: true}, + Birthday: sql.NullTime{Time: time.Now(), Valid: true}, + RegisteredAt: mytime(time.Now()), + DeletedAt: &deletedAt, + Active: mybool(true), + Admin: &isAdmin, + } + reflectValue = reflect.ValueOf(&user) + ) + + // test valuer + values := map[string]interface{}{ + "id": user.ID, + "name": user.Name, + "birthday": user.Birthday, + "registered_at": user.RegisteredAt, + "deleted_at": user.DeletedAt, + "active": user.Active, + "admin": user.Admin, + } + checkField(t, userSchema, reflectValue, values) + + // test setter + newDeletedAt := mytime(time.Now()) + newIsAdmin := mybool(true) + newValues := map[string]interface{}{ + "id": sql.NullInt64{Int64: 1, Valid: true}, + "name": &sql.NullString{String: name + "rename", Valid: true}, + "birthday": time.Now(), + "registered_at": mytime(time.Now()), + "deleted_at": &newDeletedAt, + "active": mybool(false), + "admin": &newIsAdmin, + } + + for k, v := range newValues { + if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + checkField(t, userSchema, reflectValue, newValues) + + newValues2 := map[string]interface{}{ + "id": 5, + "name": name + "rename2", + "birthday": time.Now(), + "registered_at": time.Now(), + "deleted_at": time.Now(), + "active": true, + "admin": false, + } + + for k, v := range newValues2 { + if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + checkField(t, userSchema, reflectValue, newValues2) } diff --git a/schema/model_test.go b/schema/model_test.go new file mode 100644 index 00000000..aca7e617 --- /dev/null +++ b/schema/model_test.go @@ -0,0 +1,41 @@ +package schema_test + +import ( + "database/sql" + "time" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/tests" +) + +type User struct { + *gorm.Model + Name *string + Age *uint + Birthday *time.Time + Account *tests.Account + Pets []*tests.Pet + Toys []*tests.Toy `gorm:"polymorphic:Owner"` + CompanyID *int + Company *tests.Company + ManagerID *int + Manager *User + Team []*User `gorm:"foreignkey:ManagerID"` + Languages []*tests.Language `gorm:"many2many:UserSpeak"` + Friends []*User `gorm:"many2many:user_friends"` + Active *bool +} + +type mytime time.Time +type myint int +type mybool = bool + +type AdvancedDataTypeUser struct { + ID sql.NullInt64 + Name *sql.NullString + Birthday sql.NullTime + RegisteredAt mytime + DeletedAt *mytime + Active mybool + Admin *mybool +} diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 4af0fc89..8ac2f002 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "database/sql/driver" "fmt" "reflect" "strings" @@ -194,30 +195,39 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { for k, v := range values { t.Run("CheckField/"+k, func(t *testing.T) { - field := s.FieldsByDBName[k] - fv := field.ValueOf(value) + var ( + checker func(fv interface{}, v interface{}) + field = s.FieldsByDBName[k] + fv = field.ValueOf(value) + ) - if reflect.ValueOf(fv).Kind() == reflect.Ptr { - if reflect.ValueOf(v).Kind() == reflect.Ptr { - if fv != v { - t.Errorf("pointer expects: %p, but got %p", v, fv) + checker = func(fv interface{}, v interface{}) { + if reflect.ValueOf(fv).Type() == reflect.ValueOf(v).Type() && fv != v { + t.Errorf("expects: %p, but got %p", v, fv) + } else if reflect.ValueOf(v).Type().ConvertibleTo(reflect.ValueOf(fv).Type()) { + if reflect.ValueOf(v).Convert(reflect.ValueOf(fv).Type()).Interface() != fv { + t.Errorf("expects: %p, but got %p", v, fv) } - } else if fv == nil { - if v != nil { - t.Errorf("expects: %+v, but got nil", v) + } else if reflect.ValueOf(fv).Type().ConvertibleTo(reflect.ValueOf(v).Type()) { + if reflect.ValueOf(fv).Convert(reflect.ValueOf(fv).Type()).Interface() != v { + t.Errorf("expects: %p, but got %p", v, fv) } - } else if reflect.ValueOf(fv).Elem().Interface() != v { - t.Errorf("expects: %+v, but got %+v", v, fv) - } - } else if reflect.ValueOf(v).Kind() == reflect.Ptr { - if reflect.ValueOf(v).Elem().Interface() != fv { - t.Errorf("expects: %+v, but got %+v", v, fv) - } - } else if reflect.ValueOf(v).Type().ConvertibleTo(field.FieldType) { - if reflect.ValueOf(v).Convert(field.FieldType).Interface() != fv { + } else if valuer, isValuer := fv.(driver.Valuer); isValuer { + valuerv, _ := valuer.Value() + checker(valuerv, v) + } else if valuer, isValuer := v.(driver.Valuer); isValuer { + valuerv, _ := valuer.Value() + checker(fv, valuerv) + } else if reflect.ValueOf(fv).Kind() == reflect.Ptr { + checker(reflect.ValueOf(fv).Elem().Interface(), v) + } else if reflect.ValueOf(v).Kind() == reflect.Ptr { + checker(fv, reflect.ValueOf(v).Elem().Interface()) + } else { t.Errorf("expects: %+v, but got %+v", v, fv) } } + + checker(fv, v) }) } } diff --git a/schema/schema_test.go b/schema/schema_test.go index 97da1d5d..4134c966 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -9,13 +9,24 @@ import ( ) func TestParseSchema(t *testing.T) { - cacheMap := sync.Map{} - - user, err := schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{}) + user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user, got error %v", err) } + checkUserSchema(t, user) +} + +func TestParseSchemaWithPointerFields(t *testing.T) { + user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse pointer user, got error %v", err) + } + + checkUserSchema(t, user) +} + +func checkUserSchema(t *testing.T, user *schema.Schema) { // check schema checkSchema(t, user, schema.Schema{Name: "User", Table: "users"}, []string{"ID"}) @@ -101,3 +112,31 @@ func TestParseSchema(t *testing.T) { checkSchemaRelation(t, user, relation) } } + +func TestParseSchemaWithAdvancedDataType(t *testing.T) { + user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse pointer user, got error %v", err) + } + + // check schema + checkSchema(t, user, schema.Schema{Name: "AdvancedDataTypeUser", Table: "advanced_data_type_users"}, []string{"ID"}) + + // check fields + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true}, + {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, + {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, + {Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time}, + {Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"DeletedAt"}, DataType: schema.Time}, + {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, + {Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool}, + } + + for _, f := range fields { + checkSchemaField(t, user, &f, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + }) + } +} From 98ad29f2c24bd5c358355c8daacf575dd888d6ca Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 16 Feb 2020 13:45:27 +0800 Subject: [PATCH 234/881] Add Selects, Omits for statement --- chainable_api.go | 72 ++++++++++++++++++++++++++++++++++--------- clause/select.go | 12 ++++---- clause/select_test.go | 2 +- dialects/mysql/go.mod | 7 ----- go.mod | 4 +-- helpers.go | 5 +++ statement.go | 2 ++ 7 files changed, 73 insertions(+), 31 deletions(-) delete mode 100644 dialects/mysql/go.mod diff --git a/chainable_api.go b/chainable_api.go index 432026cf..9aa08b54 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -2,6 +2,7 @@ package gorm import ( "fmt" + "strings" "github.com/jinzhu/gorm/clause" ) @@ -31,9 +32,7 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { } if len(whereConds) > 0 { - tx.Statement.AddClause(&clause.Where{ - tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...), - }) + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...)}) } return } @@ -48,38 +47,83 @@ func (db *DB) Table(name string) (tx *DB) { // Select specify fields that you want when querying, creating, updating func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() + + switch v := query.(type) { + case []string: + tx.Statement.Selects = v + + for _, arg := range args { + switch arg := arg.(type) { + case string: + tx.Statement.Selects = append(tx.Statement.Selects, arg) + case []string: + tx.Statement.Selects = append(tx.Statement.Selects, arg...) + default: + tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) + return + } + } + case string: + fields := strings.FieldsFunc(v, isChar) + + // normal field names + if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { + tx.Statement.Selects = fields + + for _, arg := range args { + switch arg := arg.(type) { + case string: + tx.Statement.Selects = append(tx.Statement.Selects, arg) + case []string: + tx.Statement.Selects = append(tx.Statement.Selects, arg...) + default: + tx.Statement.AddClause(clause.Select{ + Expression: clause.Expr{SQL: v, Vars: args}, + }) + return + } + } + } else { + tx.Statement.AddClause(clause.Select{ + Expression: clause.Expr{SQL: v, Vars: args}, + }) + } + default: + tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) + } + return } // Omit specify fields that you want to ignore when creating, updating and querying func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() + + if len(columns) == 1 && strings.Contains(columns[0], ",") { + tx.Statement.Omits = strings.FieldsFunc(columns[0], isChar) + } else { + tx.Statement.Omits = columns + } return } func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(&clause.Where{ - tx.Statement.BuildCondtion(query, args...), - }) + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(query, args...)}) return } // Not add NOT condition func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(&clause.Where{ - []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}, - }) + tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}}) return } // Or add OR conditions func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(&clause.Where{ - []clause.Expression{clause.Or(tx.Statement.BuildCondtion(query, args...)...)}, - }) + tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(tx.Statement.BuildCondtion(query, args...)...)}}) return } @@ -110,11 +154,11 @@ func (db *DB) Order(value interface{}) (tx *DB) { switch v := value.(type) { case clause.OrderByColumn: - db.Statement.AddClause(clause.OrderBy{ + tx.Statement.AddClause(clause.OrderBy{ Columns: []clause.OrderByColumn{v}, }) default: - db.Statement.AddClause(clause.OrderBy{ + tx.Statement.AddClause(clause.OrderBy{ Columns: []clause.OrderByColumn{{ Column: clause.Column{Name: fmt.Sprint(value), Raw: true}, }}, diff --git a/clause/select.go b/clause/select.go index 4bb1af8d..20b17e07 100644 --- a/clause/select.go +++ b/clause/select.go @@ -2,8 +2,8 @@ package clause // Select select attrs when querying, updating, creating type Select struct { - Columns []Column - Omits []Column + Columns []Column + Expression Expression } func (s Select) Name() string { @@ -24,9 +24,9 @@ func (s Select) Build(builder Builder) { } func (s Select) MergeClause(clause *Clause) { - if v, ok := clause.Expression.(Select); ok { - s.Columns = append(v.Columns, s.Columns...) - s.Omits = append(v.Omits, s.Omits...) + if s.Expression != nil { + clause.Expression = s.Expression + } else { + clause.Expression = s } - clause.Expression = s } diff --git a/clause/select_test.go b/clause/select_test.go index 8255e51b..0863d086 100644 --- a/clause/select_test.go +++ b/clause/select_test.go @@ -29,7 +29,7 @@ func TestSelect(t *testing.T) { }, clause.Select{ Columns: []clause.Column{{Name: "name"}}, }, clause.From{}}, - "SELECT `users`.`id`,`name` FROM `users`", nil, + "SELECT `name` FROM `users`", nil, }, } diff --git a/dialects/mysql/go.mod b/dialects/mysql/go.mod deleted file mode 100644 index a1f29122..00000000 --- a/dialects/mysql/go.mod +++ /dev/null @@ -1,7 +0,0 @@ -module github.com/jinzhu/gorm/dialects/mysql - -go 1.13 - -require ( - github.com/go-sql-driver/mysql v1.5.0 -) diff --git a/go.mod b/go.mod index e47297fb..cdb7e574 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,6 @@ module github.com/jinzhu/gorm go 1.13 require ( - github.com/go-sql-driver/mysql v1.5.0 // indirect github.com/jinzhu/inflection v1.0.0 - github.com/lib/pq v1.3.0 // indirect - github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect + github.com/jinzhu/now v1.1.1 ) diff --git a/helpers.go b/helpers.go index 77bbece8..2e5c8ed1 100644 --- a/helpers.go +++ b/helpers.go @@ -3,6 +3,7 @@ package gorm import ( "errors" "time" + "unicode" ) var ( @@ -27,3 +28,7 @@ type Model struct { UpdatedAt time.Time DeletedAt *time.Time `gorm:"index"` } + +func isChar(c rune) bool { + return !unicode.IsLetter(c) && !unicode.IsNumber(c) +} diff --git a/statement.go b/statement.go index 1c3934c1..b2626d95 100644 --- a/statement.go +++ b/statement.go @@ -43,6 +43,8 @@ type Statement struct { Model interface{} Dest interface{} Clauses map[string]clause.Clause + Selects []string // selected columns + Omits []string // omit columns Settings sync.Map DB *DB Schema *schema.Schema From cbbf8f3d497bd7c9064a48324701dbdb8947f8c5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Feb 2020 22:56:37 +0800 Subject: [PATCH 235/881] Update schema --- schema/field.go | 322 ++++++++++++++++++++--------------- schema/schema.go | 4 + schema/schema_helper_test.go | 2 +- 3 files changed, 188 insertions(+), 140 deletions(-) diff --git a/schema/field.go b/schema/field.go index 76f459ec..e4c80734 100644 --- a/schema/field.go +++ b/schema/field.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strconv" + "strings" "sync" "time" @@ -14,6 +15,13 @@ import ( type DataType string +type TimeType int64 + +const ( + UnixSecond TimeType = 1 + UnixNanosecond TimeType = 2 +) + const ( Bool DataType = "bool" Int = "int" @@ -25,32 +33,35 @@ const ( ) type Field struct { - Name string - DBName string - BindNames []string - DataType DataType - DBDataType string - PrimaryKey bool - AutoIncrement bool - Creatable bool - Updatable bool - HasDefaultValue bool - DefaultValue string - NotNull bool - Unique bool - Comment string - Size int - Precision int - FieldType reflect.Type - IndirectFieldType reflect.Type - StructField reflect.StructField - Tag reflect.StructTag - TagSettings map[string]string - Schema *Schema - EmbeddedSchema *Schema - ReflectValuer func(reflect.Value) reflect.Value - Valuer func(reflect.Value) interface{} - Setter func(reflect.Value, interface{}) error + Name string + DBName string + BindNames []string + DataType DataType + DBDataType string + PrimaryKey bool + AutoIncrement bool + Creatable bool + Updatable bool + HasDefaultValue bool + AutoCreateTime TimeType + AutoUpdateTime TimeType + DefaultValue string + DefaultValueInterface interface{} + NotNull bool + Unique bool + Comment string + Size int + Precision int + FieldType reflect.Type + IndirectFieldType reflect.Type + StructField reflect.StructField + Tag reflect.StructTag + TagSettings map[string]string + Schema *Schema + EmbeddedSchema *Schema + ReflectValueOf func(reflect.Value) reflect.Value + ValueOf func(reflect.Value) (value interface{}, zero bool) + Set func(reflect.Value, interface{}) error } func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { @@ -73,7 +84,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { fieldValue := reflect.New(field.IndirectFieldType) // if field is valuer, used its value or first fields as data type - if valuer, isValuer := fieldValue.Interface().(driver.Valuer); isValuer { + if valuer, isValueOf := fieldValue.Interface().(driver.Valuer); isValueOf { var overrideFieldValue bool if v, err := valuer.Value(); v != nil && err == nil { overrideFieldValue = true @@ -150,17 +161,48 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DBDataType = val } + if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int)) { + if strings.ToUpper(v) == "NANO" { + field.AutoCreateTime = UnixNanosecond + } else { + field.AutoCreateTime = UnixSecond + } + } + + if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int)) { + if strings.ToUpper(v) == "NANO" { + field.AutoUpdateTime = UnixNanosecond + } else { + field.AutoUpdateTime = UnixSecond + } + } + switch fieldValue.Elem().Kind() { case reflect.Bool: field.DataType = Bool + if field.HasDefaultValue { + field.DefaultValueInterface, _ = strconv.ParseBool(field.DefaultValue) + } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: field.DataType = Int + if field.HasDefaultValue { + field.DefaultValueInterface, _ = strconv.ParseInt(field.DefaultValue, 0, 64) + } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: field.DataType = Uint + if field.HasDefaultValue { + field.DefaultValueInterface, _ = strconv.ParseUint(field.DefaultValue, 0, 64) + } case reflect.Float32, reflect.Float64: field.DataType = Float + if field.HasDefaultValue { + field.DefaultValueInterface, _ = strconv.ParseFloat(field.DefaultValue, 64) + } case reflect.String: field.DataType = String + if field.HasDefaultValue { + field.DefaultValueInterface = field.DefaultValue + } case reflect.Struct: if _, ok := fieldValue.Interface().(*time.Time); ok { field.DataType = Time @@ -216,36 +258,22 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { return field } -// ValueOf field value of -func (field *Field) ValueOf(value reflect.Value) interface{} { - if field != nil { - return field.Valuer(value) - } - return nil -} - -func (field *Field) Set(value reflect.Value, v interface{}) error { - if field != nil { - return field.Setter(value, v) - } - - return fmt.Errorf("failed to set field value: %v", field.Name) -} - // create valuer, setter when parse struct func (field *Field) setupValuerAndSetter() { - // Valuer + // ValueOf switch { case len(field.StructField.Index) == 1: - field.Valuer = func(value reflect.Value) interface{} { - return reflect.Indirect(value).Field(field.StructField.Index[0]).Interface() + field.ValueOf = func(value reflect.Value) (interface{}, bool) { + fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) + return fieldValue.Interface(), fieldValue.IsZero() } case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: - field.Valuer = func(value reflect.Value) interface{} { - return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface() + field.ValueOf = func(value reflect.Value) (interface{}, bool) { + fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) + return fieldValue.Interface(), fieldValue.IsZero() } default: - field.Valuer = func(value reflect.Value) interface{} { + field.ValueOf = func(value reflect.Value) (interface{}, bool) { v := reflect.Indirect(value) for _, idx := range field.StructField.Index { @@ -259,19 +287,19 @@ func (field *Field) setupValuerAndSetter() { v = v.Elem() } } else { - return nil + return nil, true } } } - return v.Interface() + return v.Interface(), v.IsZero() } } - // ReflectValuer + // ReflectValueOf switch { case len(field.StructField.Index) == 1: if field.FieldType.Kind() == reflect.Ptr { - field.ReflectValuer = func(value reflect.Value) reflect.Value { + field.ReflectValueOf = func(value reflect.Value) reflect.Value { fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) @@ -279,16 +307,16 @@ func (field *Field) setupValuerAndSetter() { return fieldValue } } else { - field.ReflectValuer = func(value reflect.Value) reflect.Value { + field.ReflectValueOf = func(value reflect.Value) reflect.Value { return reflect.Indirect(value).Field(field.StructField.Index[0]) } } case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: - field.ReflectValuer = func(value reflect.Value) reflect.Value { + field.ReflectValueOf = func(value reflect.Value) reflect.Value { return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) } default: - field.ReflectValuer = func(value reflect.Value) reflect.Value { + field.ReflectValueOf = func(value reflect.Value) reflect.Value { v := reflect.Indirect(value) for _, idx := range field.StructField.Index { if idx >= 0 { @@ -316,168 +344,184 @@ func (field *Field) setupValuerAndSetter() { recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { reflectV := reflect.ValueOf(v) if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { return setter(value, v) } } else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { - field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) + field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) } else if reflectV.Kind() == reflect.Ptr { - return field.Setter(value, reflectV.Elem().Interface()) + return field.Set(value, reflectV.Elem().Interface()) } else { return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) } return err } - // Setter + // Set switch field.FieldType.Kind() { case reflect.Bool: - field.Setter = func(value reflect.Value, v interface{}) error { + field.Set = func(value reflect.Value, v interface{}) error { switch data := v.(type) { case bool: - field.ReflectValuer(value).SetBool(data) + field.ReflectValueOf(value).SetBool(data) case *bool: - field.ReflectValuer(value).SetBool(*data) + field.ReflectValueOf(value).SetBool(*data) default: - return recoverFunc(value, v, field.Setter) + return recoverFunc(value, v, field.Set) } return nil } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - field.Setter = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case int64: - field.ReflectValuer(value).SetInt(data) + field.ReflectValueOf(value).SetInt(data) case int: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case int8: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case int16: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case int32: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case uint: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case uint8: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case uint16: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case uint32: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case uint64: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case float32: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case float64: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case []byte: - return field.Setter(value, string(data)) + return field.Set(value, string(data)) case string: if i, err := strconv.ParseInt(data, 0, 64); err == nil { - field.ReflectValuer(value).SetInt(i) + field.ReflectValueOf(value).SetInt(i) } else { return err } + case time.Time: + if field.AutoCreateTime == UnixNanosecond { + field.ReflectValueOf(value).SetInt(data.UnixNano()) + } else { + field.ReflectValueOf(value).SetInt(data.Unix()) + } + case *time.Time: + if data != nil { + if field.AutoCreateTime == UnixNanosecond { + field.ReflectValueOf(value).SetInt(data.UnixNano()) + } else { + field.ReflectValueOf(value).SetInt(data.Unix()) + } + } else { + field.ReflectValueOf(value).SetInt(0) + } default: - return recoverFunc(value, v, field.Setter) + return recoverFunc(value, v, field.Set) } return err } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - field.Setter = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case uint64: - field.ReflectValuer(value).SetUint(data) + field.ReflectValueOf(value).SetUint(data) case uint: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case uint8: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case uint16: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case uint32: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case int64: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case int: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case int8: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case int16: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case int32: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case float32: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case float64: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case []byte: - return field.Setter(value, string(data)) + return field.Set(value, string(data)) case string: if i, err := strconv.ParseUint(data, 0, 64); err == nil { - field.ReflectValuer(value).SetUint(i) + field.ReflectValueOf(value).SetUint(i) } else { return err } default: - return recoverFunc(value, v, field.Setter) + return recoverFunc(value, v, field.Set) } return err } case reflect.Float32, reflect.Float64: - field.Setter = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case float64: - field.ReflectValuer(value).SetFloat(data) + field.ReflectValueOf(value).SetFloat(data) case float32: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case int64: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case int: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case int8: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case int16: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case int32: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case uint: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case uint8: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case uint16: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case uint32: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case uint64: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case []byte: - return field.Setter(value, string(data)) + return field.Set(value, string(data)) case string: if i, err := strconv.ParseFloat(data, 64); err == nil { - field.ReflectValuer(value).SetFloat(i) + field.ReflectValueOf(value).SetFloat(i) } else { return err } default: - return recoverFunc(value, v, field.Setter) + return recoverFunc(value, v, field.Set) } return err } case reflect.String: - field.Setter = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case string: - field.ReflectValuer(value).SetString(data) + field.ReflectValueOf(value).SetString(data) case []byte: - field.ReflectValuer(value).SetString(string(data)) + field.ReflectValueOf(value).SetString(string(data)) case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - field.ReflectValuer(value).SetString(fmt.Sprint(data)) + field.ReflectValueOf(value).SetString(fmt.Sprint(data)) case float64, float32: - field.ReflectValuer(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) + field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) default: - return recoverFunc(value, v, field.Setter) + return recoverFunc(value, v, field.Set) } return err } @@ -485,77 +529,77 @@ func (field *Field) setupValuerAndSetter() { fieldValue := reflect.New(field.FieldType) switch fieldValue.Elem().Interface().(type) { case time.Time: - field.Setter = func(value reflect.Value, v interface{}) error { + field.Set = func(value reflect.Value, v interface{}) error { switch data := v.(type) { case time.Time: - field.ReflectValuer(value).Set(reflect.ValueOf(v)) + field.ReflectValueOf(value).Set(reflect.ValueOf(v)) case *time.Time: - field.ReflectValuer(value).Set(reflect.ValueOf(v).Elem()) + field.ReflectValueOf(value).Set(reflect.ValueOf(v).Elem()) case string: if t, err := now.Parse(data); err == nil { - field.ReflectValuer(value).Set(reflect.ValueOf(t)) + field.ReflectValueOf(value).Set(reflect.ValueOf(t)) } else { return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) } default: - return recoverFunc(value, v, field.Setter) + return recoverFunc(value, v, field.Set) } return nil } case *time.Time: - field.Setter = func(value reflect.Value, v interface{}) error { + field.Set = func(value reflect.Value, v interface{}) error { switch data := v.(type) { case time.Time: - field.ReflectValuer(value).Elem().Set(reflect.ValueOf(v)) + field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(v)) case *time.Time: - field.ReflectValuer(value).Set(reflect.ValueOf(v)) + field.ReflectValueOf(value).Set(reflect.ValueOf(v)) case string: if t, err := now.Parse(data); err == nil { - field.ReflectValuer(value).Elem().Set(reflect.ValueOf(t)) + field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(t)) } else { return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) } default: - return recoverFunc(value, v, field.Setter) + return recoverFunc(value, v, field.Set) } return nil } default: if _, ok := fieldValue.Interface().(sql.Scanner); ok { // struct scanner - field.Setter = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { - err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) + err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) } } else { - err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) + err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) } return } } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner - field.Setter = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { - field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) + field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { - err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v) + err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) } } else { - err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v) + err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) } return } } else { - field.Setter = func(value reflect.Value, v interface{}) (err error) { - return recoverFunc(value, v, field.Setter) + field.Set = func(value reflect.Value, v interface{}) (err error) { + return recoverFunc(value, v, field.Set) } } } diff --git a/schema/schema.go b/schema/schema.go index 2f3cdf88..63e388f5 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -18,6 +18,7 @@ type Schema struct { ModelType reflect.Type Table string PrioritizedPrimaryField *Field + DBNames []string PrimaryFields []*Field Fields []*Field FieldsByName map[string]*Field @@ -99,6 +100,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if field.DBName != "" { // nonexistence or shortest path or first appear prioritized if has permission if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) { + if _, ok := schema.FieldsByDBName[field.DBName]; !ok { + schema.DBNames = append(schema.DBNames, field.DBName) + } schema.FieldsByDBName[field.DBName] = field schema.FieldsByName[field.Name] = field diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 8ac2f002..60e51543 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -198,7 +198,7 @@ func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[ var ( checker func(fv interface{}, v interface{}) field = s.FieldsByDBName[k] - fv = field.ValueOf(value) + fv, _ = field.ValueOf(value) ) checker = func(fv interface{}, v interface{}) { From 15ce5b3cdd8b256ce070245b3a41a1ca7d4ca0fb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 19 Feb 2020 12:53:46 +0800 Subject: [PATCH 236/881] Add create value converter --- callbacks/create.go | 87 +++++++++++++++++++++++++++++++++++++++- callbacks/helper.go | 97 +++++++++++++++++++++++++++++++++++++++++++++ chainable_api.go | 2 +- clause/values.go | 3 +- 4 files changed, 186 insertions(+), 3 deletions(-) create mode 100644 callbacks/helper.go diff --git a/callbacks/create.go b/callbacks/create.go index 58256085..8dba8a5f 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -2,6 +2,7 @@ package callbacks import ( "fmt" + "reflect" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" @@ -19,11 +20,15 @@ func SaveBeforeAssociations(db *gorm.DB) { func Create(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Insert{ - Table: clause.Table{Table: db.Statement.Table}, + Table: clause.Table{Name: db.Statement.Table}, }) + values, _ := ConvertToCreateValues(db.Statement) + db.Statement.AddClause(values) db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + fmt.Printf("%+v\n", values) fmt.Println(err) fmt.Println(result) fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) @@ -36,3 +41,83 @@ func AfterCreate(db *gorm.DB) { // after save // after create } + +// ConvertToCreateValues convert to create values +func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]interface{}) { + switch value := stmt.Dest.(type) { + case map[string]interface{}: + return ConvertMapToValues(stmt, value), nil + case []map[string]interface{}: + return ConvertSliceOfMapToValues(stmt, value), nil + default: + var ( + values = clause.Values{} + selectColumns, restricted = SelectAndOmitColumns(stmt) + curTime = stmt.DB.NowFunc() + isZero = false + returnningValues []map[string]interface{} + ) + + for _, db := range stmt.Schema.DBNames { + if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { + values.Columns = append(values.Columns, clause.Column{Name: db}) + } + } + + reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Dest)) + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + values.Values = make([][]interface{}, reflectValue.Len()) + for i := 0; i < reflectValue.Len(); i++ { + rv := reflect.Indirect(reflectValue.Index(i)) + values.Values[i] = make([]interface{}, len(values.Columns)) + for idx, column := range values.Columns { + field := stmt.Schema.FieldsByDBName[column.Name] + if values.Values[i][idx], isZero = field.ValueOf(rv); isZero { + if field.DefaultValueInterface != nil { + values.Values[i][idx] = field.DefaultValueInterface + field.Set(rv, field.DefaultValueInterface) + } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { + field.Set(rv, curTime) + values.Values[i][idx], _ = field.ValueOf(rv) + } else if field.HasDefaultValue { + if len(returnningValues) == 0 { + returnningValues = make([]map[string]interface{}, reflectValue.Len()) + } + + if returnningValues[i] == nil { + returnningValues[i] = map[string]interface{}{} + } + + // FIXME + returnningValues[i][column.Name] = field.ReflectValueOf(reflectValue).Addr().Interface() + } + } + } + } + case reflect.Struct: + values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} + for idx, column := range values.Columns { + field := stmt.Schema.FieldsByDBName[column.Name] + if values.Values[0][idx], _ = field.ValueOf(reflectValue); isZero { + if field.DefaultValueInterface != nil { + values.Values[0][idx] = field.DefaultValueInterface + field.Set(reflectValue, field.DefaultValueInterface) + } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { + field.Set(reflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(reflectValue) + } else if field.HasDefaultValue { + if len(returnningValues) == 0 { + returnningValues = make([]map[string]interface{}, 1) + } + + values.Values[0][idx] = clause.Expr{SQL: "DEFAULT"} + returnningValues[0][column.Name] = field.ReflectValueOf(reflectValue).Addr().Interface() + } else if field.PrimaryKey { + } + } + } + } + return values, returnningValues + } +} diff --git a/callbacks/helper.go b/callbacks/helper.go new file mode 100644 index 00000000..56c0767d --- /dev/null +++ b/callbacks/helper.go @@ -0,0 +1,97 @@ +package callbacks + +import ( + "sort" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" +) + +// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false +func SelectAndOmitColumns(stmt *gorm.Statement) (map[string]bool, bool) { + results := map[string]bool{} + + // select columns + for _, column := range stmt.Selects { + if field := stmt.Schema.LookUpField(column); field != nil { + results[field.DBName] = true + } else { + results[column] = true + } + } + + // omit columns + for _, omit := range stmt.Omits { + if field := stmt.Schema.LookUpField(omit); field != nil { + results[field.DBName] = false + } else { + results[omit] = false + } + } + + return results, len(stmt.Selects) > 0 +} + +// ConvertMapToValues convert map to values +func ConvertMapToValues(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { + columns := make([]string, 0, len(mapValue)) + selectColumns, restricted := SelectAndOmitColumns(stmt) + + var keys []string + for k, _ := range mapValue { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } + + if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + columns = append(columns, k) + values.Values[0] = append(values.Values[0], mapValue[k]) + } + } + return +} + +// ConvertSliceOfMapToValues convert slice of map to values +func ConvertSliceOfMapToValues(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { + var ( + columns = []string{} + result = map[string][]interface{}{} + selectColumns, restricted = SelectAndOmitColumns(stmt) + ) + + for idx, mapValue := range mapValues { + for k, v := range mapValue { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } + + if _, ok := result[k]; !ok { + if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + result[k] = make([]interface{}, len(mapValues)) + columns = append(columns, k) + } else { + continue + } + } + + result[k][idx] = v + } + } + + sort.Strings(columns) + values.Values = make([][]interface{}, len(mapValues)) + for idx, column := range columns { + for i, v := range result[column] { + if i == 0 { + values.Values[i] = make([]interface{}, len(columns)) + } + values.Values[i][idx] = v + } + } + return +} diff --git a/chainable_api.go b/chainable_api.go index 9aa08b54..a57deb63 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -99,7 +99,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() - if len(columns) == 1 && strings.Contains(columns[0], ",") { + if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { tx.Statement.Omits = strings.FieldsFunc(columns[0], isChar) } else { tx.Statement.Omits = columns diff --git a/clause/values.go b/clause/values.go index 594b92e2..2c8dcf89 100644 --- a/clause/values.go +++ b/clause/values.go @@ -7,7 +7,7 @@ type Values struct { // Name from clause name func (Values) Name() string { - return "" + return "VALUES" } // Build build from clause @@ -40,6 +40,7 @@ func (values Values) Build(builder Builder) { // MergeClause merge values clauses func (values Values) MergeClause(clause *Clause) { + clause.Name = "" if v, ok := clause.Expression.(Values); ok { values.Values = append(v.Values, values.Values...) } From 43ce0b8af2513b86a6b39ab68c7912dc373db6dc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 20 Feb 2020 10:13:26 +0800 Subject: [PATCH 237/881] Handle create with default db values --- callbacks/create.go | 58 +++++++++++++++++++++++++++++---------------- schema/schema.go | 41 ++++++++++++++++++++++---------- 2 files changed, 65 insertions(+), 34 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 8dba8a5f..95afc854 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -59,8 +59,10 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in ) for _, db := range stmt.Schema.DBNames { - if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { - values.Columns = append(values.Columns, clause.Column{Name: db}) + if stmt.Schema.FieldsWithDefaultDBValue[db] == nil { + if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { + values.Columns = append(values.Columns, clause.Column{Name: db}) + } } } @@ -68,6 +70,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in switch reflectValue.Kind() { case reflect.Slice, reflect.Array: values.Values = make([][]interface{}, reflectValue.Len()) + defaultValueFieldsHavingValue := map[string][]interface{}{} for i := 0; i < reflectValue.Len(); i++ { rv := reflect.Indirect(reflectValue.Index(i)) values.Values[i] = make([]interface{}, len(values.Columns)) @@ -80,44 +83,57 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { field.Set(rv, curTime) values.Values[i][idx], _ = field.ValueOf(rv) - } else if field.HasDefaultValue { - if len(returnningValues) == 0 { - returnningValues = make([]map[string]interface{}, reflectValue.Len()) - } - - if returnningValues[i] == nil { - returnningValues[i] = map[string]interface{}{} - } - - // FIXME - returnningValues[i][column.Name] = field.ReflectValueOf(reflectValue).Addr().Interface() } } } + + for db, field := range stmt.Schema.FieldsWithDefaultDBValue { + if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { + if v, isZero := field.ValueOf(rv); !isZero { + if len(defaultValueFieldsHavingValue[db]) == 0 { + defaultValueFieldsHavingValue[db] = make([]interface{}, reflectValue.Len()) + } + defaultValueFieldsHavingValue[db][i] = v + } + } + } + } + + for db, vs := range defaultValueFieldsHavingValue { + values.Columns = append(values.Columns, clause.Column{Name: db}) + for idx := range values.Values { + if vs[idx] == nil { + values.Values[idx] = append(values.Values[idx], clause.Expr{SQL: "DEFAULT"}) + } else { + values.Values[idx] = append(values.Values[idx], vs[idx]) + } + } } case reflect.Struct: values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] - if values.Values[0][idx], _ = field.ValueOf(reflectValue); isZero { + if values.Values[0][idx], isZero = field.ValueOf(reflectValue); isZero { if field.DefaultValueInterface != nil { values.Values[0][idx] = field.DefaultValueInterface field.Set(reflectValue, field.DefaultValueInterface) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { field.Set(reflectValue, curTime) values.Values[0][idx], _ = field.ValueOf(reflectValue) - } else if field.HasDefaultValue { - if len(returnningValues) == 0 { - returnningValues = make([]map[string]interface{}, 1) - } + } + } + } - values.Values[0][idx] = clause.Expr{SQL: "DEFAULT"} - returnningValues[0][column.Name] = field.ReflectValueOf(reflectValue).Addr().Interface() - } else if field.PrimaryKey { + for db, field := range stmt.Schema.FieldsWithDefaultDBValue { + if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { + if v, isZero := field.ValueOf(reflectValue); !isZero { + values.Columns = append(values.Columns, clause.Column{Name: db}) + values.Values[0] = append(values.Values[0], v) } } } } + return values, returnningValues } } diff --git a/schema/schema.go b/schema/schema.go index 63e388f5..acf6ff52 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -14,19 +14,20 @@ import ( var ErrUnsupportedDataType = errors.New("unsupported data type") type Schema struct { - Name string - ModelType reflect.Type - Table string - PrioritizedPrimaryField *Field - DBNames []string - PrimaryFields []*Field - Fields []*Field - FieldsByName map[string]*Field - FieldsByDBName map[string]*Field - Relationships Relationships - err error - namer Namer - cacheStore *sync.Map + Name string + ModelType reflect.Type + Table string + PrioritizedPrimaryField *Field + DBNames []string + PrimaryFields []*Field + Fields []*Field + FieldsByName map[string]*Field + FieldsByDBName map[string]*Field + FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database + Relationships Relationships + err error + namer Namer + cacheStore *sync.Map } func (schema Schema) String() string { @@ -146,6 +147,20 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } + schema.FieldsWithDefaultDBValue = map[string]*Field{} + for db, field := range schema.FieldsByDBName { + if field.HasDefaultValue && field.DefaultValueInterface == nil { + schema.FieldsWithDefaultDBValue[db] = field + } + } + + if schema.PrioritizedPrimaryField != nil { + switch schema.PrioritizedPrimaryField.DataType { + case Int, Uint: + schema.FieldsWithDefaultDBValue[schema.PrioritizedPrimaryField.DBName] = schema.PrioritizedPrimaryField + } + } + cacheStore.Store(modelType, schema) // parse relations for unidentified fields From 62dcd7896accb4cedfd9428a03a99332281da2a0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 20 Feb 2020 23:04:03 +0800 Subject: [PATCH 238/881] Add Migrator --- callbacks.go | 5 +- helpers.go | 2 + migrator.go | 12 +++- migrator/migrator.go | 153 ++++++++++++++++++++++++++++++++++++++++++- statement.go | 7 ++ 5 files changed, 172 insertions(+), 7 deletions(-) diff --git a/callbacks.go b/callbacks.go index 8546ae16..4f19a681 100644 --- a/callbacks.go +++ b/callbacks.go @@ -75,13 +75,10 @@ func (p *processor) Execute(db *DB) { } if stmt.Model != nil { - var err error - stmt.Schema, err = schema.Parse(stmt.Model, db.cacheStore, db.NamingStrategy) + err := stmt.Parse(stmt.Model) if err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") { db.AddError(err) - } else if stmt.Table == "" && stmt.Schema != nil { - stmt.Table = stmt.Schema.Table } } } diff --git a/helpers.go b/helpers.go index 2e5c8ed1..d7177ba7 100644 --- a/helpers.go +++ b/helpers.go @@ -15,6 +15,8 @@ var ( ErrInvalidTransaction = errors.New("no valid transaction") // ErrUnaddressable unaddressable value ErrUnaddressable = errors.New("using unaddressable value") + // ErrNotImplemented not implemented + ErrNotImplemented = errors.New("not implemented") ) // Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt diff --git a/migrator.go b/migrator.go index c21cda42..b6d273e7 100644 --- a/migrator.go +++ b/migrator.go @@ -4,6 +4,11 @@ import ( "database/sql" ) +// Migrator returns migrator +func (db *DB) Migrator() Migrator { + return db.Dialector.Migrator() +} + // ViewOption view option type ViewOption struct { Replace bool @@ -15,10 +20,13 @@ type Migrator interface { // AutoMigrate AutoMigrate(dst ...interface{}) error + // Database + CurrentDatabase() string + // Tables CreateTable(dst ...interface{}) error DropTable(dst ...interface{}) error - HasTable(dst ...interface{}) error + HasTable(dst ...interface{}) bool RenameTable(oldName, newName string) error // Columns @@ -39,6 +47,6 @@ type Migrator interface { // Indexes CreateIndex(dst interface{}, name string) error DropIndex(dst interface{}, name string) error - HasIndex(dst interface{}, name string) error + HasIndex(dst interface{}, name string) bool RenameIndex(dst interface{}, oldName, newName string) error } diff --git a/migrator/migrator.go b/migrator/migrator.go index 0ff83ac1..e9725935 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -1,6 +1,11 @@ package migrator -import "github.com/jinzhu/gorm" +import ( + "database/sql" + "fmt" + + "github.com/jinzhu/gorm" +) // Migrator migrator struct type Migrator struct { @@ -12,3 +17,149 @@ type Config struct { CheckExistsBeforeDropping bool DB *gorm.DB } + +func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { + stmt := migrator.DB.Statement + if stmt == nil { + stmt = &gorm.Statement{DB: migrator.DB} + } + + if err := stmt.Parse(value); err != nil { + return err + } + + return fc(stmt) +} + +// AutoMigrate +func (migrator Migrator) AutoMigrate(values ...interface{}) error { + return gorm.ErrNotImplemented +} + +func (migrator Migrator) CreateTable(values ...interface{}) error { + return gorm.ErrNotImplemented +} + +func (migrator Migrator) DropTable(values ...interface{}) error { + for _, value := range values { + if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + return migrator.DB.Exec("DROP TABLE " + stmt.Quote(stmt.Table)).Error + }); err != nil { + return err + } + } + return nil +} + +func (migrator Migrator) HasTable(values ...interface{}) bool { + var count int64 + for _, value := range values { + err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := migrator.DB.Migrator().CurrentDatabase() + return migrator.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Scan(&count).Error + }) + + if err != nil || count == 0 { + return false + } + } + + return true +} + +func (migrator Migrator) RenameTable(oldName, newName string) error { + return migrator.DB.Exec("RENAME TABLE ? TO ?", oldName, newName).Error +} + +func (migrator Migrator) AddColumn(value interface{}, field string) error { + return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(field); field != nil { + return migrator.DB.Exec(fmt.Sprintf("ALTER TABLE ? ADD ? %s", field.DBDataType), stmt.Table, field.DBName).Error + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + +func (migrator Migrator) DropColumn(value interface{}, field string) error { + return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(field); field != nil { + return migrator.DB.Exec("ALTER TABLE ? DROP COLUMN ?", stmt.Table, field.DBName).Error + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + +func (migrator Migrator) AlterColumn(value interface{}, field string) error { + return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(field); field != nil { + return migrator.DB.Exec(fmt.Sprintf("ALTER TABLE ? ALTER COLUMN ? TYPE %s", field.DBDataType), stmt.Table, field.DBName).Error + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + +func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) error { + return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(field); field != nil { + oldName = migrator.DB.NamingStrategy.ColumnName(stmt.Table, oldName) + return migrator.DB.Exec("ALTER TABLE ? RENAME COLUMN ? TO ?", stmt.Table, oldName, field.DBName).Error + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + +func (migrator Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) { + return nil, gorm.ErrNotImplemented +} + +func (migrator Migrator) CreateView(name string, option gorm.ViewOption) error { + return gorm.ErrNotImplemented +} + +func (migrator Migrator) DropView(name string) error { + return gorm.ErrNotImplemented +} + +func (migrator Migrator) CreateConstraint(value interface{}, name string) error { + return gorm.ErrNotImplemented +} + +func (migrator Migrator) DropConstraint(value interface{}, name string) error { + return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + return migrator.DB.Raw("ALTER TABLE ? DROP CONSTRAINT ?", stmt.Table, name).Error + }) +} + +func (migrator Migrator) CreateIndex(value interface{}, name string) error { + return gorm.ErrNotImplemented +} + +func (migrator Migrator) DropIndex(value interface{}, name string) error { + return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + return migrator.DB.Raw("DROP INDEX ? ON ?", name, stmt.Table).Error + }) +} + +func (migrator Migrator) HasIndex(value interface{}, name string) bool { + var count int64 + migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := migrator.DB.Migrator().CurrentDatabase() + return migrator.DB.Raw("SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, stmt.Table, name).Scan(&count).Error + }) + + if count != 0 { + return true + } + return false +} + +func (migrator Migrator) RenameIndex(value interface{}, oldName, newName string) error { + return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + return migrator.DB.Exec("ALTER TABLE ? RENAME INDEX ? TO ?", stmt.Table, oldName, newName).Error + }) +} + +func (migrator Migrator) CurrentDatabase() (name string) { + migrator.DB.Raw("SELECT DATABASE()").Scan(&name) + return +} diff --git a/statement.go b/statement.go index b2626d95..8c75c90d 100644 --- a/statement.go +++ b/statement.go @@ -267,3 +267,10 @@ func (stmt *Statement) Build(clauses ...string) { } // TODO handle named vars } + +func (stmt *Statement) Parse(value interface{}) (err error) { + if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { + stmt.Table = stmt.Schema.Table + } + return err +} From ad419855e96e405ee6597516d26e80524c786640 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 21 Feb 2020 23:51:38 +0800 Subject: [PATCH 239/881] Parse Indexes --- schema/index.go | 116 +++++++++++++++++++++++++++++++++++++++++++ schema/index_test.go | 96 +++++++++++++++++++++++++++++++++++ schema/naming.go | 20 +++++++- 3 files changed, 230 insertions(+), 2 deletions(-) create mode 100644 schema/index.go create mode 100644 schema/index_test.go diff --git a/schema/index.go b/schema/index.go new file mode 100644 index 00000000..ea3a68f5 --- /dev/null +++ b/schema/index.go @@ -0,0 +1,116 @@ +package schema + +import ( + "strconv" + "strings" +) + +type Index struct { + Name string + Class string // UNIQUE | FULLTEXT | SPATIAL + Fields []IndexOption +} + +type IndexOption struct { + *Field + Expression string + Sort string // DESC, ASC + Collate string + Length int + Type string // btree, hash, gist, spgist, gin, and brin + Where string + Comment string +} + +// ParseIndexes parse schema indexes +func (schema *Schema) ParseIndexes() map[string]Index { + var indexes = map[string]Index{} + + for _, field := range schema.FieldsByDBName { + if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE_INDEX"] != "" { + for _, index := range parseFieldIndexes(field) { + idx := indexes[index.Name] + idx.Name = index.Name + if idx.Class == "" { + idx.Class = index.Class + } + idx.Fields = append(idx.Fields, index.Fields...) + indexes[index.Name] = idx + } + } + } + + return indexes +} + +func parseFieldIndexes(field *Field) (indexes []Index) { + for _, value := range strings.Split(field.Tag.Get("gorm"), ";") { + if value != "" { + v := strings.Split(value, ":") + k := strings.TrimSpace(strings.ToUpper(v[0])) + if k == "INDEX" || k == "UNIQUE_INDEX" { + var ( + name string + tag = strings.Join(v[1:], ":") + settings = map[string]string{} + ) + + names := strings.Split(tag, ",") + for i := 0; i < len(names); i++ { + if len(names[i]) > 0 { + j := i + for { + if names[j][len(names[j])-1] == '\\' { + i++ + names[j] = names[j][0:len(names[j])-1] + names[i] + names[i] = "" + } else { + break + } + } + } + + if i == 0 { + name = names[0] + } + + values := strings.Split(names[i], ":") + k := strings.TrimSpace(strings.ToUpper(values[0])) + + if len(values) >= 2 { + settings[k] = strings.Join(values[1:], ":") + } else if k != "" { + settings[k] = k + } + } + + if name == "" { + name = field.Schema.namer.IndexName(field.Schema.Table, field.Name) + } + + length, _ := strconv.Atoi(settings["LENGTH"]) + + if (k == "UNIQUE_INDEX") || settings["UNIQUE"] != "" { + settings["CLASS"] = "UNIQUE" + } + + indexes = append(indexes, Index{ + Name: name, + Class: settings["CLASS"], + Fields: []IndexOption{{ + Field: field, + Expression: settings["EXPRESSION"], + Sort: settings["SORT"], + Collate: settings["COLLATE"], + Type: settings["TYPE"], + Length: length, + Where: settings["WHERE"], + Comment: settings["COMMENT"], + }}, + }) + } + } + } + + return +} diff --git a/schema/index_test.go b/schema/index_test.go new file mode 100644 index 00000000..8c2cb9fe --- /dev/null +++ b/schema/index_test.go @@ -0,0 +1,96 @@ +package schema_test + +import ( + "reflect" + "sync" + "testing" + + "github.com/jinzhu/gorm/schema" +) + +type UserIndex struct { + Name string `gorm:"index"` + Name2 string `gorm:"index:idx_name,unique"` + Name3 string `gorm:"index:,sort:desc,collate:utf8,type:btree,length:10,where:name3 != 'jinzhu'"` + Name4 string `gorm:"unique_index"` + Name5 int64 `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"` + Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"` + Age int64 `gorm:"index:profile,expression:(age+10)"` +} + +func TestParseIndex(t *testing.T) { + user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user index index, got error %v", err) + } + + results := map[string]schema.Index{ + "idx_user_indices_name": { + Name: "idx_user_indices_name", + Fields: []schema.IndexOption{{}}, + }, + "idx_name": { + Name: "idx_name", + Class: "UNIQUE", + Fields: []schema.IndexOption{{}}, + }, + "idx_user_indices_name3": { + Name: "idx_user_indices_name3", + Fields: []schema.IndexOption{{ + Sort: "desc", + Collate: "utf8", + Length: 10, + Type: "btree", + Where: "name3 != 'jinzhu'", + }}, + }, + "idx_user_indices_name4": { + Name: "idx_user_indices_name4", + Class: "UNIQUE", + Fields: []schema.IndexOption{{}}, + }, + "idx_user_indices_name5": { + Name: "idx_user_indices_name5", + Class: "FULLTEXT", + Fields: []schema.IndexOption{{ + Comment: "hello , world", + Where: "age > 10", + }}, + }, + "profile": { + Name: "profile", + Fields: []schema.IndexOption{{ + Comment: "hello , world", + Where: "age > 10", + }, { + Expression: "(age+10)", + }}, + }, + } + + indices := user.ParseIndexes() + + for k, result := range results { + v, ok := indices[k] + if !ok { + t.Errorf("Failed to found index %v from parsed indices %+v", k, indices) + } + + if result.Name != v.Name { + t.Errorf("index %v name should equal, expects %v, got %v", k, result.Name, v.Name) + } + + if result.Class != v.Class { + t.Errorf("index %v Class should equal, expects %v, got %v", k, result.Class, v.Class) + } + + for idx, ef := range result.Fields { + rf := v.Fields[idx] + for _, name := range []string{"Expression", "Sort", "Collate", "Length", "Type", "Where"} { + if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() { + t.Errorf("index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name, reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface()) + } + } + } + } +} diff --git a/schema/naming.go b/schema/naming.go index e6a5625e..80af4277 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -1,9 +1,11 @@ package schema import ( + "crypto/sha1" "fmt" "strings" "sync" + "unicode/utf8" "github.com/jinzhu/inflection" ) @@ -12,6 +14,7 @@ import ( type Namer interface { TableName(table string) string ColumnName(table, column string) string + IndexName(table, column string) string JoinTableName(table string) string } @@ -30,8 +33,21 @@ func (ns NamingStrategy) TableName(str string) string { } // ColumnName convert string to column name -func (ns NamingStrategy) ColumnName(table, str string) string { - return toDBName(str) +func (ns NamingStrategy) ColumnName(table, column string) string { + return toDBName(column) +} + +func (ns NamingStrategy) IndexName(table, column string) string { + idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column)) + + if utf8.RuneCountInString(idxName) > 64 { + h := sha1.New() + h.Write([]byte(idxName)) + bs := h.Sum(nil) + + idxName = fmt.Sprintf("idx%v%v", table, column)[0:56] + string(bs)[:8] + } + return idxName } // JoinTableName convert string to join table name From ea0b13f7a3aa58efcdea56566ef205e05a6d5867 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 22 Feb 2020 00:02:05 +0800 Subject: [PATCH 240/881] Refactor ParseTagSetting --- schema/field.go | 4 +-- schema/index.go | 70 ++++++++++++++---------------------- schema/index_test.go | 46 ++++++++++++------------ schema/schema_helper_test.go | 2 +- schema/utils.go | 37 ++++++++++++------- 5 files changed, 79 insertions(+), 80 deletions(-) diff --git a/schema/field.go b/schema/field.go index e4c80734..60cfc2ab 100644 --- a/schema/field.go +++ b/schema/field.go @@ -74,7 +74,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { Creatable: true, Updatable: true, Tag: fieldStruct.Tag, - TagSettings: ParseTagSetting(fieldStruct.Tag), + TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"), Schema: schema, } @@ -104,7 +104,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } // copy tag settings from valuer - for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag) { + for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { if _, ok := field.TagSettings[key]; !ok { field.TagSettings[key] = value } diff --git a/schema/index.go b/schema/index.go index ea3a68f5..26c7a558 100644 --- a/schema/index.go +++ b/schema/index.go @@ -6,9 +6,12 @@ import ( ) type Index struct { - Name string - Class string // UNIQUE | FULLTEXT | SPATIAL - Fields []IndexOption + Name string + Class string // UNIQUE | FULLTEXT | SPATIAL + Type string // btree, hash, gist, spgist, gin, and brin + Where string + Comment string + Fields []IndexOption } type IndexOption struct { @@ -17,9 +20,6 @@ type IndexOption struct { Sort string // DESC, ASC Collate string Length int - Type string // btree, hash, gist, spgist, gin, and brin - Where string - Comment string } // ParseIndexes parse schema indexes @@ -34,6 +34,15 @@ func (schema *Schema) ParseIndexes() map[string]Index { if idx.Class == "" { idx.Class = index.Class } + if idx.Type == "" { + idx.Type = index.Type + } + if idx.Where == "" { + idx.Where = index.Where + } + if idx.Comment == "" { + idx.Comment = index.Comment + } idx.Fields = append(idx.Fields, index.Fields...) indexes[index.Name] = idx } @@ -50,62 +59,37 @@ func parseFieldIndexes(field *Field) (indexes []Index) { k := strings.TrimSpace(strings.ToUpper(v[0])) if k == "INDEX" || k == "UNIQUE_INDEX" { var ( - name string - tag = strings.Join(v[1:], ":") - settings = map[string]string{} + name string + tag = strings.Join(v[1:], ":") + idx = strings.Index(tag, ",") + settings = ParseTagSetting(tag, ",") + length, _ = strconv.Atoi(settings["LENGTH"]) ) - names := strings.Split(tag, ",") - for i := 0; i < len(names); i++ { - if len(names[i]) > 0 { - j := i - for { - if names[j][len(names[j])-1] == '\\' { - i++ - names[j] = names[j][0:len(names[j])-1] + names[i] - names[i] = "" - } else { - break - } - } - } - - if i == 0 { - name = names[0] - } - - values := strings.Split(names[i], ":") - k := strings.TrimSpace(strings.ToUpper(values[0])) - - if len(values) >= 2 { - settings[k] = strings.Join(values[1:], ":") - } else if k != "" { - settings[k] = k - } + if idx != -1 { + name = tag[0:idx] } if name == "" { name = field.Schema.namer.IndexName(field.Schema.Table, field.Name) } - length, _ := strconv.Atoi(settings["LENGTH"]) - if (k == "UNIQUE_INDEX") || settings["UNIQUE"] != "" { settings["CLASS"] = "UNIQUE" } indexes = append(indexes, Index{ - Name: name, - Class: settings["CLASS"], + Name: name, + Class: settings["CLASS"], + Type: settings["TYPE"], + Where: settings["WHERE"], + Comment: settings["COMMENT"], Fields: []IndexOption{{ Field: field, Expression: settings["EXPRESSION"], Sort: settings["SORT"], Collate: settings["COLLATE"], - Type: settings["TYPE"], Length: length, - Where: settings["WHERE"], - Comment: settings["COMMENT"], }}, }) } diff --git a/schema/index_test.go b/schema/index_test.go index 8c2cb9fe..d9595ae6 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -35,13 +35,13 @@ func TestParseIndex(t *testing.T) { Fields: []schema.IndexOption{{}}, }, "idx_user_indices_name3": { - Name: "idx_user_indices_name3", + Name: "idx_user_indices_name3", + Type: "btree", + Where: "name3 != 'jinzhu'", Fields: []schema.IndexOption{{ Sort: "desc", Collate: "utf8", Length: 10, - Type: "btree", - Where: "name3 != 'jinzhu'", }}, }, "idx_user_indices_name4": { @@ -50,19 +50,17 @@ func TestParseIndex(t *testing.T) { Fields: []schema.IndexOption{{}}, }, "idx_user_indices_name5": { - Name: "idx_user_indices_name5", - Class: "FULLTEXT", - Fields: []schema.IndexOption{{ - Comment: "hello , world", - Where: "age > 10", - }}, + Name: "idx_user_indices_name5", + Class: "FULLTEXT", + Comment: "hello , world", + Where: "age > 10", + Fields: []schema.IndexOption{{}}, }, "profile": { - Name: "profile", - Fields: []schema.IndexOption{{ - Comment: "hello , world", - Where: "age > 10", - }, { + Name: "profile", + Comment: "hello , world", + Where: "age > 10", + Fields: []schema.IndexOption{{}, { Expression: "(age+10)", }}, }, @@ -76,19 +74,23 @@ func TestParseIndex(t *testing.T) { t.Errorf("Failed to found index %v from parsed indices %+v", k, indices) } - if result.Name != v.Name { - t.Errorf("index %v name should equal, expects %v, got %v", k, result.Name, v.Name) - } - - if result.Class != v.Class { - t.Errorf("index %v Class should equal, expects %v, got %v", k, result.Class, v.Class) + for _, name := range []string{"Name", "Class", "Type", "Where", "Comment"} { + if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() { + t.Errorf( + "index %v %v should equal, expects %v, got %v", + k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(), + ) + } } for idx, ef := range result.Fields { rf := v.Fields[idx] - for _, name := range []string{"Expression", "Sort", "Collate", "Length", "Type", "Where"} { + for _, name := range []string{"Expression", "Sort", "Collate", "Length"} { if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() { - t.Errorf("index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name, reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface()) + t.Errorf( + "index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name, + reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface(), + ) } } } diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 60e51543..196d19c4 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -44,7 +44,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* if f.TagSettings == nil { if f.Tag != "" { - f.TagSettings = schema.ParseTagSetting(f.Tag) + f.TagSettings = schema.ParseTagSetting(f.Tag.Get("gorm"), ";") } else { f.TagSettings = map[string]string{} } diff --git a/schema/utils.go b/schema/utils.go index 4774fd75..d7572d3d 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -6,22 +6,35 @@ import ( "strings" ) -func ParseTagSetting(tags reflect.StructTag) map[string]string { - setting := map[string]string{} +func ParseTagSetting(str string, sep string) map[string]string { + settings := map[string]string{} + names := strings.Split(str, sep) - for _, value := range strings.Split(tags.Get("gorm"), ";") { - if value != "" { - v := strings.Split(value, ":") - k := strings.TrimSpace(strings.ToUpper(v[0])) - - if len(v) >= 2 { - setting[k] = strings.Join(v[1:], ":") - } else { - setting[k] = k + for i := 0; i < len(names); i++ { + j := i + if len(names[j]) > 0 { + for { + if names[j][len(names[j])-1] == '\\' { + i++ + names[j] = names[j][0:len(names[j])-1] + sep + names[i] + names[i] = "" + } else { + break + } } } + + values := strings.Split(names[j], ":") + k := strings.TrimSpace(strings.ToUpper(values[0])) + + if len(values) >= 2 { + settings[k] = strings.Join(values[1:], ":") + } else if k != "" { + settings[k] = k + } } - return setting + + return settings } func checkTruth(val string) bool { From 0be4817ff9cb1c79eb0d8aa800f59e0c11df7b9d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 22 Feb 2020 11:15:51 +0800 Subject: [PATCH 241/881] Finish CreateConstraint --- clause/expression.go | 2 +- migrator/migrator.go | 152 ++++++++++++++++++++++++++++++++++++++--- schema/check.go | 29 ++++++++ schema/index_test.go | 4 +- schema/naming.go | 25 +++++-- schema/relationship.go | 49 +++++++++++++ 6 files changed, 241 insertions(+), 20 deletions(-) create mode 100644 schema/check.go diff --git a/clause/expression.go b/clause/expression.go index 048b0980..6b3575df 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -22,7 +22,7 @@ type Expr struct { func (expr Expr) Build(builder Builder) { sql := expr.SQL for _, v := range expr.Vars { - sql = strings.Replace(sql, " ? ", " "+builder.AddVar(v)+" ", 1) + sql = strings.Replace(sql, " ?", " "+builder.AddVar(v), 1) } builder.Write(sql) } diff --git a/migrator/migrator.go b/migrator/migrator.go index e9725935..fc93954e 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" ) // Migrator migrator struct @@ -33,17 +34,25 @@ func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement // AutoMigrate func (migrator Migrator) AutoMigrate(values ...interface{}) error { + // if has table + // not -> create table + // check columns -> add column, change column type + // check foreign keys -> create indexes + // check indexes -> create indexes + return gorm.ErrNotImplemented } func (migrator Migrator) CreateTable(values ...interface{}) error { + // migrate + // create join table return gorm.ErrNotImplemented } func (migrator Migrator) DropTable(values ...interface{}) error { for _, value := range values { if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - return migrator.DB.Exec("DROP TABLE " + stmt.Quote(stmt.Table)).Error + return migrator.DB.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error }); err != nil { return err } @@ -74,7 +83,10 @@ func (migrator Migrator) RenameTable(oldName, newName string) error { func (migrator Migrator) AddColumn(value interface{}, field string) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - return migrator.DB.Exec(fmt.Sprintf("ALTER TABLE ? ADD ? %s", field.DBDataType), stmt.Table, field.DBName).Error + return migrator.DB.Exec( + "ALTER TABLE ? ADD ? ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType}, + ).Error } return fmt.Errorf("failed to look up field with name: %s", field) }) @@ -83,7 +95,9 @@ func (migrator Migrator) AddColumn(value interface{}, field string) error { func (migrator Migrator) DropColumn(value interface{}, field string) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - return migrator.DB.Exec("ALTER TABLE ? DROP COLUMN ?", stmt.Table, field.DBName).Error + return migrator.DB.Exec( + "ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, + ).Error } return fmt.Errorf("failed to look up field with name: %s", field) }) @@ -92,7 +106,10 @@ func (migrator Migrator) DropColumn(value interface{}, field string) error { func (migrator Migrator) AlterColumn(value interface{}, field string) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - return migrator.DB.Exec(fmt.Sprintf("ALTER TABLE ? ALTER COLUMN ? TYPE %s", field.DBDataType), stmt.Table, field.DBName).Error + return migrator.DB.Exec( + "ALTER TABLE ? ALTER COLUMN ? TYPE ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType}, + ).Error } return fmt.Errorf("failed to look up field with name: %s", field) }) @@ -102,7 +119,10 @@ func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { oldName = migrator.DB.NamingStrategy.ColumnName(stmt.Table, oldName) - return migrator.DB.Exec("ALTER TABLE ? RENAME COLUMN ? TO ?", stmt.Table, oldName, field.DBName).Error + return migrator.DB.Exec( + "ALTER TABLE ? RENAME COLUMN ? TO ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: field.DBName}, + ).Error } return fmt.Errorf("failed to look up field with name: %s", field) }) @@ -121,22 +141,126 @@ func (migrator Migrator) DropView(name string) error { } func (migrator Migrator) CreateConstraint(value interface{}, name string) error { - return gorm.ErrNotImplemented + return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + checkConstraints := stmt.Schema.ParseCheckConstraints() + if chk, ok := checkConstraints[name]; ok { + return migrator.DB.Exec( + "ALTER TABLE ? ADD CONSTRAINT ? CHECK ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, + ).Error + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { + sql := "ALTER TABLE ? ADD CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" + if constraint.OnDelete != "" { + sql += " ON DELETE " + constraint.OnDelete + } + + if constraint.OnUpdate != "" { + sql += " ON UPDATE " + constraint.OnUpdate + } + var foreignKeys, references []interface{} + for _, field := range constraint.ForeignKeys { + foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) + } + + for _, field := range constraint.References { + references = append(references, clause.Column{Name: field.DBName}) + } + + return migrator.DB.Exec( + sql, clause.Table{Name: stmt.Table}, clause.Column{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references, + ).Error + } + } + + err := fmt.Errorf("failed to create constraint with name %v", name) + if field := stmt.Schema.LookUpField(name); field != nil { + for _, cc := range checkConstraints { + if err = migrator.CreateIndex(value, cc.Name); err != nil { + return err + } + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field { + if err = migrator.CreateIndex(value, constraint.Name); err != nil { + return err + } + } + } + } + + return err + }) } func (migrator Migrator) DropConstraint(value interface{}, name string) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - return migrator.DB.Raw("ALTER TABLE ? DROP CONSTRAINT ?", stmt.Table, name).Error + return migrator.DB.Exec( + "ALTER TABLE ? DROP CONSTRAINT ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: name}, + ).Error }) } func (migrator Migrator) CreateIndex(value interface{}, name string) error { - return gorm.ErrNotImplemented + return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + err := fmt.Errorf("failed to create index with name %v", name) + indexes := stmt.Schema.ParseIndexes() + + if idx, ok := indexes[name]; ok { + fields := []interface{}{} + for _, field := range idx.Fields { + str := stmt.Quote(field.DBName) + if field.Expression != "" { + str = field.Expression + } else if field.Length > 0 { + str += fmt.Sprintf("(%d)", field.Length) + } + + if field.Sort != "" { + str += " " + field.Sort + } + fields = append(fields, clause.Expr{SQL: str}) + } + values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, fields} + + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX ? ON ??" + + if idx.Comment != "" { + values = append(values, idx.Comment) + createIndexSQL += " COMMENT ?" + } + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + } + + return migrator.DB.Raw(createIndexSQL, values...).Error + } else if field := stmt.Schema.LookUpField(name); field != nil { + for _, idx := range indexes { + for _, idxOpt := range idx.Fields { + if idxOpt.Field == field { + if err = migrator.CreateIndex(value, idx.Name); err != nil { + return err + } + } + } + } + } + return err + }) } func (migrator Migrator) DropIndex(value interface{}, name string) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - return migrator.DB.Raw("DROP INDEX ? ON ?", name, stmt.Table).Error + return migrator.DB.Raw("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error }) } @@ -144,7 +268,10 @@ func (migrator Migrator) HasIndex(value interface{}, name string) bool { var count int64 migrator.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := migrator.DB.Migrator().CurrentDatabase() - return migrator.DB.Raw("SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, stmt.Table, name).Scan(&count).Error + return migrator.DB.Raw( + "SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", + currentDatabase, stmt.Table, name, + ).Scan(&count).Error }) if count != 0 { @@ -155,7 +282,10 @@ func (migrator Migrator) HasIndex(value interface{}, name string) bool { func (migrator Migrator) RenameIndex(value interface{}, oldName, newName string) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - return migrator.DB.Exec("ALTER TABLE ? RENAME INDEX ? TO ?", stmt.Table, oldName, newName).Error + return migrator.DB.Exec( + "ALTER TABLE ? RENAME INDEX ? TO ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, + ).Error }) } diff --git a/schema/check.go b/schema/check.go new file mode 100644 index 00000000..a06ac67b --- /dev/null +++ b/schema/check.go @@ -0,0 +1,29 @@ +package schema + +import ( + "regexp" + "strings" +) + +type Check struct { + Name string + Constraint string // length(phone) >= 10 + *Field +} + +// ParseCheckConstraints parse schema check constraints +func (schema *Schema) ParseCheckConstraints() map[string]Check { + var checks = map[string]Check{} + for _, field := range schema.FieldsByDBName { + if chk := field.TagSettings["CHECK"]; chk != "" { + names := strings.Split(chk, ",") + if len(names) > 1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(names[0]) { + checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} + } else { + name := schema.namer.CheckerName(schema.Table, field.DBName) + checks[name] = Check{Name: name, Constraint: chk, Field: field} + } + } + } + return checks +} diff --git a/schema/index_test.go b/schema/index_test.go index d9595ae6..1409b9c4 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -15,7 +15,7 @@ type UserIndex struct { Name4 string `gorm:"unique_index"` Name5 int64 `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"` Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"` - Age int64 `gorm:"index:profile,expression:(age+10)"` + Age int64 `gorm:"index:profile,expression:ABS(age)"` } func TestParseIndex(t *testing.T) { @@ -61,7 +61,7 @@ func TestParseIndex(t *testing.T) { Comment: "hello , world", Where: "age > 10", Fields: []schema.IndexOption{{}, { - Expression: "(age+10)", + Expression: "ABS(age)", }}, }, } diff --git a/schema/naming.go b/schema/naming.go index 80af4277..d6f26e9f 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -14,8 +14,10 @@ import ( type Namer interface { TableName(table string) string ColumnName(table, column string) string - IndexName(table, column string) string JoinTableName(table string) string + RelationshipFKName(Relationship) string + CheckerName(table, column string) string + IndexName(table, column string) string } // NamingStrategy tables, columns naming strategy @@ -37,6 +39,22 @@ func (ns NamingStrategy) ColumnName(table, column string) string { return toDBName(column) } +// JoinTableName convert string to join table name +func (ns NamingStrategy) JoinTableName(str string) string { + return ns.TablePrefix + inflection.Plural(toDBName(str)) +} + +// RelationshipFKName generate fk name for relation +func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { + return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, rel.FieldSchema.Table) +} + +// CheckerName generate checker name +func (ns NamingStrategy) CheckerName(table, column string) string { + return fmt.Sprintf("chk_%s_%s", table, column) +} + +// IndexName generate index name func (ns NamingStrategy) IndexName(table, column string) string { idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column)) @@ -50,11 +68,6 @@ func (ns NamingStrategy) IndexName(table, column string) string { return idxName } -// JoinTableName convert string to join table name -func (ns NamingStrategy) JoinTableName(str string) string { - return ns.TablePrefix + inflection.Plural(toDBName(str)) -} - var ( smap sync.Map // https://github.com/golang/lint/blob/master/lint.go#L770 diff --git a/schema/relationship.go b/schema/relationship.go index 671371fe..8081b0e7 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -3,6 +3,7 @@ package schema import ( "fmt" "reflect" + "regexp" "strings" "github.com/jinzhu/inflection" @@ -292,3 +293,51 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH relation.Type = BelongsTo } } + +type Constraint struct { + Name string + Field *Field + Schema *Schema + ForeignKeys []*Field + ReferenceSchema *Schema + References []*Field + OnDelete string + OnUpdate string +} + +func (rel *Relationship) ParseConstraint() *Constraint { + str := rel.Field.TagSettings["CONSTRAINT"] + if str == "-" { + return nil + } + + var ( + name string + idx = strings.Index(str, ",") + settings = ParseTagSetting(str, ",") + ) + + if idx != -1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(str[0:idx]) { + name = str[0:idx] + } else { + name = rel.Schema.namer.RelationshipFKName(*rel) + } + + constraint := Constraint{ + Name: name, + Field: rel.Field, + OnUpdate: settings["ONUPDATE"], + OnDelete: settings["ONDELETE"], + Schema: rel.Schema, + } + + for _, ref := range rel.References { + if ref.PrimaryKey != nil && !ref.OwnPrimaryKey { + constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey) + constraint.References = append(constraint.References, ref.PrimaryKey) + constraint.ReferenceSchema = ref.PrimaryKey.Schema + } + } + + return &constraint +} From 0801cdf164acccb50892ee3f27d1e55db51289e8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 22 Feb 2020 13:09:57 +0800 Subject: [PATCH 242/881] Almost finish Migrator --- migrator.go | 2 + migrator/migrator.go | 250 +++++++++++++++++++++++++++++++++++-------- 2 files changed, 208 insertions(+), 44 deletions(-) diff --git a/migrator.go b/migrator.go index b6d273e7..a5ea4d8f 100644 --- a/migrator.go +++ b/migrator.go @@ -33,6 +33,7 @@ type Migrator interface { AddColumn(dst interface{}, field string) error DropColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error + HasColumn(dst interface{}, field string) bool RenameColumn(dst interface{}, oldName, field string) error ColumnTypes(dst interface{}) ([]*sql.ColumnType, error) @@ -43,6 +44,7 @@ type Migrator interface { // Constraints CreateConstraint(dst interface{}, name string) error DropConstraint(dst interface{}, name string) error + HasConstraint(dst interface{}, name string) bool // Indexes CreateIndex(dst interface{}, name string) error diff --git a/migrator/migrator.go b/migrator/migrator.go index fc93954e..7e749037 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -3,9 +3,12 @@ package migrator import ( "database/sql" "fmt" + "reflect" + "strings" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" ) // Migrator migrator struct @@ -34,19 +37,133 @@ func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement // AutoMigrate func (migrator Migrator) AutoMigrate(values ...interface{}) error { - // if has table - // not -> create table - // check columns -> add column, change column type - // check foreign keys -> create indexes - // check indexes -> create indexes + // TODO smart migrate data type - return gorm.ErrNotImplemented + for _, value := range values { + if !migrator.DB.Migrator().HasTable(value) { + if err := migrator.DB.Migrator().CreateTable(value); err != nil { + return err + } + } else { + if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + for _, field := range stmt.Schema.FieldsByDBName { + if !migrator.DB.Migrator().HasColumn(value, field.DBName) { + if err := migrator.DB.Migrator().AddColumn(value, field.DBName); err != nil { + return err + } + } + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil { + if !migrator.DB.Migrator().HasConstraint(value, constraint.Name) { + if err := migrator.DB.Migrator().CreateConstraint(value, constraint.Name); err != nil { + return err + } + } + } + + for _, chk := range stmt.Schema.ParseCheckConstraints() { + if !migrator.DB.Migrator().HasConstraint(value, chk.Name) { + if err := migrator.DB.Migrator().CreateConstraint(value, chk.Name); err != nil { + return err + } + } + } + + // create join table + joinValue := reflect.New(rel.JoinTable.ModelType).Interface() + if !migrator.DB.Migrator().HasTable(joinValue) { + defer migrator.DB.Migrator().CreateTable(joinValue) + } + } + return nil + }); err != nil { + return err + } + } + } + + return nil } func (migrator Migrator) CreateTable(values ...interface{}) error { - // migrate - // create join table - return gorm.ErrNotImplemented + for _, value := range values { + if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + var ( + createTableSQL = "CREATE TABLE ? (" + values = []interface{}{clause.Table{Name: stmt.Table}} + hasPrimaryKeyInDataType bool + ) + + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.FieldsByDBName[dbName] + createTableSQL += fmt.Sprintf("? ?") + hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY") + values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: field.DBDataType}) + + if field.AutoIncrement { + createTableSQL += " AUTO_INCREMENT" + } + + if field.NotNull { + createTableSQL += " NOT NULL" + } + + if field.Unique { + createTableSQL += " UNIQUE" + } + + if field.DefaultValue != "" { + createTableSQL += " DEFAULT ?" + values = append(values, clause.Expr{SQL: field.DefaultValue}) + } + createTableSQL += "," + } + + if !hasPrimaryKeyInDataType { + createTableSQL += "PRIMARY KEY ?," + primaryKeys := []interface{}{} + for _, field := range stmt.Schema.PrimaryFields { + primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName}) + } + + values = append(values, primaryKeys) + } + + for _, idx := range stmt.Schema.ParseIndexes() { + createTableSQL += "INDEX ? ?," + values = append(values, clause.Expr{SQL: idx.Name}, buildIndexOptions(idx.Fields, stmt)) + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil { + sql, vars := buildConstraint(constraint) + createTableSQL += sql + "," + values = append(values, vars...) + } + + // create join table + joinValue := reflect.New(rel.JoinTable.ModelType).Interface() + if !migrator.DB.Migrator().HasTable(joinValue) { + defer migrator.DB.Migrator().CreateTable(joinValue) + } + } + + for _, chk := range stmt.Schema.ParseCheckConstraints() { + createTableSQL += "CONSTRAINT ? CHECK ?," + values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) + } + + createTableSQL = strings.TrimSuffix(createTableSQL, ",") + + createTableSQL += ")" + return migrator.DB.Exec(createTableSQL, values...).Error + }); err != nil { + return err + } + } + return nil } func (migrator Migrator) DropTable(values ...interface{}) error { @@ -115,6 +232,27 @@ func (migrator Migrator) AlterColumn(value interface{}, field string) error { }) } +func (migrator Migrator) HasColumn(value interface{}, field string) bool { + var count int64 + migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := migrator.DB.Migrator().CurrentDatabase() + name := field + if field := stmt.Schema.LookUpField(field); field != nil { + name = field.DBName + } + + return migrator.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", + currentDatabase, stmt.Table, name, + ).Scan(&count).Error + }) + + if count != 0 { + return true + } + return false +} + func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { @@ -140,6 +278,28 @@ func (migrator Migrator) DropView(name string) error { return gorm.ErrNotImplemented } +func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) { + sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" + if constraint.OnDelete != "" { + sql += " ON DELETE " + constraint.OnDelete + } + + if constraint.OnUpdate != "" { + sql += " ON UPDATE " + constraint.OnUpdate + } + + var foreignKeys, references []interface{} + for _, field := range constraint.ForeignKeys { + foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) + } + + for _, field := range constraint.References { + references = append(references, clause.Column{Name: field.DBName}) + } + results = append(results, constraint.Name, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) + return +} + func (migrator Migrator) CreateConstraint(value interface{}, name string) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { checkConstraints := stmt.Schema.ParseCheckConstraints() @@ -152,26 +312,8 @@ func (migrator Migrator) CreateConstraint(value interface{}, name string) error for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { - sql := "ALTER TABLE ? ADD CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" - if constraint.OnDelete != "" { - sql += " ON DELETE " + constraint.OnDelete - } - - if constraint.OnUpdate != "" { - sql += " ON UPDATE " + constraint.OnUpdate - } - var foreignKeys, references []interface{} - for _, field := range constraint.ForeignKeys { - foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) - } - - for _, field := range constraint.References { - references = append(references, clause.Column{Name: field.DBName}) - } - - return migrator.DB.Exec( - sql, clause.Table{Name: stmt.Table}, clause.Column{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references, - ).Error + sql, values := buildConstraint(constraint) + return migrator.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error } } @@ -205,27 +347,47 @@ func (migrator Migrator) DropConstraint(value interface{}, name string) error { }) } +func (migrator Migrator) HasConstraint(value interface{}, name string) bool { + var count int64 + migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := migrator.DB.Migrator().CurrentDatabase() + return migrator.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.referential_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", + currentDatabase, stmt.Table, name, + ).Scan(&count).Error + }) + + if count != 0 { + return true + } + return false +} + +func buildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { + for _, opt := range opts { + str := stmt.Quote(opt.DBName) + if opt.Expression != "" { + str = opt.Expression + } else if opt.Length > 0 { + str += fmt.Sprintf("(%d)", opt.Length) + } + + if opt.Sort != "" { + str += " " + opt.Sort + } + results = append(results, clause.Expr{SQL: str}) + } + return +} + func (migrator Migrator) CreateIndex(value interface{}, name string) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { err := fmt.Errorf("failed to create index with name %v", name) indexes := stmt.Schema.ParseIndexes() if idx, ok := indexes[name]; ok { - fields := []interface{}{} - for _, field := range idx.Fields { - str := stmt.Quote(field.DBName) - if field.Expression != "" { - str = field.Expression - } else if field.Length > 0 { - str += fmt.Sprintf("(%d)", field.Length) - } - - if field.Sort != "" { - str += " " + field.Sort - } - fields = append(fields, clause.Expr{SQL: str}) - } - values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, fields} + opts := buildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} createIndexSQL := "CREATE " if idx.Class != "" { From fab7d96da5d0308a77684acb9b39eb558b6ea58e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 22 Feb 2020 17:53:57 +0800 Subject: [PATCH 243/881] Add DataTypeOf for dialector --- dialects/mssql/migrator.go | 37 ++++++ dialects/mssql/mssql.go | 75 ++++++++++++ dialects/mysql/migrator.go | 43 +++++++ dialects/mysql/mysql.go | 83 ++++++++++++- dialects/postgres/migrator.go | 89 ++++++++++++++ dialects/postgres/postgres.go | 51 +++++++- dialects/sqlite/migrator.go | 122 +++++++++++++++++++ dialects/sqlite/sqlite.go | 32 ++++- interfaces.go | 5 +- migrator.go | 4 +- migrator/migrator.go | 223 +++++++++++++++++----------------- schema/field.go | 5 +- 12 files changed, 640 insertions(+), 129 deletions(-) create mode 100644 dialects/mssql/migrator.go create mode 100644 dialects/mssql/mssql.go create mode 100644 dialects/mysql/migrator.go create mode 100644 dialects/postgres/migrator.go create mode 100644 dialects/sqlite/migrator.go diff --git a/dialects/mssql/migrator.go b/dialects/mssql/migrator.go new file mode 100644 index 00000000..43eaf573 --- /dev/null +++ b/dialects/mssql/migrator.go @@ -0,0 +1,37 @@ +package mssql + +import ( + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/migrator" +) + +type Migrator struct { + migrator.Migrator +} + +func (m Migrator) HasIndex(value interface{}, name string) bool { + var count int + m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw( + "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", + name, stmt.Table, + ).Row().Scan(&count) + }) + return count > 0 +} + +func (m Migrator) HasConstraint(value interface{}, name string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw( + `SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join information_schema.tables as I on I.TABLE_NAME = T.name WHERE F.name = ? AND T.Name = ? AND I.TABLE_CATALOG = ?;`, + name, stmt.Table, m.CurrentDatabase(), + ).Row().Scan(&count) + }) + return count > 0 +} + +func (m Migrator) CurrentDatabase() (name string) { + m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name) + return +} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go new file mode 100644 index 00000000..bdca667d --- /dev/null +++ b/dialects/mssql/mssql.go @@ -0,0 +1,75 @@ +package mssql + +import ( + "database/sql" + "fmt" + + _ "github.com/denisenkom/go-mssqldb" + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/migrator" + "github.com/jinzhu/gorm/schema" +) + +type Dialector struct { + DSN string +} + +func Open(dsn string) gorm.Dialector { + return &Dialector{DSN: dsn} +} + +func (dialector Dialector) Initialize(db *gorm.DB) (err error) { + // register callbacks + callbacks.RegisterDefaultCallbacks(db) + + db.DB, err = sql.Open("sqlserver", dialector.DSN) + return +} + +func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { + return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} +} + +func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { + return "?" +} + +func (dialector Dialector) QuoteChars() [2]byte { + return [2]byte{'[', ']'} // `name` +} + +func (dialector Dialector) DataTypeOf(field *schema.Field) string { + switch field.DataType { + case schema.Bool: + return "bit" + case schema.Int, schema.Uint: + var sqlType string + switch { + case field.Size < 16: + sqlType = "smallint" + case field.Size < 31: + sqlType = "int" + default: + sqlType = "bigint" + } + + if field.AutoIncrement { + return sqlType + " IDENTITY(1,1)" + } + return sqlType + case schema.Float: + return "decimal" + case schema.String: + if field.Size > 0 && field.Size <= 4000 { + return fmt.Sprintf("nvarchar(%d)", field.Size) + } + return "ntext" + case schema.Time: + return "datetimeoffset" + case schema.Bytes: + return "binary" + } + + return "" +} diff --git a/dialects/mysql/migrator.go b/dialects/mysql/migrator.go new file mode 100644 index 00000000..2c11af94 --- /dev/null +++ b/dialects/mysql/migrator.go @@ -0,0 +1,43 @@ +package mysql + +import ( + "fmt" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/migrator" +) + +type Migrator struct { + migrator.Migrator +} + +func (m Migrator) AlterColumn(value interface{}, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(field); field != nil { + return m.DB.Exec( + "ALTER TABLE ? MODIFY COLUMN ? TYPE ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType}, + ).Error + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + +func (m Migrator) DropConstraint(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + for _, chk := range stmt.Schema.ParseCheckConstraints() { + if chk.Name == name { + return m.DB.Exec( + "ALTER TABLE ? DROP CHECK ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: name}, + ).Error + } + } + + return m.DB.Exec( + "ALTER TABLE ? DROP FOREIGN KEY ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: name}, + ).Error + }) +} diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index b402ef95..e2fea53c 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -1,33 +1,104 @@ package mysql import ( + "database/sql" + "fmt" + "math" + _ "github.com/go-sql-driver/mysql" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/migrator" + "github.com/jinzhu/gorm/schema" ) type Dialector struct { + DSN string } func Open(dsn string) gorm.Dialector { - return &Dialector{} + return &Dialector{DSN: dsn} } -func (Dialector) Initialize(db *gorm.DB) error { +func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db) + db.DB, err = sql.Open("sqlite3", dialector.DSN) return nil } -func (Dialector) Migrator() gorm.Migrator { - return nil +func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { + return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} } -func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string { +func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } -func (Dialector) QuoteChars() [2]byte { +func (dialector Dialector) QuoteChars() [2]byte { return [2]byte{'`', '`'} // `name` } + +func (dialector Dialector) DataTypeOf(field *schema.Field) string { + switch field.DataType { + case schema.Bool: + return "boolean" + case schema.Int, schema.Uint: + sqlType := "int" + switch { + case field.Size <= 8: + sqlType = "tinyint" + case field.Size <= 16: + sqlType = "smallint" + case field.Size <= 32: + sqlType = "int" + default: + sqlType = "bigint" + } + + if field.DataType == schema.Uint { + sqlType += " unsigned" + } + + if field.AutoIncrement { + sqlType += " AUTO_INCREMENT" + } + return sqlType + case schema.Float: + if field.Size <= 32 { + return "float" + } + return "double" + case schema.String: + size := field.Size + if size >= 65536 && size <= int(math.Pow(2, 24)) { + return "mediumtext" + } else if size > int(math.Pow(2, 24)) || size < 0 { + return "longtext" + } + return fmt.Sprintf("varchar(%d)", size) + case schema.Time: + precision := "" + if field.Precision > 0 { + precision = fmt.Sprintf("(%d)", field.Precision) + } + + if field.NotNull || field.PrimaryKey { + return "datetime" + precision + } + return "datetime" + precision + " NULL" + case schema.Bytes: + if field.Size > 0 && field.Size < 65536 { + return fmt.Sprintf("varbinary(%d)", field.Size) + } + + if field.Size >= 65536 && field.Size <= int(math.Pow(2, 24)) { + return "mediumblob" + } + + return "longblob" + } + + return "" +} diff --git a/dialects/postgres/migrator.go b/dialects/postgres/migrator.go new file mode 100644 index 00000000..35101bf3 --- /dev/null +++ b/dialects/postgres/migrator.go @@ -0,0 +1,89 @@ +package postgres + +import ( + "fmt" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/migrator" + "github.com/jinzhu/gorm/schema" +) + +type Migrator struct { + migrator.Migrator +} + +func (m Migrator) CurrentDatabase() (name string) { + m.DB.Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name) + return +} + +func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { + for _, opt := range opts { + str := stmt.Quote(opt.DBName) + if opt.Expression != "" { + str = opt.Expression + } + + if opt.Collate != "" { + str += " COLLATE " + opt.Collate + } + + if opt.Sort != "" { + str += " " + opt.Sort + } + results = append(results, clause.Expr{SQL: str}) + } + return +} + +func (m Migrator) HasIndex(value interface{}, indexName string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw( + "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, indexName, + ).Row().Scan(&count) + }) + + return count > 0 +} + +func (m Migrator) CreateIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + err := fmt.Errorf("failed to create index with name %v", name) + indexes := stmt.Schema.ParseIndexes() + + if idx, ok := indexes[name]; ok { + opts := m.BuildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} + + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX ?" + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + } + createIndexSQL += " ON ??" + + if idx.Where != "" { + createIndexSQL += " WHERE " + idx.Where + } + + return m.DB.Exec(createIndexSQL, values...).Error + } else if field := stmt.Schema.LookUpField(name); field != nil { + for _, idx := range indexes { + for _, idxOpt := range idx.Fields { + if idxOpt.Field == field { + if err = m.CreateIndex(value, idx.Name); err != nil { + return err + } + } + } + } + } + return err + }) +} diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 9ea0048a..a3eeefb9 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -2,9 +2,12 @@ package postgres import ( "database/sql" + "fmt" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/migrator" + "github.com/jinzhu/gorm/schema" _ "github.com/lib/pq" ) @@ -24,14 +27,54 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { return } -func (Dialector) Migrator() gorm.Migrator { - return nil +func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { + return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} } -func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { +func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } -func (Dialector) QuoteChars() [2]byte { +func (dialector Dialector) QuoteChars() [2]byte { return [2]byte{'"', '"'} // "name" } + +func (dialector Dialector) DataTypeOf(field *schema.Field) string { + switch field.DataType { + case schema.Bool: + return "boolean" + case schema.Int, schema.Uint: + if field.AutoIncrement { + switch { + case field.Size < 16: + return "smallserial" + case field.Size < 31: + return "serial" + default: + return "bigserial" + } + } else { + switch { + case field.Size < 16: + return "smallint" + case field.Size < 31: + return "integer" + default: + return "bigint" + } + } + case schema.Float: + return "decimal" + case schema.String: + if field.Size > 0 { + return fmt.Sprintf("varchar(%d)", field.Size) + } + return "text" + case schema.Time: + return "timestamp with time zone" + case schema.Bytes: + return "bytea" + } + + return "" +} diff --git a/dialects/sqlite/migrator.go b/dialects/sqlite/migrator.go new file mode 100644 index 00000000..07e189ad --- /dev/null +++ b/dialects/sqlite/migrator.go @@ -0,0 +1,122 @@ +package sqlite + +import ( + "fmt" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/migrator" + "github.com/jinzhu/gorm/schema" +) + +type Migrator struct { + migrator.Migrator +} + +func (m Migrator) HasTable(value interface{}) bool { + var count int + m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count) + }) + return count > 0 +} + +func (m Migrator) HasColumn(value interface{}, field string) bool { + var count int + m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + name := field + if field := stmt.Schema.LookUpField(field); field != nil { + name = field.DBName + } + + return m.DB.Raw( + "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ?)", + stmt.Table, `%"`+name+`" %`, `%`+name+` %`, + ).Row().Scan(&count) + }) + return count > 0 +} + +func (m Migrator) HasIndex(value interface{}, name string) bool { + var count int + m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw( + "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE ?", + stmt.Table, "%INDEX "+name+" ON%", + ).Row().Scan(&count) + }) + return count > 0 +} + +func (m Migrator) CreateConstraint(interface{}, string) error { + return gorm.ErrNotImplemented +} + +func (m Migrator) DropConstraint(interface{}, string) error { + return gorm.ErrNotImplemented +} + +func (m Migrator) CurrentDatabase() (name string) { + var null interface{} + m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null) + return +} + +func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { + for _, opt := range opts { + str := stmt.Quote(opt.DBName) + if opt.Expression != "" { + str = opt.Expression + } + + if opt.Collate != "" { + str += " COLLATE " + opt.Collate + } + + if opt.Sort != "" { + str += " " + opt.Sort + } + results = append(results, clause.Expr{SQL: str}) + } + return +} + +func (m Migrator) CreateIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + err := fmt.Errorf("failed to create index with name %v", name) + indexes := stmt.Schema.ParseIndexes() + + if idx, ok := indexes[name]; ok { + opts := m.BuildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} + + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX ?" + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + } + createIndexSQL += " ON ??" + + if idx.Where != "" { + createIndexSQL += " WHERE " + idx.Where + } + + return m.DB.Exec(createIndexSQL, values...).Error + } else if field := stmt.Schema.LookUpField(name); field != nil { + for _, idx := range indexes { + for _, idxOpt := range idx.Fields { + if idxOpt.Field == field { + if err = m.CreateIndex(value, idx.Name); err != nil { + return err + } + } + } + } + } + return err + }) +} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 80a18cfb..b77226db 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -5,6 +5,8 @@ import ( "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/migrator" + "github.com/jinzhu/gorm/schema" _ "github.com/mattn/go-sqlite3" ) @@ -24,14 +26,36 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { return } -func (Dialector) Migrator() gorm.Migrator { - return nil +func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { + return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} } -func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { +func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } -func (Dialector) QuoteChars() [2]byte { +func (dialector Dialector) QuoteChars() [2]byte { return [2]byte{'`', '`'} // `name` } + +func (dialector Dialector) DataTypeOf(field *schema.Field) string { + switch field.DataType { + case schema.Bool: + return "NUMERIC" + case schema.Int, schema.Uint: + if field.AutoIncrement { + // https://www.sqlite.org/autoinc.html + return "INTEGER PRIMARY KEY AUTOINCREMENT" + } else { + return "INTEGER" + } + case schema.Float: + return "REAL" + case schema.String, schema.Time: + return "TEXT" + case schema.Bytes: + return "BLOB" + } + + return "" +} diff --git a/interfaces.go b/interfaces.go index 71522455..8f0f3085 100644 --- a/interfaces.go +++ b/interfaces.go @@ -3,12 +3,15 @@ package gorm import ( "context" "database/sql" + + "github.com/jinzhu/gorm/schema" ) // Dialector GORM database dialector type Dialector interface { Initialize(*DB) error - Migrator() Migrator + Migrator(db *DB) Migrator + DataTypeOf(*schema.Field) string BindVar(stmt *Statement, v interface{}) string QuoteChars() [2]byte } diff --git a/migrator.go b/migrator.go index a5ea4d8f..d90c362f 100644 --- a/migrator.go +++ b/migrator.go @@ -6,7 +6,7 @@ import ( // Migrator returns migrator func (db *DB) Migrator() Migrator { - return db.Dialector.Migrator() + return db.Dialector.Migrator(db) } // ViewOption view option @@ -26,7 +26,7 @@ type Migrator interface { // Tables CreateTable(dst ...interface{}) error DropTable(dst ...interface{}) error - HasTable(dst ...interface{}) bool + HasTable(dst interface{}) bool RenameTable(oldName, newName string) error // Columns diff --git a/migrator/migrator.go b/migrator/migrator.go index 7e749037..9e94cc68 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -11,21 +11,21 @@ import ( "github.com/jinzhu/gorm/schema" ) -// Migrator migrator struct +// Migrator m struct type Migrator struct { - *Config + Config } // Config schema config type Config struct { - CheckExistsBeforeDropping bool - DB *gorm.DB + DB *gorm.DB + gorm.Dialector } -func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { - stmt := migrator.DB.Statement +func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { + stmt := m.DB.Statement if stmt == nil { - stmt = &gorm.Statement{DB: migrator.DB} + stmt = &gorm.Statement{DB: m.DB} } if err := stmt.Parse(value); err != nil { @@ -35,20 +35,28 @@ func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement return fc(stmt) } +func (m Migrator) DataTypeOf(field *schema.Field) string { + if field.DBDataType != "" { + return field.DBDataType + } + + return m.Dialector.DataTypeOf(field) +} + // AutoMigrate -func (migrator Migrator) AutoMigrate(values ...interface{}) error { +func (m Migrator) AutoMigrate(values ...interface{}) error { // TODO smart migrate data type for _, value := range values { - if !migrator.DB.Migrator().HasTable(value) { - if err := migrator.DB.Migrator().CreateTable(value); err != nil { + if !m.DB.Migrator().HasTable(value) { + if err := m.DB.Migrator().CreateTable(value); err != nil { return err } } else { - if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { for _, field := range stmt.Schema.FieldsByDBName { - if !migrator.DB.Migrator().HasColumn(value, field.DBName) { - if err := migrator.DB.Migrator().AddColumn(value, field.DBName); err != nil { + if !m.DB.Migrator().HasColumn(value, field.DBName) { + if err := m.DB.Migrator().AddColumn(value, field.DBName); err != nil { return err } } @@ -56,16 +64,16 @@ func (migrator Migrator) AutoMigrate(values ...interface{}) error { for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil { - if !migrator.DB.Migrator().HasConstraint(value, constraint.Name) { - if err := migrator.DB.Migrator().CreateConstraint(value, constraint.Name); err != nil { + if !m.DB.Migrator().HasConstraint(value, constraint.Name) { + if err := m.DB.Migrator().CreateConstraint(value, constraint.Name); err != nil { return err } } } for _, chk := range stmt.Schema.ParseCheckConstraints() { - if !migrator.DB.Migrator().HasConstraint(value, chk.Name) { - if err := migrator.DB.Migrator().CreateConstraint(value, chk.Name); err != nil { + if !m.DB.Migrator().HasConstraint(value, chk.Name) { + if err := m.DB.Migrator().CreateConstraint(value, chk.Name); err != nil { return err } } @@ -73,8 +81,8 @@ func (migrator Migrator) AutoMigrate(values ...interface{}) error { // create join table joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !migrator.DB.Migrator().HasTable(joinValue) { - defer migrator.DB.Migrator().CreateTable(joinValue) + if !m.DB.Migrator().HasTable(joinValue) { + defer m.DB.Migrator().CreateTable(joinValue) } } return nil @@ -87,9 +95,9 @@ func (migrator Migrator) AutoMigrate(values ...interface{}) error { return nil } -func (migrator Migrator) CreateTable(values ...interface{}) error { +func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range values { - if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { var ( createTableSQL = "CREATE TABLE ? (" values = []interface{}{clause.Table{Name: stmt.Table}} @@ -100,7 +108,7 @@ func (migrator Migrator) CreateTable(values ...interface{}) error { field := stmt.Schema.FieldsByDBName[dbName] createTableSQL += fmt.Sprintf("? ?") hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY") - values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: field.DBDataType}) + values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: m.DataTypeOf(field)}) if field.AutoIncrement { createTableSQL += " AUTO_INCREMENT" @@ -133,7 +141,7 @@ func (migrator Migrator) CreateTable(values ...interface{}) error { for _, idx := range stmt.Schema.ParseIndexes() { createTableSQL += "INDEX ? ?," - values = append(values, clause.Expr{SQL: idx.Name}, buildIndexOptions(idx.Fields, stmt)) + values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) } for _, rel := range stmt.Schema.Relationships.Relations { @@ -145,8 +153,8 @@ func (migrator Migrator) CreateTable(values ...interface{}) error { // create join table joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !migrator.DB.Migrator().HasTable(joinValue) { - defer migrator.DB.Migrator().CreateTable(joinValue) + if !m.DB.Migrator().HasTable(joinValue) { + defer m.DB.Migrator().CreateTable(joinValue) } } @@ -158,7 +166,7 @@ func (migrator Migrator) CreateTable(values ...interface{}) error { createTableSQL = strings.TrimSuffix(createTableSQL, ",") createTableSQL += ")" - return migrator.DB.Exec(createTableSQL, values...).Error + return m.DB.Exec(createTableSQL, values...).Error }); err != nil { return err } @@ -166,10 +174,10 @@ func (migrator Migrator) CreateTable(values ...interface{}) error { return nil } -func (migrator Migrator) DropTable(values ...interface{}) error { +func (m Migrator) DropTable(values ...interface{}) error { for _, value := range values { - if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - return migrator.DB.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error + if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error }); err != nil { return err } @@ -177,42 +185,36 @@ func (migrator Migrator) DropTable(values ...interface{}) error { return nil } -func (migrator Migrator) HasTable(values ...interface{}) bool { +func (m Migrator) HasTable(value interface{}) bool { var count int64 - for _, value := range values { - err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - currentDatabase := migrator.DB.Migrator().CurrentDatabase() - return migrator.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Scan(&count).Error - }) + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count) + }) - if err != nil || count == 0 { - return false - } - } - - return true + return count > 0 } -func (migrator Migrator) RenameTable(oldName, newName string) error { - return migrator.DB.Exec("RENAME TABLE ? TO ?", oldName, newName).Error +func (m Migrator) RenameTable(oldName, newName string) error { + return m.DB.Exec("RENAME TABLE ? TO ?", oldName, newName).Error } -func (migrator Migrator) AddColumn(value interface{}, field string) error { - return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { +func (m Migrator) AddColumn(value interface{}, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - return migrator.DB.Exec( + return m.DB.Exec( "ALTER TABLE ? ADD ? ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType}, + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)}, ).Error } return fmt.Errorf("failed to look up field with name: %s", field) }) } -func (migrator Migrator) DropColumn(value interface{}, field string) error { - return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { +func (m Migrator) DropColumn(value interface{}, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - return migrator.DB.Exec( + return m.DB.Exec( "ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, ).Error } @@ -220,44 +222,41 @@ func (migrator Migrator) DropColumn(value interface{}, field string) error { }) } -func (migrator Migrator) AlterColumn(value interface{}, field string) error { - return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { +func (m Migrator) AlterColumn(value interface{}, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - return migrator.DB.Exec( + return m.DB.Exec( "ALTER TABLE ? ALTER COLUMN ? TYPE ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType}, + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)}, ).Error } return fmt.Errorf("failed to look up field with name: %s", field) }) } -func (migrator Migrator) HasColumn(value interface{}, field string) bool { +func (m Migrator) HasColumn(value interface{}, field string) bool { var count int64 - migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - currentDatabase := migrator.DB.Migrator().CurrentDatabase() + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() name := field if field := stmt.Schema.LookUpField(field); field != nil { name = field.DBName } - return migrator.DB.Raw( + return m.DB.Raw( "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, stmt.Table, name, - ).Scan(&count).Error + ).Row().Scan(&count) }) - if count != 0 { - return true - } - return false + return count > 0 } -func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) error { - return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { +func (m Migrator) RenameColumn(value interface{}, oldName, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - oldName = migrator.DB.NamingStrategy.ColumnName(stmt.Table, oldName) - return migrator.DB.Exec( + oldName = m.DB.NamingStrategy.ColumnName(stmt.Table, oldName) + return m.DB.Exec( "ALTER TABLE ? RENAME COLUMN ? TO ?", clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: field.DBName}, ).Error @@ -266,15 +265,15 @@ func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) }) } -func (migrator Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) { +func (m Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) { return nil, gorm.ErrNotImplemented } -func (migrator Migrator) CreateView(name string, option gorm.ViewOption) error { +func (m Migrator) CreateView(name string, option gorm.ViewOption) error { return gorm.ErrNotImplemented } -func (migrator Migrator) DropView(name string) error { +func (m Migrator) DropView(name string) error { return gorm.ErrNotImplemented } @@ -300,11 +299,11 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter return } -func (migrator Migrator) CreateConstraint(value interface{}, name string) error { - return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { +func (m Migrator) CreateConstraint(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { checkConstraints := stmt.Schema.ParseCheckConstraints() if chk, ok := checkConstraints[name]; ok { - return migrator.DB.Exec( + return m.DB.Exec( "ALTER TABLE ? ADD CONSTRAINT ? CHECK ?", clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, ).Error @@ -313,21 +312,21 @@ func (migrator Migrator) CreateConstraint(value interface{}, name string) error for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { sql, values := buildConstraint(constraint) - return migrator.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error + return m.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error } } err := fmt.Errorf("failed to create constraint with name %v", name) if field := stmt.Schema.LookUpField(name); field != nil { for _, cc := range checkConstraints { - if err = migrator.CreateIndex(value, cc.Name); err != nil { + if err = m.CreateIndex(value, cc.Name); err != nil { return err } } for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field { - if err = migrator.CreateIndex(value, constraint.Name); err != nil { + if err = m.CreateIndex(value, constraint.Name); err != nil { return err } } @@ -338,32 +337,29 @@ func (migrator Migrator) CreateConstraint(value interface{}, name string) error }) } -func (migrator Migrator) DropConstraint(value interface{}, name string) error { - return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - return migrator.DB.Exec( +func (m Migrator) DropConstraint(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Exec( "ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: stmt.Table}, clause.Column{Name: name}, ).Error }) } -func (migrator Migrator) HasConstraint(value interface{}, name string) bool { +func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 - migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - currentDatabase := migrator.DB.Migrator().CurrentDatabase() - return migrator.DB.Raw( + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + return m.DB.Raw( "SELECT count(*) FROM INFORMATION_SCHEMA.referential_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", currentDatabase, stmt.Table, name, - ).Scan(&count).Error + ).Row().Scan(&count) }) - if count != 0 { - return true - } - return false + return count > 0 } -func buildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { +func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { for _, opt := range opts { str := stmt.Quote(opt.DBName) if opt.Expression != "" { @@ -372,6 +368,10 @@ func buildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results str += fmt.Sprintf("(%d)", opt.Length) } + if opt.Collate != "" { + str += " COLLATE " + opt.Collate + } + if opt.Sort != "" { str += " " + opt.Sort } @@ -380,13 +380,17 @@ func buildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results return } -func (migrator Migrator) CreateIndex(value interface{}, name string) error { - return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { +type BuildIndexOptionsInterface interface { + BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{} +} + +func (m Migrator) CreateIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { err := fmt.Errorf("failed to create index with name %v", name) indexes := stmt.Schema.ParseIndexes() if idx, ok := indexes[name]; ok { - opts := buildIndexOptions(idx.Fields, stmt) + opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} createIndexSQL := "CREATE " @@ -404,12 +408,12 @@ func (migrator Migrator) CreateIndex(value interface{}, name string) error { createIndexSQL += " USING " + idx.Type } - return migrator.DB.Raw(createIndexSQL, values...).Error + return m.DB.Exec(createIndexSQL, values...).Error } else if field := stmt.Schema.LookUpField(name); field != nil { for _, idx := range indexes { for _, idxOpt := range idx.Fields { if idxOpt.Field == field { - if err = migrator.CreateIndex(value, idx.Name); err != nil { + if err = m.CreateIndex(value, idx.Name); err != nil { return err } } @@ -420,38 +424,35 @@ func (migrator Migrator) CreateIndex(value interface{}, name string) error { }) } -func (migrator Migrator) DropIndex(value interface{}, name string) error { - return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - return migrator.DB.Raw("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error +func (m Migrator) DropIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error }) } -func (migrator Migrator) HasIndex(value interface{}, name string) bool { +func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 - migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - currentDatabase := migrator.DB.Migrator().CurrentDatabase() - return migrator.DB.Raw( + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + return m.DB.Raw( "SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, stmt.Table, name, - ).Scan(&count).Error + ).Row().Scan(&count) }) - if count != 0 { - return true - } - return false + return count > 0 } -func (migrator Migrator) RenameIndex(value interface{}, oldName, newName string) error { - return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - return migrator.DB.Exec( +func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Exec( "ALTER TABLE ? RENAME INDEX ? TO ?", clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, ).Error }) } -func (migrator Migrator) CurrentDatabase() (name string) { - migrator.DB.Raw("SELECT DATABASE()").Scan(&name) +func (m Migrator) CurrentDatabase() (name string) { + m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) return } diff --git a/schema/field.go b/schema/field.go index 60cfc2ab..ea4e6a40 100644 --- a/schema/field.go +++ b/schema/field.go @@ -138,7 +138,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if num, ok := field.TagSettings["SIZE"]; ok { - field.Size, _ = strconv.Atoi(num) + var err error + if field.Size, err = strconv.Atoi(num); err != nil { + field.Size = -1 + } } if p, ok := field.TagSettings["PRECISION"]; ok { From 215f5e77650349aa888c83b481c3e36e2722669e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 22 Feb 2020 19:41:01 +0800 Subject: [PATCH 244/881] Add Raw, Row, Rows --- callbacks/callbacks.go | 3 ++ callbacks/raw.go | 11 +++++++ callbacks/row.go | 19 ++++++++++++ chainable_api.go | 3 ++ dialects/mssql/mssql.go | 5 +++- dialects/mysql/mysql.go | 5 +++- dialects/postgres/postgres.go | 5 +++- dialects/sqlite/sqlite.go | 5 +++- dialects/sqlite/sqlite_test.go | 6 +++- finisher_api.go | 9 ++++-- gorm.go | 5 ++++ migrator/migrator.go | 11 +++++-- schema/check.go | 5 +++- schema/check_test.go | 55 ++++++++++++++++++++++++++++++++++ schema/index_test.go | 2 +- schema/relationship.go | 6 +++- tests/migrate.go | 19 ++++++++++++ 17 files changed, 162 insertions(+), 12 deletions(-) create mode 100644 callbacks/raw.go create mode 100644 callbacks/row.go create mode 100644 schema/check_test.go create mode 100644 tests/migrate.go diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index f9d5543d..0a48ada6 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -38,4 +38,7 @@ func RegisterDefaultCallbacks(db *gorm.DB) { updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations) updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + + db.Callback().Row().Register("gorm:raw", RowQuery) + db.Callback().Raw().Register("gorm:raw", RawExec) } diff --git a/callbacks/raw.go b/callbacks/raw.go new file mode 100644 index 00000000..6d0a5aac --- /dev/null +++ b/callbacks/raw.go @@ -0,0 +1,11 @@ +package callbacks + +import "github.com/jinzhu/gorm" + +func RawExec(db *gorm.DB) { + result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + db.RowsAffected, _ = result.RowsAffected() + if err != nil { + db.AddError(err) + } +} diff --git a/callbacks/row.go b/callbacks/row.go new file mode 100644 index 00000000..04fe4f48 --- /dev/null +++ b/callbacks/row.go @@ -0,0 +1,19 @@ +package callbacks + +import ( + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" +) + +func RowQuery(db *gorm.DB) { + db.Statement.AddClauseIfNotExists(clause.Select{}) + db.Statement.AddClauseIfNotExists(clause.From{}) + + db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") + + if _, ok := db.Get("rows"); ok { + db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } else { + db.Statement.Dest = db.DB.QueryRowContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } +} diff --git a/chainable_api.go b/chainable_api.go index a57deb63..ccd61716 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -222,5 +222,8 @@ func (db *DB) Unscoped() (tx *DB) { func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() + stmt := tx.Statement + stmt.SQL = strings.Builder{} + clause.Expr{SQL: sql, Vars: values}.Build(stmt) return } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index bdca667d..78c048b4 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -28,7 +28,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { } func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { - return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} + return Migrator{migrator.Migrator{Config: migrator.Config{ + DB: db, + Dialector: dialector, + }}} } func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index e2fea53c..3b456891 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -29,7 +29,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { } func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { - return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} + return Migrator{migrator.Migrator{Config: migrator.Config{ + DB: db, + Dialector: dialector, + }}} } func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index a3eeefb9..4ffc4204 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -28,7 +28,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { } func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { - return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} + return Migrator{migrator.Migrator{Config: migrator.Config{ + DB: db, + Dialector: dialector, + }}} } func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index b77226db..804016a5 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -27,7 +27,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { } func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { - return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} + return Migrator{migrator.Migrator{Config: migrator.Config{ + DB: db, + Dialector: dialector, + }}} } func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { diff --git a/dialects/sqlite/sqlite_test.go b/dialects/sqlite/sqlite_test.go index 51c1def0..a42bc8ee 100644 --- a/dialects/sqlite/sqlite_test.go +++ b/dialects/sqlite/sqlite_test.go @@ -22,6 +22,10 @@ func init() { } } -func TestSqlite(t *testing.T) { +func TestCURD(t *testing.T) { tests.RunTestsSuit(t, DB) } + +func TestMigrate(t *testing.T) { + tests.TestMigrate(t, DB) +} diff --git a/finisher_api.go b/finisher_api.go index 5389ed6a..8b824d12 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -108,11 +108,15 @@ func (db *DB) Count(value interface{}) (tx *DB) { } func (db *DB) Row() *sql.Row { - return nil + tx := db.getInstance() + tx.callbacks.Row().Execute(tx) + return tx.Statement.Dest.(*sql.Row) } func (db *DB) Rows() (*sql.Rows, error) { - return nil, nil + tx := db.Set("rows", true) + tx.callbacks.Row().Execute(tx) + return tx.Statement.Dest.(*sql.Rows), tx.Error } // Scan scan value to a struct @@ -162,5 +166,6 @@ func (db *DB) Rollback() (tx *DB) { func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() + tx.callbacks.Raw().Execute(tx) return } diff --git a/gorm.go b/gorm.go index 23f812d1..2f10be60 100644 --- a/gorm.go +++ b/gorm.go @@ -138,6 +138,11 @@ func (db *DB) Callback() *callbacks { return db.callbacks } +// AutoMigrate run auto migration for given models +func (db *DB) AutoMigrate(dst ...interface{}) error { + return db.Migrator().AutoMigrate(dst...) +} + func (db *DB) getInstance() *DB { if db.clone { ctx := db.Instance.Context diff --git a/migrator/migrator.go b/migrator/migrator.go index 9e94cc68..5debc600 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -265,8 +265,15 @@ func (m Migrator) RenameColumn(value interface{}, oldName, field string) error { }) } -func (m Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) { - return nil, gorm.ErrNotImplemented +func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { + err = m.RunWithValue(value, func(stmt *gorm.Statement) error { + rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() + if err == nil { + columnTypes, err = rows.ColumnTypes() + } + return err + }) + return } func (m Migrator) CreateView(name string, option gorm.ViewOption) error { diff --git a/schema/check.go b/schema/check.go index a06ac67b..7d31ec70 100644 --- a/schema/check.go +++ b/schema/check.go @@ -17,9 +17,12 @@ func (schema *Schema) ParseCheckConstraints() map[string]Check { for _, field := range schema.FieldsByDBName { if chk := field.TagSettings["CHECK"]; chk != "" { names := strings.Split(chk, ",") - if len(names) > 1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(names[0]) { + if len(names) > 1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(names[0]) { checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} } else { + if names[0] == "" { + chk = strings.Join(names[1:], ",") + } name := schema.namer.CheckerName(schema.Table, field.DBName) checks[name] = Check{Name: name, Constraint: chk, Field: field} } diff --git a/schema/check_test.go b/schema/check_test.go new file mode 100644 index 00000000..e4bc9ebe --- /dev/null +++ b/schema/check_test.go @@ -0,0 +1,55 @@ +package schema_test + +import ( + "reflect" + "sync" + "testing" + + "github.com/jinzhu/gorm/schema" +) + +type UserCheck struct { + Name string `gorm:"check:name_checker,name <> 'jinzhu'"` + Name2 string `gorm:"check:name <> 'jinzhu'"` + Name3 string `gorm:"check:,name <> 'jinzhu'"` +} + +func TestParseCheck(t *testing.T) { + user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user check, got error %v", err) + } + + results := map[string]schema.Check{ + "name_checker": { + Name: "name_checker", + Constraint: "name <> 'jinzhu'", + }, + "chk_user_checks_name2": { + Name: "chk_user_checks_name2", + Constraint: "name <> 'jinzhu'", + }, + "chk_user_checks_name3": { + Name: "chk_user_checks_name3", + Constraint: "name <> 'jinzhu'", + }, + } + + checks := user.ParseCheckConstraints() + + for k, result := range results { + v, ok := checks[k] + if !ok { + t.Errorf("Failed to found check %v from parsed checks %+v", k, checks) + } + + for _, name := range []string{"Name", "Constraint"} { + if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() { + t.Errorf( + "check %v %v should equal, expects %v, got %v", + k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(), + ) + } + } + } +} diff --git a/schema/index_test.go b/schema/index_test.go index 1409b9c4..d0e8dfe0 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -21,7 +21,7 @@ type UserIndex struct { func TestParseIndex(t *testing.T) { user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { - t.Fatalf("failed to parse user index index, got error %v", err) + t.Fatalf("failed to parse user index, got error %v", err) } results := map[string]schema.Index{ diff --git a/schema/relationship.go b/schema/relationship.go index 8081b0e7..6606d77e 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -317,7 +317,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { settings = ParseTagSetting(str, ",") ) - if idx != -1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(str[0:idx]) { + if idx != -1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(str[0:idx]) { name = str[0:idx] } else { name = rel.Schema.namer.RelationshipFKName(*rel) @@ -339,5 +339,9 @@ func (rel *Relationship) ParseConstraint() *Constraint { } } + if constraint.ReferenceSchema == nil { + return nil + } + return &constraint } diff --git a/tests/migrate.go b/tests/migrate.go new file mode 100644 index 00000000..0466fe11 --- /dev/null +++ b/tests/migrate.go @@ -0,0 +1,19 @@ +package tests + +import ( + "testing" + + "github.com/jinzhu/gorm" +) + +func TestMigrate(t *testing.T, db *gorm.DB) { + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Toy{}, &Company{}, &Language{}} + + db.AutoMigrate(allModels...) + + for _, m := range allModels { + if !db.Migrator().HasTable(m) { + t.Errorf("Failed to create table for %+v", m) + } + } +} From 6d58b62fd457ccfd8daef962f679e266b6844a2a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 22 Feb 2020 20:57:29 +0800 Subject: [PATCH 245/881] Add sqlite migration tests --- callbacks/query.go | 9 ++++++--- callbacks/raw.go | 7 +++++-- callbacks/row.go | 8 +++++--- chainable_api.go | 5 ++--- clause/expression.go | 6 ++++-- clause/expression_test.go | 35 +++++++++++++++++++++++++++++++++++ dialects/sqlite/migrator.go | 4 ++-- dialects/sqlite/sqlite.go | 17 +++++++++-------- finisher_api.go | 3 +++ go.mod | 1 + migrator/migrator.go | 33 +++++++++++++++++++++------------ schema/naming.go | 2 +- schema/relationship.go | 2 +- statement.go | 5 ++++- tests/dummy_dialecter.go | 7 ++++++- tests/migrate.go | 14 ++++++++++++-- 16 files changed, 117 insertions(+), 41 deletions(-) create mode 100644 clause/expression_test.go diff --git a/callbacks/query.go b/callbacks/query.go index 8d13095e..a4ed3adb 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -8,10 +8,13 @@ import ( ) func Query(db *gorm.DB) { - db.Statement.AddClauseIfNotExists(clause.Select{}) - db.Statement.AddClauseIfNotExists(clause.From{}) + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Select{}) + db.Statement.AddClauseIfNotExists(clause.From{}) + + db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") + } - db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) fmt.Println(err) fmt.Println(result) diff --git a/callbacks/raw.go b/callbacks/raw.go index 6d0a5aac..e8cad25d 100644 --- a/callbacks/raw.go +++ b/callbacks/raw.go @@ -1,11 +1,14 @@ package callbacks -import "github.com/jinzhu/gorm" +import ( + "github.com/jinzhu/gorm" +) func RawExec(db *gorm.DB) { result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) - db.RowsAffected, _ = result.RowsAffected() if err != nil { db.AddError(err) + } else { + db.RowsAffected, _ = result.RowsAffected() } } diff --git a/callbacks/row.go b/callbacks/row.go index 04fe4f48..f7d6752d 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -6,10 +6,12 @@ import ( ) func RowQuery(db *gorm.DB) { - db.Statement.AddClauseIfNotExists(clause.Select{}) - db.Statement.AddClauseIfNotExists(clause.From{}) + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Select{}) + db.Statement.AddClauseIfNotExists(clause.From{}) - db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") + db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") + } if _, ok := db.Get("rows"); ok { db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/chainable_api.go b/chainable_api.go index ccd61716..770b2236 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -222,8 +222,7 @@ func (db *DB) Unscoped() (tx *DB) { func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() - stmt := tx.Statement - stmt.SQL = strings.Builder{} - clause.Expr{SQL: sql, Vars: values}.Build(stmt) + tx.Statement.SQL = strings.Builder{} + clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) return } diff --git a/clause/expression.go b/clause/expression.go index 6b3575df..d72db08d 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -1,6 +1,8 @@ package clause -import "strings" +import ( + "strings" +) // Expression expression interface type Expression interface { @@ -22,7 +24,7 @@ type Expr struct { func (expr Expr) Build(builder Builder) { sql := expr.SQL for _, v := range expr.Vars { - sql = strings.Replace(sql, " ?", " "+builder.AddVar(v), 1) + sql = strings.Replace(sql, "?", builder.AddVar(v), 1) } builder.Write(sql) } diff --git a/clause/expression_test.go b/clause/expression_test.go new file mode 100644 index 00000000..e51d189e --- /dev/null +++ b/clause/expression_test.go @@ -0,0 +1,35 @@ +package clause_test + +import ( + "fmt" + "sync" + "testing" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/gorm/tests" +) + +func TestExpr(t *testing.T) { + results := []struct { + SQL string + Result string + Vars []interface{} + }{{ + SQL: "create table ? (? ?, ? ?)", + Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}}, + Result: "create table `users` (`id` int, `name` text)", + }} + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) + if stmt.SQL.String() != result.Result { + t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) + } + }) + } +} diff --git a/dialects/sqlite/migrator.go b/dialects/sqlite/migrator.go index 07e189ad..4ddcbb5d 100644 --- a/dialects/sqlite/migrator.go +++ b/dialects/sqlite/migrator.go @@ -30,8 +30,8 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { } return m.DB.Raw( - "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ?)", - stmt.Table, `%"`+name+`" %`, `%`+name+` %`, + "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)", + stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", ).Row().Scan(&count) }) return count > 0 diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 804016a5..38cd760b 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -28,8 +28,9 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { return Migrator{migrator.Migrator{Config: migrator.Config{ - DB: db, - Dialector: dialector, + DB: db, + Dialector: dialector, + CreateIndexAfterCreateTable: true, }}} } @@ -44,20 +45,20 @@ func (dialector Dialector) QuoteChars() [2]byte { func (dialector Dialector) DataTypeOf(field *schema.Field) string { switch field.DataType { case schema.Bool: - return "NUMERIC" + return "numeric" case schema.Int, schema.Uint: if field.AutoIncrement { // https://www.sqlite.org/autoinc.html - return "INTEGER PRIMARY KEY AUTOINCREMENT" + return "integer PRIMARY KEY AUTOINCREMENT" } else { - return "INTEGER" + return "integer" } case schema.Float: - return "REAL" + return "real" case schema.String, schema.Time: - return "TEXT" + return "text" case schema.Bytes: - return "BLOB" + return "blob" } return "" diff --git a/finisher_api.go b/finisher_api.go index 8b824d12..c9b58861 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -2,6 +2,7 @@ package gorm import ( "database/sql" + "strings" "github.com/jinzhu/gorm/clause" ) @@ -166,6 +167,8 @@ func (db *DB) Rollback() (tx *DB) { func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.SQL = strings.Builder{} + clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) tx.callbacks.Raw().Execute(tx) return } diff --git a/go.mod b/go.mod index cdb7e574..9046ea99 100644 --- a/go.mod +++ b/go.mod @@ -5,4 +5,5 @@ go 1.13 require ( github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.1.1 + github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect ) diff --git a/migrator/migrator.go b/migrator/migrator.go index 5debc600..e3097abd 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -18,7 +18,8 @@ type Migrator struct { // Config schema config type Config struct { - DB *gorm.DB + CreateIndexAfterCreateTable bool + DB *gorm.DB gorm.Dialector } @@ -80,9 +81,11 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } // create join table - joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !m.DB.Migrator().HasTable(joinValue) { - defer m.DB.Migrator().CreateTable(joinValue) + if rel.JoinTable != nil { + joinValue := reflect.New(rel.JoinTable.ModelType).Interface() + if !m.DB.Migrator().HasTable(joinValue) { + defer m.DB.Migrator().CreateTable(joinValue) + } } } return nil @@ -140,8 +143,12 @@ func (m Migrator) CreateTable(values ...interface{}) error { } for _, idx := range stmt.Schema.ParseIndexes() { - createTableSQL += "INDEX ? ?," - values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) + if m.CreateIndexAfterCreateTable { + m.DB.Migrator().CreateIndex(value, idx.Name) + } else { + createTableSQL += "INDEX ? ?," + values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) + } } for _, rel := range stmt.Schema.Relationships.Relations { @@ -152,9 +159,11 @@ func (m Migrator) CreateTable(values ...interface{}) error { } // create join table - joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !m.DB.Migrator().HasTable(joinValue) { - defer m.DB.Migrator().CreateTable(joinValue) + if rel.JoinTable != nil { + joinValue := reflect.New(rel.JoinTable.ModelType).Interface() + if !m.DB.Migrator().HasTable(joinValue) { + defer m.DB.Migrator().CreateTable(joinValue) + } } } @@ -302,7 +311,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter for _, field := range constraint.References { references = append(references, clause.Column{Name: field.DBName}) } - results = append(results, constraint.Name, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) + results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) return } @@ -326,14 +335,14 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { err := fmt.Errorf("failed to create constraint with name %v", name) if field := stmt.Schema.LookUpField(name); field != nil { for _, cc := range checkConstraints { - if err = m.CreateIndex(value, cc.Name); err != nil { + if err = m.DB.Migrator().CreateIndex(value, cc.Name); err != nil { return err } } for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field { - if err = m.CreateIndex(value, constraint.Name); err != nil { + if err = m.DB.Migrator().CreateIndex(value, constraint.Name); err != nil { return err } } diff --git a/schema/naming.go b/schema/naming.go index d6f26e9f..f7c82f32 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -46,7 +46,7 @@ func (ns NamingStrategy) JoinTableName(str string) string { // RelationshipFKName generate fk name for relation func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { - return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, rel.FieldSchema.Table) + return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Field.Name)) } // CheckerName generate checker name diff --git a/schema/relationship.go b/schema/relationship.go index 6606d77e..4ffea8b3 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -339,7 +339,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { } } - if constraint.ReferenceSchema == nil { + if rel.JoinTable != nil || constraint.ReferenceSchema == nil { return nil } diff --git a/statement.go b/statement.go index 8c75c90d..d486a1c7 100644 --- a/statement.go +++ b/statement.go @@ -152,8 +152,11 @@ func (stmt *Statement) AddVar(vars ...interface{}) string { stmt.Vars = append(stmt.Vars, v.Value) placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value)) } - case clause.Column: + case clause.Column, clause.Table: placeholders.WriteString(stmt.Quote(v)) + case clause.Expr: + placeholders.WriteString(v.SQL) + stmt.Vars = append(stmt.Vars, v.Vars...) case []interface{}: if len(v) > 0 { placeholders.WriteByte('(') diff --git a/tests/dummy_dialecter.go b/tests/dummy_dialecter.go index e2cda8fc..b4e3361b 100644 --- a/tests/dummy_dialecter.go +++ b/tests/dummy_dialecter.go @@ -2,6 +2,7 @@ package tests import ( "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/schema" ) type DummyDialector struct { @@ -11,7 +12,7 @@ func (DummyDialector) Initialize(*gorm.DB) error { return nil } -func (DummyDialector) Migrator() gorm.Migrator { +func (DummyDialector) Migrator(*gorm.DB) gorm.Migrator { return nil } @@ -22,3 +23,7 @@ func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string { func (DummyDialector) QuoteChars() [2]byte { return [2]byte{'`', '`'} // `name` } + +func (DummyDialector) DataTypeOf(*schema.Field) string { + return "" +} diff --git a/tests/migrate.go b/tests/migrate.go index 0466fe11..9f7e2d67 100644 --- a/tests/migrate.go +++ b/tests/migrate.go @@ -9,11 +9,21 @@ import ( func TestMigrate(t *testing.T, db *gorm.DB) { allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Toy{}, &Company{}, &Language{}} - db.AutoMigrate(allModels...) + for _, m := range allModels { + if db.Migrator().HasTable(m) { + if err := db.Migrator().DropTable(m); err != nil { + t.Errorf("Failed to drop table, got error %v", err) + } + } + } + + if err := db.AutoMigrate(allModels...); err != nil { + t.Errorf("Failed to auto migrate, but got error %v", err) + } for _, m := range allModels { if !db.Migrator().HasTable(m) { - t.Errorf("Failed to create table for %+v", m) + t.Errorf("Failed to create table for %#v", m) } } } From 1895d281bf7a183e5d679c1962737eb74ab19546 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 22 Feb 2020 23:08:20 +0800 Subject: [PATCH 246/881] Add migrator tests for mysql --- dialects/mysql/mysql.go | 11 ++++++---- dialects/mysql/mysql_test.go | 21 ++++++++++++++++++ dialects/sqlite/sqlite.go | 1 - finisher_api.go | 2 +- migrator/migrator.go | 41 +++++++++++++++++++----------------- tests/migrate.go | 2 +- tests/model.go | 4 ++-- 7 files changed, 54 insertions(+), 28 deletions(-) diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index 3b456891..5fcc2d69 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -23,9 +23,8 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db) - db.DB, err = sql.Open("sqlite3", dialector.DSN) - - return nil + db.DB, err = sql.Open("mysql", dialector.DSN) + return } func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { @@ -75,9 +74,13 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { return "double" case schema.String: size := field.Size + if field.PrimaryKey { + size = 256 + } + if size >= 65536 && size <= int(math.Pow(2, 24)) { return "mediumtext" - } else if size > int(math.Pow(2, 24)) || size < 0 { + } else if size > int(math.Pow(2, 24)) || size <= 0 { return "longtext" } return fmt.Sprintf("varchar(%d)", size) diff --git a/dialects/mysql/mysql_test.go b/dialects/mysql/mysql_test.go index 49c26915..7fd5e373 100644 --- a/dialects/mysql/mysql_test.go +++ b/dialects/mysql/mysql_test.go @@ -1,12 +1,33 @@ package mysql_test import ( + "fmt" "testing" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/dialects/mysql" + "github.com/jinzhu/gorm/tests" ) func TestOpen(t *testing.T) { gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"), nil) } + +var ( + DB *gorm.DB + err error +) + +func init() { + if DB, err = gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"), &gorm.Config{}); err != nil { + panic(fmt.Sprintf("failed to initialize database, got error %v", err)) + } +} + +func TestCURD(t *testing.T) { + tests.RunTestsSuit(t, DB) +} + +func TestMigrate(t *testing.T) { + tests.TestMigrate(t, DB) +} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 38cd760b..54fa7de0 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -21,7 +21,6 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db) - db.DB, err = sql.Open("sqlite3", dialector.DSN) return } diff --git a/finisher_api.go b/finisher_api.go index c9b58861..2c5d4f65 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -140,7 +140,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } }() - err = fc(tx) + err = fc(tx.Session(&Session{})) if err == nil { err = tx.Commit().Error diff --git a/migrator/migrator.go b/migrator/migrator.go index e3097abd..a5ec1a62 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -18,8 +18,9 @@ type Migrator struct { // Config schema config type Config struct { - CreateIndexAfterCreateTable bool - DB *gorm.DB + CreateIndexAfterCreateTable bool + AllowDeferredConstraintsWhenAutoMigrate bool + DB *gorm.DB gorm.Dialector } @@ -47,17 +48,17 @@ func (m Migrator) DataTypeOf(field *schema.Field) string { // AutoMigrate func (m Migrator) AutoMigrate(values ...interface{}) error { // TODO smart migrate data type - for _, value := range values { - if !m.DB.Migrator().HasTable(value) { - if err := m.DB.Migrator().CreateTable(value); err != nil { + tx := m.DB.Session(&gorm.Session{}) + if !tx.Migrator().HasTable(value) { + if err := tx.Migrator().CreateTable(value); err != nil { return err } } else { if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { for _, field := range stmt.Schema.FieldsByDBName { - if !m.DB.Migrator().HasColumn(value, field.DBName) { - if err := m.DB.Migrator().AddColumn(value, field.DBName); err != nil { + if !tx.Migrator().HasColumn(value, field.DBName) { + if err := tx.Migrator().AddColumn(value, field.DBName); err != nil { return err } } @@ -65,16 +66,16 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil { - if !m.DB.Migrator().HasConstraint(value, constraint.Name) { - if err := m.DB.Migrator().CreateConstraint(value, constraint.Name); err != nil { + if !tx.Migrator().HasConstraint(value, constraint.Name) { + if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { return err } } } for _, chk := range stmt.Schema.ParseCheckConstraints() { - if !m.DB.Migrator().HasConstraint(value, chk.Name) { - if err := m.DB.Migrator().CreateConstraint(value, chk.Name); err != nil { + if !tx.Migrator().HasConstraint(value, chk.Name) { + if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { return err } } @@ -83,8 +84,8 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { // create join table if rel.JoinTable != nil { joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !m.DB.Migrator().HasTable(joinValue) { - defer m.DB.Migrator().CreateTable(joinValue) + if !tx.Migrator().HasTable(joinValue) { + defer tx.Migrator().CreateTable(joinValue) } } } @@ -100,6 +101,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range values { + tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { var ( createTableSQL = "CREATE TABLE ? (" @@ -144,10 +146,10 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, idx := range stmt.Schema.ParseIndexes() { if m.CreateIndexAfterCreateTable { - m.DB.Migrator().CreateIndex(value, idx.Name) + tx.Migrator().CreateIndex(value, idx.Name) } else { createTableSQL += "INDEX ? ?," - values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) + values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) } } @@ -161,8 +163,8 @@ func (m Migrator) CreateTable(values ...interface{}) error { // create join table if rel.JoinTable != nil { joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !m.DB.Migrator().HasTable(joinValue) { - defer m.DB.Migrator().CreateTable(joinValue) + if !tx.Migrator().HasTable(joinValue) { + defer tx.Migrator().CreateTable(joinValue) } } } @@ -175,7 +177,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { createTableSQL = strings.TrimSuffix(createTableSQL, ",") createTableSQL += ")" - return m.DB.Exec(createTableSQL, values...).Error + return tx.Exec(createTableSQL, values...).Error }); err != nil { return err } @@ -185,8 +187,9 @@ func (m Migrator) CreateTable(values ...interface{}) error { func (m Migrator) DropTable(values ...interface{}) error { for _, value := range values { + tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error + return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error }); err != nil { return err } diff --git a/tests/migrate.go b/tests/migrate.go index 9f7e2d67..477f0ad6 100644 --- a/tests/migrate.go +++ b/tests/migrate.go @@ -7,7 +7,7 @@ import ( ) func TestMigrate(t *testing.T, db *gorm.DB) { - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Toy{}, &Company{}, &Language{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} for _, m := range allModels { if db.Migrator().HasTable(m) { diff --git a/tests/model.go b/tests/model.go index ac2156c7..b2d5efe1 100644 --- a/tests/model.go +++ b/tests/model.go @@ -21,7 +21,7 @@ type User struct { Toys []Toy `gorm:"polymorphic:Owner"` CompanyID *int Company Company - ManagerID int + ManagerID uint Manager *User Team []User `gorm:"foreignkey:ManagerID"` Languages []Language `gorm:"many2many:UserSpeak"` @@ -49,7 +49,7 @@ type Toy struct { } type Company struct { - ID uint + ID int Name string } From d3c63a03cbed09c07d4c5a19189d25768f3204ea Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Feb 2020 00:18:12 +0800 Subject: [PATCH 247/881] Handle constraint dependencies smartly --- migrator/migrator.go | 77 ++++++++++++++++++++++++++++++++++++++++++-- tests/migrate.go | 12 +++---- 2 files changed, 80 insertions(+), 9 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index a5ec1a62..318c2fb8 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -48,7 +48,7 @@ func (m Migrator) DataTypeOf(field *schema.Field) string { // AutoMigrate func (m Migrator) AutoMigrate(values ...interface{}) error { // TODO smart migrate data type - for _, value := range values { + for _, value := range m.ReorderModels(values, true) { tx := m.DB.Session(&gorm.Session{}) if !tx.Migrator().HasTable(value) { if err := tx.Migrator().CreateTable(value); err != nil { @@ -100,7 +100,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } func (m Migrator) CreateTable(values ...interface{}) error { - for _, value := range values { + for _, value := range m.ReorderModels(values, false) { tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { var ( @@ -186,7 +186,9 @@ func (m Migrator) CreateTable(values ...interface{}) error { } func (m Migrator) DropTable(values ...interface{}) error { - for _, value := range values { + values = m.ReorderModels(values, false) + for i := len(values) - 1; i >= 0; i-- { + value := values[i] tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error @@ -475,3 +477,72 @@ func (m Migrator) CurrentDatabase() (name string) { m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) return } + +// ReorderModels reorder models according to constraint dependencies +func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []interface{}) { + type Dependency struct { + Table string + Depends []*schema.Schema + } + + var ( + modelNames, orderedModelNames []string + orderedModelNamesMap = map[string]bool{} + valuesMap = map[string]*gorm.Statement{} + dependencies = map[string]Dependency{} + insertIntoOrderedMap func(name string) + ) + + parseDependence := func(value interface{}, addToMap bool) { + stmt := &gorm.Statement{DB: m.DB, Dest: value} + stmt.Parse(value) + dep := Dependency{Table: stmt.Schema.Table} + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil { + dep.Depends = append(dep.Depends, constraint.ReferenceSchema) + } + } + dependencies[stmt.Schema.Table] = dep + + if addToMap { + modelNames = append(modelNames, stmt.Schema.Table) + valuesMap[stmt.Schema.Table] = stmt + } + } + + for _, value := range values { + parseDependence(value, true) + } + + insertIntoOrderedMap = func(name string) { + // avoid loop + if _, ok := orderedModelNamesMap[name]; ok { + return + } + + dep := dependencies[name] + for _, d := range dep.Depends { + if _, ok := valuesMap[d.Table]; ok { + if _, ok := orderedModelNamesMap[d.Table]; !ok && name != d.Table { + insertIntoOrderedMap(d.Table) + } + } else if autoAdd { + parseDependence(reflect.New(d.ModelType).Interface(), autoAdd) + insertIntoOrderedMap(d.Table) + } + } + + orderedModelNames = append(orderedModelNames, name) + orderedModelNamesMap[name] = true + } + + for _, name := range modelNames { + insertIntoOrderedMap(name) + } + + for _, name := range orderedModelNames { + results = append(results, valuesMap[name].Dest) + } + return +} diff --git a/tests/migrate.go b/tests/migrate.go index 477f0ad6..fa8a89e8 100644 --- a/tests/migrate.go +++ b/tests/migrate.go @@ -1,20 +1,20 @@ package tests import ( + "math/rand" "testing" + "time" "github.com/jinzhu/gorm" ) func TestMigrate(t *testing.T, db *gorm.DB) { allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) - for _, m := range allModels { - if db.Migrator().HasTable(m) { - if err := db.Migrator().DropTable(m); err != nil { - t.Errorf("Failed to drop table, got error %v", err) - } - } + if err := db.Migrator().DropTable(allModels...); err != nil { + t.Errorf("Failed to drop table, got error %v", err) } if err := db.AutoMigrate(allModels...); err != nil { From ce84e82c9e3d9d6beecfeba5f22a425a4aebc02b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Feb 2020 00:40:06 +0800 Subject: [PATCH 248/881] Add migrator tests for postgres --- dialects/mysql/mysql_test.go | 4 ---- dialects/postgres/migrator.go | 26 ++++++++++++++++++++++++++ dialects/postgres/postgres.go | 8 +++++--- dialects/postgres/postgres_test.go | 29 +++++++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 7 deletions(-) create mode 100644 dialects/postgres/postgres_test.go diff --git a/dialects/mysql/mysql_test.go b/dialects/mysql/mysql_test.go index 7fd5e373..f079ad60 100644 --- a/dialects/mysql/mysql_test.go +++ b/dialects/mysql/mysql_test.go @@ -9,10 +9,6 @@ import ( "github.com/jinzhu/gorm/tests" ) -func TestOpen(t *testing.T) { - gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"), nil) -} - var ( DB *gorm.DB err error diff --git a/dialects/postgres/migrator.go b/dialects/postgres/migrator.go index 35101bf3..f06af25f 100644 --- a/dialects/postgres/migrator.go +++ b/dialects/postgres/migrator.go @@ -87,3 +87,29 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { return err }) } + +func (m Migrator) HasTable(value interface{}) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND table_type = ?", stmt.Table, "BASE TABLE").Row().Scan(&count) + }) + + return count > 0 +} + +func (m Migrator) HasColumn(value interface{}, field string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + name := field + if field := stmt.Schema.LookUpField(field); field != nil { + name = field.DBName + } + + return m.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?", + stmt.Table, name, + ).Row().Scan(&count) + }) + + return count > 0 +} diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 4ffc4204..bb9726a8 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -3,6 +3,7 @@ package postgres import ( "database/sql" "fmt" + "strconv" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" @@ -29,13 +30,14 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { return Migrator{migrator.Migrator{Config: migrator.Config{ - DB: db, - Dialector: dialector, + DB: db, + Dialector: dialector, + CreateIndexAfterCreateTable: true, }}} } func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { - return "?" + return "$" + strconv.Itoa(len(stmt.Vars)) } func (dialector Dialector) QuoteChars() [2]byte { diff --git a/dialects/postgres/postgres_test.go b/dialects/postgres/postgres_test.go new file mode 100644 index 00000000..84c0fe53 --- /dev/null +++ b/dialects/postgres/postgres_test.go @@ -0,0 +1,29 @@ +package postgres_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/dialects/postgres" + "github.com/jinzhu/gorm/tests" +) + +var ( + DB *gorm.DB + err error +) + +func init() { + if DB, err = gorm.Open(postgres.Open("user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"), &gorm.Config{}); err != nil { + panic(fmt.Sprintf("failed to initialize database, got error %v", err)) + } +} + +func TestCURD(t *testing.T) { + tests.RunTestsSuit(t, DB) +} + +func TestMigrate(t *testing.T) { + tests.TestMigrate(t, DB) +} From 1d803dfdd9fa106f329ff6247433e893d44cb152 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Feb 2020 01:02:07 +0800 Subject: [PATCH 249/881] Add migrator tests for mssql --- dialects/mssql/migrator.go | 11 +++++++++++ dialects/mssql/mssql.go | 18 ++++++++++++------ dialects/mssql/mssql_test.go | 29 +++++++++++++++++++++++++++++ migrator/migrator.go | 12 +++++++----- 4 files changed, 59 insertions(+), 11 deletions(-) create mode 100644 dialects/mssql/mssql_test.go diff --git a/dialects/mssql/migrator.go b/dialects/mssql/migrator.go index 43eaf573..412d86c6 100644 --- a/dialects/mssql/migrator.go +++ b/dialects/mssql/migrator.go @@ -9,6 +9,17 @@ type Migrator struct { migrator.Migrator } +func (m Migrator) HasTable(value interface{}) bool { + var count int + m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", + stmt.Table, m.CurrentDatabase(), + ).Row().Scan(&count) + }) + return count > 0 +} + func (m Migrator) HasIndex(value interface{}, name string) bool { var count int m.RunWithValue(value, func(stmt *gorm.Statement) error { diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 78c048b4..ded49aae 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -3,6 +3,7 @@ package mssql import ( "database/sql" "fmt" + "strconv" _ "github.com/denisenkom/go-mssqldb" "github.com/jinzhu/gorm" @@ -29,17 +30,18 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { return Migrator{migrator.Migrator{Config: migrator.Config{ - DB: db, - Dialector: dialector, + DB: db, + Dialector: dialector, + CreateIndexAfterCreateTable: true, }}} } func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { - return "?" + return "@p" + strconv.Itoa(len(stmt.Vars)) } func (dialector Dialector) QuoteChars() [2]byte { - return [2]byte{'[', ']'} // `name` + return [2]byte{'"', '"'} // `name` } func (dialector Dialector) DataTypeOf(field *schema.Field) string { @@ -64,8 +66,12 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { case schema.Float: return "decimal" case schema.String: - if field.Size > 0 && field.Size <= 4000 { - return fmt.Sprintf("nvarchar(%d)", field.Size) + size := field.Size + if field.PrimaryKey { + size = 256 + } + if size > 0 && size <= 4000 { + return fmt.Sprintf("nvarchar(%d)", size) } return "ntext" case schema.Time: diff --git a/dialects/mssql/mssql_test.go b/dialects/mssql/mssql_test.go new file mode 100644 index 00000000..b56e7369 --- /dev/null +++ b/dialects/mssql/mssql_test.go @@ -0,0 +1,29 @@ +package mssql_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/dialects/mssql" + "github.com/jinzhu/gorm/tests" +) + +var ( + DB *gorm.DB + err error +) + +func init() { + if DB, err = gorm.Open(mssql.Open("sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"), &gorm.Config{}); err != nil { + panic(fmt.Sprintf("failed to initialize database, got error %v", err)) + } +} + +func TestCURD(t *testing.T) { + tests.RunTestsSuit(t, DB) +} + +func TestMigrate(t *testing.T) { + tests.TestMigrate(t, DB) +} diff --git a/migrator/migrator.go b/migrator/migrator.go index 318c2fb8..4b52193f 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -189,11 +189,13 @@ func (m Migrator) DropTable(values ...interface{}) error { values = m.ReorderModels(values, false) for i := len(values) - 1; i >= 0; i-- { value := values[i] - tx := m.DB.Session(&gorm.Session{}) - if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { - return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error - }); err != nil { - return err + if m.DB.Migrator().HasTable(value) { + tx := m.DB.Session(&gorm.Session{}) + if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error + }); err != nil { + return err + } } } return nil From a67be2a1f12503c69fa3de5d3f5a97ddec5a4025 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Feb 2020 08:29:59 +0800 Subject: [PATCH 250/881] Refactor reorder migrator models --- dialects/mssql/mssql.go | 2 +- dialects/mysql/mysql.go | 2 +- go.mod | 1 - migrator/migrator.go | 55 +++++++++++++++++++---------------------- 4 files changed, 28 insertions(+), 32 deletions(-) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index ded49aae..79e36385 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -67,7 +67,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { return "decimal" case schema.String: size := field.Size - if field.PrimaryKey { + if field.PrimaryKey && size == 0 { size = 256 } if size > 0 && size <= 4000 { diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index 5fcc2d69..629b89df 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -74,7 +74,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { return "double" case schema.String: size := field.Size - if field.PrimaryKey { + if field.PrimaryKey && size == 0 { size = 256 } diff --git a/go.mod b/go.mod index 9046ea99..cdb7e574 100644 --- a/go.mod +++ b/go.mod @@ -5,5 +5,4 @@ go 1.13 require ( github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.1.1 - github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect ) diff --git a/migrator/migrator.go b/migrator/migrator.go index 4b52193f..730e8cfe 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -483,55 +483,48 @@ func (m Migrator) CurrentDatabase() (name string) { // ReorderModels reorder models according to constraint dependencies func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []interface{}) { type Dependency struct { - Table string + *gorm.Statement Depends []*schema.Schema } var ( modelNames, orderedModelNames []string orderedModelNamesMap = map[string]bool{} - valuesMap = map[string]*gorm.Statement{} - dependencies = map[string]Dependency{} - insertIntoOrderedMap func(name string) + valuesMap = map[string]Dependency{} + insertIntoOrderedList func(name string) ) - parseDependence := func(value interface{}, addToMap bool) { - stmt := &gorm.Statement{DB: m.DB, Dest: value} - stmt.Parse(value) - dep := Dependency{Table: stmt.Schema.Table} + parseDependence := func(value interface{}, addToList bool) { + dep := Dependency{ + Statement: &gorm.Statement{DB: m.DB, Dest: value}, + } + dep.Parse(value) - for _, rel := range stmt.Schema.Relationships.Relations { - if constraint := rel.ParseConstraint(); constraint != nil { - dep.Depends = append(dep.Depends, constraint.ReferenceSchema) + for _, rel := range dep.Schema.Relationships.Relations { + if c := rel.ParseConstraint(); c != nil && c.Schema != c.ReferenceSchema { + dep.Depends = append(dep.Depends, c.ReferenceSchema) } } - dependencies[stmt.Schema.Table] = dep - if addToMap { - modelNames = append(modelNames, stmt.Schema.Table) - valuesMap[stmt.Schema.Table] = stmt + valuesMap[dep.Schema.Table] = dep + + if addToList { + modelNames = append(modelNames, dep.Schema.Table) } } - for _, value := range values { - parseDependence(value, true) - } - - insertIntoOrderedMap = func(name string) { - // avoid loop + insertIntoOrderedList = func(name string) { if _, ok := orderedModelNamesMap[name]; ok { - return + return // avoid loop } - dep := dependencies[name] + dep := valuesMap[name] for _, d := range dep.Depends { if _, ok := valuesMap[d.Table]; ok { - if _, ok := orderedModelNamesMap[d.Table]; !ok && name != d.Table { - insertIntoOrderedMap(d.Table) - } + insertIntoOrderedList(d.Table) } else if autoAdd { parseDependence(reflect.New(d.ModelType).Interface(), autoAdd) - insertIntoOrderedMap(d.Table) + insertIntoOrderedList(d.Table) } } @@ -539,12 +532,16 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i orderedModelNamesMap[name] = true } + for _, value := range values { + parseDependence(value, true) + } + for _, name := range modelNames { - insertIntoOrderedMap(name) + insertIntoOrderedList(name) } for _, name := range orderedModelNames { - results = append(results, valuesMap[name].Dest) + results = append(results, valuesMap[name].Statement.Dest) } return } From fe24c3f105762bf780f5ab5d1d63d2f11a930886 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Feb 2020 09:38:48 +0800 Subject: [PATCH 251/881] Setup tests script --- schema/model_test.go | 2 +- schema/schema_test.go | 2 +- tests/README.md | 10 +++ tests/docker-compose.yml | 30 +++++++++ tests/tests_all.sh | 25 ++++++++ wercker.yml | 132 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 199 insertions(+), 2 deletions(-) create mode 100644 tests/README.md create mode 100644 tests/docker-compose.yml create mode 100755 tests/tests_all.sh create mode 100644 wercker.yml diff --git a/schema/model_test.go b/schema/model_test.go index aca7e617..343e324e 100644 --- a/schema/model_test.go +++ b/schema/model_test.go @@ -18,7 +18,7 @@ type User struct { Toys []*tests.Toy `gorm:"polymorphic:Owner"` CompanyID *int Company *tests.Company - ManagerID *int + ManagerID *uint Manager *User Team []*User `gorm:"foreignkey:ManagerID"` Languages []*tests.Language `gorm:"many2many:UserSpeak"` diff --git a/schema/schema_test.go b/schema/schema_test.go index 4134c966..ce225010 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -40,7 +40,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { {Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint}, {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, {Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int}, - {Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Int}, + {Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint}, {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, } diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 00000000..6ae3337f --- /dev/null +++ b/tests/README.md @@ -0,0 +1,10 @@ +# Test Guide + +```bash +cd tests +# prepare test databases +docker-compose up + +# run all tests +./tests_all.sh +``` diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml new file mode 100644 index 00000000..79bf5fc3 --- /dev/null +++ b/tests/docker-compose.yml @@ -0,0 +1,30 @@ +version: '3' + +services: + mysql: + image: 'mysql:latest' + ports: + - 9910:3306 + environment: + - MYSQL_DATABASE=gorm + - MYSQL_USER=gorm + - MYSQL_PASSWORD=gorm + - MYSQL_RANDOM_ROOT_PASSWORD="yes" + postgres: + image: 'postgres:latest' + ports: + - 9920:5432 + environment: + - POSTGRES_USER=gorm + - POSTGRES_DB=gorm + - POSTGRES_PASSWORD=gorm + mssql: + image: 'mcmoe/mssqldocker:latest' + ports: + - 9930:1433 + environment: + - ACCEPT_EULA=Y + - SA_PASSWORD=LoremIpsum86 + - MSSQL_DB=gorm + - MSSQL_USER=gorm + - MSSQL_PASSWORD=LoremIpsum86 diff --git a/tests/tests_all.sh b/tests/tests_all.sh new file mode 100755 index 00000000..91d415f1 --- /dev/null +++ b/tests/tests_all.sh @@ -0,0 +1,25 @@ +dialects=("postgres" "mysql" "mssql" "sqlite") + +if [[ $(pwd) == *"gorm/tests"* ]]; then + cd .. +fi + +for dialect in "${dialects[@]}" ; do + if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ] + then + if [ "$GORM_VERBOSE" = "" ] + then + cd dialects/${dialect} + DEBUG=false GORM_DIALECT=${dialect} go test -race ./... + cd ../.. + + DEBUG=false GORM_DIALECT=${dialect} go test -race ./... + else + cd dialects/${dialect} + DEBUG=false GORM_DIALECT=${dialect} go test -race ./... + cd ../.. + + DEBUG=false GORM_DIALECT=${dialect} go test -race -v ./... + fi + fi +done diff --git a/wercker.yml b/wercker.yml new file mode 100644 index 00000000..54d80be0 --- /dev/null +++ b/wercker.yml @@ -0,0 +1,132 @@ +# use the default golang container from Docker Hub +box: golang + +services: + - name: mariadb + id: mariadb:latest + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql + id: mysql:latest + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql57 + id: mysql:5.7 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql56 + id: mysql:5.6 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: postgres + id: postgres:latest + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres11 + id: postgres:11 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres10 + id: postgres:10 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: mssql + id: mcmoe/mssqldocker:latest + env: + ACCEPT_EULA: Y + SA_PASSWORD: LoremIpsum86 + MSSQL_DB: gorm + MSSQL_USER: gorm + MSSQL_PASSWORD: LoremIpsum86 + +# The steps that will be executed in the build pipeline +build: + # The steps that will be executed on build + steps: + # Sets the go workspace and places you package + # at the right place in the workspace tree + - setup-go-workspace + + # Gets the dependencies + - script: + name: go get + code: | + cd $WERCKER_SOURCE_DIR + go version + go get -t -v ./... + + # Build the project + - script: + name: go build + code: | + go build ./... + + # Test the project + - script: + name: test sqlite + code: | + GORM_DIALECT=sqlite $GORM_VERBOSE=true ./tests/tests_all.sh + + - script: + name: test mariadb + code: | + GORM_DIALECT=mysql $GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + + - script: + name: test mysql + code: | + GORM_DIALECT=mysql $GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + + - script: + name: test mysql5.7 + code: | + GORM_DIALECT=mysql $GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + + - script: + name: test mysql5.6 + code: | + GORM_DIALECT=mysql $GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + + - script: + name: test postgres + code: | + GORM_DIALECT=postgres $GORM_VERBOSE=true GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" ./tests/tests_all.sh + + - script: + name: test postgres11 + code: | + GORM_DIALECT=postgres $GORM_VERBOSE=true GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" ./tests/tests_all.sh + + - script: + name: test postgres10 + code: | + GORM_DIALECT=postgres $GORM_VERBOSE=true GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" ./tests/tests_all.sh + + - script: + name: test mssql + code: | + GORM_DIALECT=mssql $GORM_VERBOSE=true GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" ./tests/tests_all.sh + + - script: + name: codecov + code: | + go test -race -coverprofile=coverage.txt -covermode=atomic ./... + bash <(curl -s https://codecov.io/bash) From bc5ceff82ff17b72081cc40bb7711489312349c4 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Feb 2020 12:39:26 +0800 Subject: [PATCH 252/881] Explain SQL for dialects --- callbacks.go | 8 ++++ dialects/mssql/mssql.go | 8 ++++ dialects/mssql/mssql_test.go | 8 +++- dialects/mysql/mysql.go | 5 +++ dialects/mysql/mysql_test.go | 8 +++- dialects/postgres/postgres.go | 8 ++++ dialects/postgres/postgres_test.go | 8 +++- dialects/sqlite/sqlite.go | 5 +++ interfaces.go | 1 + logger/logger.go | 17 +++++--- logger/sql.go | 68 ++++++++++++++++++++++++++++++ logger/sql_test.go | 45 ++++++++++++++++++++ tests/tests_all.sh | 2 +- 13 files changed, 182 insertions(+), 9 deletions(-) create mode 100644 logger/sql.go create mode 100644 logger/sql_test.go diff --git a/callbacks.go b/callbacks.go index 4f19a681..41951168 100644 --- a/callbacks.go +++ b/callbacks.go @@ -3,6 +3,7 @@ package gorm import ( "errors" "fmt" + "time" "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/schema" @@ -69,6 +70,7 @@ func (cs *callbacks) Raw() *processor { } func (p *processor) Execute(db *DB) { + curTime := time.Now() if stmt := db.Statement; stmt != nil { if stmt.Model == nil { stmt.Model = stmt.Dest @@ -86,6 +88,12 @@ func (p *processor) Execute(db *DB) { for _, f := range p.fns { f(db) } + + if stmt := db.Statement; stmt != nil { + db.Logger.RunWith(logger.Info, func() { + db.Logger.Info(db.Dialector.Explain(stmt.SQL.String(), stmt.Vars)) + }) + } } func (p *processor) Get(name string) func(*DB) { diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 79e36385..b93cc8f6 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -3,11 +3,13 @@ package mssql import ( "database/sql" "fmt" + "regexp" "strconv" _ "github.com/denisenkom/go-mssqldb" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/schema" ) @@ -44,6 +46,12 @@ func (dialector Dialector) QuoteChars() [2]byte { return [2]byte{'"', '"'} // `name` } +var numericPlaceholder = regexp.MustCompile("@p(\\d+)") + +func (dialector Dialector) Explain(sql string, vars ...interface{}) string { + return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) +} + func (dialector Dialector) DataTypeOf(field *schema.Field) string { switch field.DataType { case schema.Bool: diff --git a/dialects/mssql/mssql_test.go b/dialects/mssql/mssql_test.go index b56e7369..49b3cd6a 100644 --- a/dialects/mssql/mssql_test.go +++ b/dialects/mssql/mssql_test.go @@ -2,6 +2,7 @@ package mssql_test import ( "fmt" + "os" "testing" "github.com/jinzhu/gorm" @@ -15,7 +16,12 @@ var ( ) func init() { - if DB, err = gorm.Open(mssql.Open("sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"), &gorm.Config{}); err != nil { + dsn := "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + if os.Getenv("GORM_DSN") != "" { + dsn = os.Getenv("GORM_DSN") + } + + if DB, err = gorm.Open(mssql.Open(dsn), &gorm.Config{}); err != nil { panic(fmt.Sprintf("failed to initialize database, got error %v", err)) } } diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index 629b89df..e1bf985a 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -8,6 +8,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/schema" ) @@ -42,6 +43,10 @@ func (dialector Dialector) QuoteChars() [2]byte { return [2]byte{'`', '`'} // `name` } +func (dialector Dialector) Explain(sql string, vars ...interface{}) string { + return logger.ExplainSQL(sql, nil, `"`, vars...) +} + func (dialector Dialector) DataTypeOf(field *schema.Field) string { switch field.DataType { case schema.Bool: diff --git a/dialects/mysql/mysql_test.go b/dialects/mysql/mysql_test.go index f079ad60..5bc1debd 100644 --- a/dialects/mysql/mysql_test.go +++ b/dialects/mysql/mysql_test.go @@ -2,6 +2,7 @@ package mysql_test import ( "fmt" + "os" "testing" "github.com/jinzhu/gorm" @@ -15,7 +16,12 @@ var ( ) func init() { - if DB, err = gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"), &gorm.Config{}); err != nil { + dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" + if os.Getenv("GORM_DSN") != "" { + dsn = os.Getenv("GORM_DSN") + } + + if DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}); err != nil { panic(fmt.Sprintf("failed to initialize database, got error %v", err)) } } diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index bb9726a8..3ee4ba9f 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -3,10 +3,12 @@ package postgres import ( "database/sql" "fmt" + "regexp" "strconv" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/schema" _ "github.com/lib/pq" @@ -44,6 +46,12 @@ func (dialector Dialector) QuoteChars() [2]byte { return [2]byte{'"', '"'} // "name" } +var numericPlaceholder = regexp.MustCompile("\\$(\\d+)") + +func (dialector Dialector) Explain(sql string, vars ...interface{}) string { + return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) +} + func (dialector Dialector) DataTypeOf(field *schema.Field) string { switch field.DataType { case schema.Bool: diff --git a/dialects/postgres/postgres_test.go b/dialects/postgres/postgres_test.go index 84c0fe53..a1252d92 100644 --- a/dialects/postgres/postgres_test.go +++ b/dialects/postgres/postgres_test.go @@ -2,6 +2,7 @@ package postgres_test import ( "fmt" + "os" "testing" "github.com/jinzhu/gorm" @@ -15,7 +16,12 @@ var ( ) func init() { - if DB, err = gorm.Open(postgres.Open("user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"), &gorm.Config{}); err != nil { + dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" + if os.Getenv("GORM_DSN") != "" { + dsn = os.Getenv("GORM_DSN") + } + + if DB, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}); err != nil { panic(fmt.Sprintf("failed to initialize database, got error %v", err)) } } diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 54fa7de0..a6aba066 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -5,6 +5,7 @@ import ( "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/schema" _ "github.com/mattn/go-sqlite3" @@ -41,6 +42,10 @@ func (dialector Dialector) QuoteChars() [2]byte { return [2]byte{'`', '`'} // `name` } +func (dialector Dialector) Explain(sql string, vars ...interface{}) string { + return logger.ExplainSQL(sql, nil, `"`, vars...) +} + func (dialector Dialector) DataTypeOf(field *schema.Field) string { switch field.DataType { case schema.Bool: diff --git a/interfaces.go b/interfaces.go index 8f0f3085..bf1aab46 100644 --- a/interfaces.go +++ b/interfaces.go @@ -14,6 +14,7 @@ type Dialector interface { DataTypeOf(*schema.Field) string BindVar(stmt *Statement, v interface{}) string QuoteChars() [2]byte + Explain(sql string, vars ...interface{}) string } // CommonDB common db interface diff --git a/logger/logger.go b/logger/logger.go index cad9be16..049b724d 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -11,9 +11,9 @@ type LogLevel int var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", log.LstdFlags)} const ( - Info LogLevel = iota + 1 + Error LogLevel = iota + 1 Warn - Error + Info ) // Interface logger interface @@ -22,6 +22,7 @@ type Interface interface { Info(string, ...interface{}) Warn(string, ...interface{}) Error(string, ...interface{}) + RunWith(LogLevel, func()) } // Writer log writer interface @@ -40,21 +41,27 @@ func (logger Logger) LogMode(level LogLevel) Interface { // Info print info func (logger Logger) Info(msg string, data ...interface{}) { - if logger.logLevel <= Info { + if logger.logLevel >= Info { logger.Print("[info] " + fmt.Sprintf(msg, data...)) } } // Warn print warn messages func (logger Logger) Warn(msg string, data ...interface{}) { - if logger.logLevel <= Warn { + if logger.logLevel >= Warn { logger.Print("[warn] " + fmt.Sprintf(msg, data...)) } } // Error print error messages func (logger Logger) Error(msg string, data ...interface{}) { - if logger.logLevel <= Error { + if logger.logLevel >= Error { logger.Print("[error] " + fmt.Sprintf(msg, data...)) } } + +func (logger Logger) RunWith(logLevel LogLevel, fc func()) { + if logger.logLevel >= logLevel { + fc() + } +} diff --git a/logger/sql.go b/logger/sql.go new file mode 100644 index 00000000..b0e11027 --- /dev/null +++ b/logger/sql.go @@ -0,0 +1,68 @@ +package logger + +import ( + "database/sql/driver" + "fmt" + "regexp" + "strconv" + "strings" + "time" + "unicode" +) + +func isPrintable(s []byte) bool { + for _, r := range s { + if !unicode.IsPrint(rune(r)) { + return false + } + } + return true +} + +func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, vars ...interface{}) string { + for idx, v := range vars { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + switch v := v.(type) { + case bool: + vars[idx] = fmt.Sprint(v) + case time.Time: + vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper + case *time.Time: + vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper + case []byte: + if isPrintable(v) { + vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper + } else { + vars[idx] = escaper + "" + escaper + } + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + vars[idx] = fmt.Sprintf("%d", v) + case float64, float32: + vars[idx] = fmt.Sprintf("%.6f", v) + case string: + vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper + default: + if v == nil { + vars[idx] = "NULL" + } else { + vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper + } + } + } + + if numericPlaceholder == nil { + for _, v := range vars { + sql = strings.Replace(sql, "?", v.(string), 1) + } + } else { + sql = numericPlaceholder.ReplaceAllString(sql, "$$$$$1") + for idx, v := range vars { + sql = strings.Replace(sql, "$$"+strconv.Itoa(idx), v.(string), 1) + } + } + + return sql +} diff --git a/logger/sql_test.go b/logger/sql_test.go new file mode 100644 index 00000000..d98e19b3 --- /dev/null +++ b/logger/sql_test.go @@ -0,0 +1,45 @@ +package logger_test + +import ( + "regexp" + "testing" + + "github.com/jinzhu/gorm/logger" + "github.com/jinzhu/now" +) + +func TestExplainSQL(t *testing.T) { + tt := now.MustParse("2020-02-23 11:10:10") + + results := []struct { + SQL string + NumericRegexp *regexp.Regexp + Vars []interface{} + Result string + }{ + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (@p0, @p1, @p2, @p3, @p4, @p5, @p6, @p7)", + NumericRegexp: regexp.MustCompile("@p(\\d+)"), + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ($2, $3, $0, $1, $6, $7, $4, $5)", + NumericRegexp: regexp.MustCompile("\\$(\\d+)"), + Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`, + }, + } + + for idx, r := range results { + if result := logger.ExplainSQL(r.SQL, r.NumericRegexp, `"`, r.Vars...); result != r.Result { + t.Errorf("Explain SQL #%v expects %v, but got %v", idx, r.Result, result) + } + } +} diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 91d415f1..cd42e1e0 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -1,4 +1,4 @@ -dialects=("postgres" "mysql" "mssql" "sqlite") +dialects=("sqlite" "mysql" "postgres" "mssql") if [[ $(pwd) == *"gorm/tests"* ]]; then cd .. From c3b798aec869da7b8c513c45e275a4310dfede31 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Feb 2020 13:22:08 +0800 Subject: [PATCH 253/881] Refactor SQL Explainer --- logger/sql.go | 33 ++++++++++++++++++++++++--------- logger/sql_test.go | 32 ++++++++++++++++++++++---------- 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index b0e11027..f63dc160 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -3,6 +3,7 @@ package logger import ( "database/sql/driver" "fmt" + "reflect" "regexp" "strconv" "strings" @@ -19,19 +20,17 @@ func isPrintable(s []byte) bool { return true } -func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, vars ...interface{}) string { - for idx, v := range vars { - if valuer, ok := v.(driver.Valuer); ok { - v, _ = valuer.Value() - } +var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} +func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, vars ...interface{}) string { + var convertParams func(interface{}, int) + + convertParams = func(v interface{}, idx int) { switch v := v.(type) { case bool: vars[idx] = fmt.Sprint(v) case time.Time: vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper - case *time.Time: - vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper case []byte: if isPrintable(v) { vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper @@ -48,19 +47,35 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v if v == nil { vars[idx] = "NULL" } else { + rv := reflect.Indirect(reflect.ValueOf(v)) + for _, t := range convertableTypes { + if rv.Type().ConvertibleTo(t) { + convertParams(rv.Convert(t).Interface(), idx) + return + } + } + vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper } } } + for idx, v := range vars { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + convertParams(v, idx) + } + if numericPlaceholder == nil { for _, v := range vars { sql = strings.Replace(sql, "?", v.(string), 1) } } else { - sql = numericPlaceholder.ReplaceAllString(sql, "$$$$$1") + sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") for idx, v := range vars { - sql = strings.Replace(sql, "$$"+strconv.Itoa(idx), v.(string), 1) + sql = strings.Replace(sql, "$"+strconv.Itoa(idx)+"$", v.(string), 1) } } diff --git a/logger/sql_test.go b/logger/sql_test.go index d98e19b3..829d6302 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -9,7 +9,13 @@ import ( ) func TestExplainSQL(t *testing.T) { - tt := now.MustParse("2020-02-23 11:10:10") + type role string + type password []byte + var ( + tt = now.MustParse("2020-02-23 11:10:10") + myrole = role("admin") + pwd = password([]byte("pass")) + ) results := []struct { SQL string @@ -18,22 +24,28 @@ func TestExplainSQL(t *testing.T) { Result string }{ { - SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (?, ?, ?, ?, ?, ?, ?, ?)", + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, - Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { - SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (@p0, @p1, @p2, @p3, @p4, @p5, @p6, @p7)", + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p0, @p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)", NumericRegexp: regexp.MustCompile("@p(\\d+)"), - Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { - SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ($2, $3, $0, $1, $6, $7, $4, $5)", + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($2, $3, $0, $1, $6, $7, $4, $5, $8, $9, $10)", NumericRegexp: regexp.MustCompile("\\$(\\d+)"), - Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`, + Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p0, @p10, @p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9)", + NumericRegexp: regexp.MustCompile("@p(\\d+)"), + Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, } From 27cb613871e07c1646033c7ef35590a0dfee4f0b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Feb 2020 15:07:47 +0800 Subject: [PATCH 254/881] Refactor logger --- callbacks.go | 6 +-- logger/logger.go | 110 +++++++++++++++++++++++++++++++-------- tests/dummy_dialecter.go | 5 ++ utils/utils.go | 4 +- 4 files changed, 97 insertions(+), 28 deletions(-) diff --git a/callbacks.go b/callbacks.go index 41951168..573d7a8e 100644 --- a/callbacks.go +++ b/callbacks.go @@ -90,9 +90,9 @@ func (p *processor) Execute(db *DB) { } if stmt := db.Statement; stmt != nil { - db.Logger.RunWith(logger.Info, func() { - db.Logger.Info(db.Dialector.Explain(stmt.SQL.String(), stmt.Vars)) - }) + db.Logger.Trace(curTime, func() (string, int64) { + return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars), db.RowsAffected + }, db.Error) } } diff --git a/logger/logger.go b/logger/logger.go index 049b724d..5656a86f 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -1,14 +1,29 @@ package logger import ( - "fmt" "log" "os" + "time" + + "github.com/jinzhu/gorm/utils" ) -type LogLevel int +// Colors +const ( + Reset = "\033[0m" + Red = "\033[31m" + Green = "\033[32m" + Yellow = "\033[33m" + Blue = "\033[34m" + Magenta = "\033[35m" + Cyan = "\033[36m" + White = "\033[37m" + Redbold = "\033[31;1m" + YellowBold = "\033[33;1m" +) -var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", log.LstdFlags)} +// LogLevel +type LogLevel int const ( Error LogLevel = iota + 1 @@ -16,52 +31,101 @@ const ( Info ) +// Writer log writer interface +type Writer interface { + Printf(string, ...interface{}) +} + +type Config struct { + SlowThreshold time.Duration + Colorful bool + LogLevel LogLevel +} + // Interface logger interface type Interface interface { LogMode(LogLevel) Interface Info(string, ...interface{}) Warn(string, ...interface{}) Error(string, ...interface{}) - RunWith(LogLevel, func()) + Trace(begin time.Time, fc func() (string, int64), err error) } -// Writer log writer interface -type Writer interface { - Print(...interface{}) +var Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ + SlowThreshold: 100 * time.Millisecond, + Colorful: true, +}) + +func New(writer Writer, config Config) Interface { + var ( + infoPrefix = "%s\n[info] " + warnPrefix = "%s\n[warn] " + errPrefix = "%s\n[error] " + tracePrefix = "%s\n[%v] [rows:%d] %s" + traceErrPrefix = "%s\n[%v] [rows:%d] %s" + ) + + if config.Colorful { + infoPrefix = Green + "%s\n" + Reset + Green + "[info]" + Reset + warnPrefix = Blue + "%s\n" + Reset + Magenta + "[warn]" + Reset + errPrefix = Magenta + "%s\n" + Reset + Red + "[error]" + Reset + tracePrefix = Green + "%s\n" + Reset + YellowBold + "[%.3fms] " + Green + "[rows:%d]" + Reset + " %s" + traceErrPrefix = Magenta + "%s\n" + Reset + Redbold + "[%.3fms] " + Yellow + "[rows:%d]" + Reset + " %s" + } + + return logger{ + Writer: writer, + Config: config, + infoPrefix: infoPrefix, + warnPrefix: warnPrefix, + errPrefix: errPrefix, + tracePrefix: tracePrefix, + traceErrPrefix: traceErrPrefix, + } } -type Logger struct { +type logger struct { Writer - logLevel LogLevel + Config + infoPrefix, warnPrefix, errPrefix string + tracePrefix, traceErrPrefix string } -func (logger Logger) LogMode(level LogLevel) Interface { - return Logger{Writer: logger.Writer, logLevel: level} +// LogMode log mode +func (l logger) LogMode(level LogLevel) Interface { + config := l.Config + config.LogLevel = level + return logger{Writer: l.Writer, Config: config} } // Info print info -func (logger Logger) Info(msg string, data ...interface{}) { - if logger.logLevel >= Info { - logger.Print("[info] " + fmt.Sprintf(msg, data...)) +func (l logger) Info(msg string, data ...interface{}) { + if l.LogLevel >= Info { + l.Printf(l.infoPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)) } } // Warn print warn messages -func (logger Logger) Warn(msg string, data ...interface{}) { - if logger.logLevel >= Warn { - logger.Print("[warn] " + fmt.Sprintf(msg, data...)) +func (l logger) Warn(msg string, data ...interface{}) { + if l.LogLevel >= Warn { + l.Printf(l.warnPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)) } } // Error print error messages -func (logger Logger) Error(msg string, data ...interface{}) { - if logger.logLevel >= Error { - logger.Print("[error] " + fmt.Sprintf(msg, data...)) +func (l logger) Error(msg string, data ...interface{}) { + if l.LogLevel >= Error { + l.Printf(l.errPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)) } } -func (logger Logger) RunWith(logLevel LogLevel, fc func()) { - if logger.logLevel >= logLevel { - fc() +// Trace print sql message +func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) { + if elapsed := time.Now().Sub(begin); err != nil || elapsed > l.SlowThreshold { + sql, rows := fc() + l.Printf(l.traceErrPrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + } else if l.LogLevel >= Info { + sql, rows := fc() + l.Printf(l.tracePrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) } } diff --git a/tests/dummy_dialecter.go b/tests/dummy_dialecter.go index b4e3361b..04d6248d 100644 --- a/tests/dummy_dialecter.go +++ b/tests/dummy_dialecter.go @@ -2,6 +2,7 @@ package tests import ( "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/schema" ) @@ -24,6 +25,10 @@ func (DummyDialector) QuoteChars() [2]byte { return [2]byte{'`', '`'} // `name` } +func (DummyDialector) Explain(sql string, vars ...interface{}) string { + return logger.ExplainSQL(sql, nil, `"`, vars...) +} + func (DummyDialector) DataTypeOf(*schema.Field) string { return "" } diff --git a/utils/utils.go b/utils/utils.go index 81ac8b30..315ba930 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -6,8 +6,8 @@ import ( "runtime" ) -var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) -var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`) +var goSrcRegexp = regexp.MustCompile(`/gorm/.*.go`) +var goTestRegexp = regexp.MustCompile(`/gorm/.*test.go`) func FileWithLineNum() string { for i := 2; i < 15; i++ { From 868ae052a1bb22309dcf8b8f6bd507c3ad849b02 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Feb 2020 15:16:40 +0800 Subject: [PATCH 255/881] Add escape sql params test --- logger/sql_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/logger/sql_test.go b/logger/sql_test.go index 829d6302..aee064d8 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -26,8 +26,8 @@ func TestExplainSQL(t *testing.T) { { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, - Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p0, @p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)", From fa22807e120606aca6f9da994f03bff5d2187a8a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Feb 2020 19:41:29 +0800 Subject: [PATCH 256/881] Make inesrt into db works --- callbacks.go | 2 +- callbacks/create.go | 58 ++++++++++++++++++++++++------------------ callbacks/query.go | 8 ++---- logger/logger.go | 23 +++++++++-------- logger/sql.go | 11 +++++++- schema/field.go | 2 +- schema/relationship.go | 4 +-- schema/schema.go | 15 ++++++----- statement.go | 32 +++++++++++++---------- tests/tests.go | 3 +++ 10 files changed, 92 insertions(+), 66 deletions(-) diff --git a/callbacks.go b/callbacks.go index 573d7a8e..3aed2d37 100644 --- a/callbacks.go +++ b/callbacks.go @@ -91,7 +91,7 @@ func (p *processor) Execute(db *DB) { if stmt := db.Statement; stmt != nil { db.Logger.Trace(curTime, func() (string, int64) { - return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars), db.RowsAffected + return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected }, db.Error) } } diff --git a/callbacks/create.go b/callbacks/create.go index 95afc854..3866ddb0 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -1,7 +1,6 @@ package callbacks import ( - "fmt" "reflect" "github.com/jinzhu/gorm" @@ -11,8 +10,6 @@ import ( func BeforeCreate(db *gorm.DB) { // before save // before create - - // assign timestamp } func SaveBeforeAssociations(db *gorm.DB) { @@ -22,16 +19,29 @@ func Create(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Insert{ Table: clause.Table{Name: db.Statement.Table}, }) - values, _ := ConvertToCreateValues(db.Statement) - db.Statement.AddClause(values) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) - fmt.Printf("%+v\n", values) - fmt.Println(err) - fmt.Println(result) - fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) + if err == nil { + if db.Statement.Schema != nil { + if insertID, err := result.LastInsertId(); err == nil { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID-- + } + case reflect.Struct: + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } + } + } + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } } func SaveAfterAssociations(db *gorm.DB) { @@ -43,19 +53,18 @@ func AfterCreate(db *gorm.DB) { } // ConvertToCreateValues convert to create values -func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]interface{}) { +func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { switch value := stmt.Dest.(type) { case map[string]interface{}: - return ConvertMapToValues(stmt, value), nil + return ConvertMapToValues(stmt, value) case []map[string]interface{}: - return ConvertSliceOfMapToValues(stmt, value), nil + return ConvertSliceOfMapToValues(stmt, value) default: var ( values = clause.Values{} selectColumns, restricted = SelectAndOmitColumns(stmt) curTime = stmt.DB.NowFunc() isZero = false - returnningValues []map[string]interface{} ) for _, db := range stmt.Schema.DBNames { @@ -66,13 +75,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in } } - reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Dest)) - switch reflectValue.Kind() { + switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - values.Values = make([][]interface{}, reflectValue.Len()) + values.Values = make([][]interface{}, stmt.ReflectValue.Len()) defaultValueFieldsHavingValue := map[string][]interface{}{} - for i := 0; i < reflectValue.Len(); i++ { - rv := reflect.Indirect(reflectValue.Index(i)) + for i := 0; i < stmt.ReflectValue.Len(); i++ { + rv := reflect.Indirect(stmt.ReflectValue.Index(i)) values.Values[i] = make([]interface{}, len(values.Columns)) for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] @@ -91,7 +99,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { if v, isZero := field.ValueOf(rv); !isZero { if len(defaultValueFieldsHavingValue[db]) == 0 { - defaultValueFieldsHavingValue[db] = make([]interface{}, reflectValue.Len()) + defaultValueFieldsHavingValue[db] = make([]interface{}, stmt.ReflectValue.Len()) } defaultValueFieldsHavingValue[db][i] = v } @@ -113,20 +121,20 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] - if values.Values[0][idx], isZero = field.ValueOf(reflectValue); isZero { + if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero { if field.DefaultValueInterface != nil { values.Values[0][idx] = field.DefaultValueInterface - field.Set(reflectValue, field.DefaultValueInterface) + field.Set(stmt.ReflectValue, field.DefaultValueInterface) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - field.Set(reflectValue, curTime) - values.Values[0][idx], _ = field.ValueOf(reflectValue) + field.Set(stmt.ReflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) } } } for db, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { - if v, isZero := field.ValueOf(reflectValue); !isZero { + if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero { values.Columns = append(values.Columns, clause.Column{Name: db}) values.Values[0] = append(values.Values[0], v) } @@ -134,6 +142,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in } } - return values, returnningValues + return values } } diff --git a/callbacks/query.go b/callbacks/query.go index a4ed3adb..195709fe 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -1,8 +1,6 @@ package callbacks import ( - "fmt" - "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" ) @@ -15,10 +13,8 @@ func Query(db *gorm.DB) { db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } - result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) - fmt.Println(err) - fmt.Println(result) - fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) + rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + db.AddError(err) } func Preload(db *gorm.DB) { diff --git a/logger/logger.go b/logger/logger.go index 5656a86f..568ddd57 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -66,9 +66,9 @@ func New(writer Writer, config Config) Interface { ) if config.Colorful { - infoPrefix = Green + "%s\n" + Reset + Green + "[info]" + Reset - warnPrefix = Blue + "%s\n" + Reset + Magenta + "[warn]" + Reset - errPrefix = Magenta + "%s\n" + Reset + Red + "[error]" + Reset + infoPrefix = Green + "%s\n" + Reset + Green + "[info] " + Reset + warnPrefix = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset + errPrefix = Magenta + "%s\n" + Reset + Red + "[error] " + Reset tracePrefix = Green + "%s\n" + Reset + YellowBold + "[%.3fms] " + Green + "[rows:%d]" + Reset + " %s" traceErrPrefix = Magenta + "%s\n" + Reset + Redbold + "[%.3fms] " + Yellow + "[rows:%d]" + Reset + " %s" } @@ -93,29 +93,28 @@ type logger struct { // LogMode log mode func (l logger) LogMode(level LogLevel) Interface { - config := l.Config - config.LogLevel = level - return logger{Writer: l.Writer, Config: config} + l.LogLevel = level + return l } // Info print info func (l logger) Info(msg string, data ...interface{}) { if l.LogLevel >= Info { - l.Printf(l.infoPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)) + l.Printf(l.infoPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Warn print warn messages func (l logger) Warn(msg string, data ...interface{}) { if l.LogLevel >= Warn { - l.Printf(l.warnPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)) + l.Printf(l.warnPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Error print error messages func (l logger) Error(msg string, data ...interface{}) { if l.LogLevel >= Error { - l.Printf(l.errPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)) + l.Printf(l.errPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } @@ -123,7 +122,11 @@ func (l logger) Error(msg string, data ...interface{}) { func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) { if elapsed := time.Now().Sub(begin); err != nil || elapsed > l.SlowThreshold { sql, rows := fc() - l.Printf(l.traceErrPrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + fileline := utils.FileWithLineNum() + if err != nil { + fileline += " " + err.Error() + } + l.Printf(l.traceErrPrefix, fileline, float64(elapsed.Nanoseconds())/1e6, rows, sql) } else if l.LogLevel >= Info { sql, rows := fc() l.Printf(l.tracePrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) diff --git a/logger/sql.go b/logger/sql.go index f63dc160..eec72d47 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -30,7 +30,11 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v case bool: vars[idx] = fmt.Sprint(v) case time.Time: - vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper + if v.IsZero() { + vars[idx] = escaper + "0000-00-00 00:00:00" + escaper + } else { + vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper + } case []byte: if isPrintable(v) { vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper @@ -48,6 +52,11 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v vars[idx] = "NULL" } else { rv := reflect.Indirect(reflect.ValueOf(v)) + if !rv.IsValid() { + vars[idx] = "NULL" + return + } + for _, t := range convertableTypes { if rv.Type().ConvertibleTo(t) { convertParams(rv.Convert(t).Interface(), idx) diff --git a/schema/field.go b/schema/field.go index ea4e6a40..f640ec3b 100644 --- a/schema/field.go +++ b/schema/field.go @@ -235,7 +235,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { var err error field.Creatable = false field.Updatable = false - if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { + if field.EmbeddedSchema, _, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { schema.err = err } for _, ef := range field.EmbeddedSchema.Fields { diff --git a/schema/relationship.go b/schema/relationship.go index 4ffea8b3..3b9d692a 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -65,7 +65,7 @@ func (schema *Schema) parseRelation(field *Field) { } ) - if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { + if relation.FieldSchema, _, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { schema.err = err return } @@ -192,7 +192,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } } - if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { + if relation.JoinTable, _, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { schema.err = err } relation.JoinTable.Name = many2many diff --git a/schema/schema.go b/schema/schema.go index acf6ff52..c3ac2bd9 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -48,21 +48,22 @@ func (schema Schema) LookUpField(name string) *Field { } // get data type from dialector -func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { - modelType := reflect.ValueOf(dest).Type() +func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflect.Value, error) { + reflectValue := reflect.ValueOf(dest) + modelType := reflectValue.Type() for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } if modelType.Kind() != reflect.Struct { if modelType.PkgPath() == "" { - return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + return nil, reflectValue, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + return nil, reflectValue, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } if v, ok := cacheStore.Load(modelType); ok { - return v.(*Schema), nil + return v.(*Schema), reflectValue, nil } schema := &Schema{ @@ -167,10 +168,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) for _, field := range schema.Fields { if field.DataType == "" && field.Creatable { if schema.parseRelation(field); schema.err != nil { - return schema, schema.err + return schema, reflectValue, schema.err } } } - return schema, schema.err + return schema, reflectValue, schema.err } diff --git a/statement.go b/statement.go index d486a1c7..91f45b2b 100644 --- a/statement.go +++ b/statement.go @@ -5,6 +5,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "reflect" "strconv" "strings" "sync" @@ -32,22 +33,23 @@ func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) { func (inst *Instance) AddError(err error) { if inst.Error == nil { inst.Error = err - } else { + } else if err != nil { inst.Error = fmt.Errorf("%v; %w", inst.Error, err) } } // Statement statement type Statement struct { - Table string - Model interface{} - Dest interface{} - Clauses map[string]clause.Clause - Selects []string // selected columns - Omits []string // omit columns - Settings sync.Map - DB *DB - Schema *schema.Schema + Table string + Model interface{} + Dest interface{} + ReflectValue reflect.Value + Clauses map[string]clause.Clause + Selects []string // selected columns + Omits []string // omit columns + Settings sync.Map + DB *DB + Schema *schema.Schema // SQL Builder SQL strings.Builder @@ -197,7 +199,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { // BuildCondtion build condition func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) { if sql, ok := query.(string); ok { - if i, err := strconv.Atoi(sql); err != nil { + if i, err := strconv.Atoi(sql); err == nil { query = i } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} @@ -272,8 +274,12 @@ func (stmt *Statement) Build(clauses ...string) { } func (stmt *Statement) Parse(value interface{}) (err error) { - if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { - stmt.Table = stmt.Schema.Table + if stmt.Schema, stmt.ReflectValue, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + stmt.ReflectValue = reflect.Indirect(stmt.ReflectValue) + + if stmt.Table == "" { + stmt.Table = stmt.Schema.Table + } } return err } diff --git a/tests/tests.go b/tests/tests.go index b3246a79..53700710 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -17,6 +17,9 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) { } func TestCreate(t *testing.T, db *gorm.DB) { + db.AutoMigrate(&User{}) + db = db.Debug() + t.Run("Create", func(t *testing.T) { var user = User{ Name: "create", From e2a360b9faa72efb3f35f3edca4ed6e293d9185e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Feb 2020 21:22:35 +0800 Subject: [PATCH 257/881] Add Before/After callbacks --- callbacks/create.go | 64 ++++++++++++++++++++++++++++++++++--- callbacks/delete.go | 50 ++++++++++++++++++++++++++++- callbacks/query.go | 27 ++++++++++++++-- callbacks/update.go | 66 ++++++++++++++++++++++++++++++++++++++- clause/benchmarks_test.go | 4 +-- clause/clause_test.go | 2 +- clause/expression_test.go | 2 +- interfaces.go | 36 +++++++++++++++++++++ schema/callbacks_test.go | 38 ++++++++++++++++++++++ schema/check_test.go | 2 +- schema/field_test.go | 24 +++++++------- schema/index_test.go | 2 +- schema/schema.go | 45 +++++++++++++++++--------- schema/schema_test.go | 6 ++-- 14 files changed, 325 insertions(+), 43 deletions(-) create mode 100644 schema/callbacks_test.go diff --git a/callbacks/create.go b/callbacks/create.go index 3866ddb0..2e1b3381 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -8,8 +8,36 @@ import ( ) func BeforeCreate(db *gorm.DB) { - // before save - // before create + if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { + callMethod := func(value interface{}) bool { + var ok bool + if db.Statement.Schema.BeforeSave { + if i, ok := value.(gorm.BeforeSaveInterface); ok { + ok = true + i.BeforeSave(db) + } + } + + if db.Statement.Schema.BeforeCreate { + if i, ok := value.(gorm.BeforeCreateInterface); ok { + ok = true + i.BeforeCreate(db) + } + } + return ok + } + + if ok := callMethod(db.Statement.Dest); !ok { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + callMethod(db.Statement.ReflectValue.Index(i).Interface()) + } + case reflect.Struct: + callMethod(db.Statement.ReflectValue.Interface()) + } + } + } } func SaveBeforeAssociations(db *gorm.DB) { @@ -48,8 +76,36 @@ func SaveAfterAssociations(db *gorm.DB) { } func AfterCreate(db *gorm.DB) { - // after save - // after create + if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { + callMethod := func(value interface{}) bool { + var ok bool + if db.Statement.Schema.AfterSave { + if i, ok := value.(gorm.AfterSaveInterface); ok { + ok = true + i.AfterSave(db) + } + } + + if db.Statement.Schema.AfterCreate { + if i, ok := value.(gorm.AfterCreateInterface); ok { + ok = true + i.AfterCreate(db) + } + } + return ok + } + + if ok := callMethod(db.Statement.Dest); !ok { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + callMethod(db.Statement.ReflectValue.Index(i).Interface()) + } + case reflect.Struct: + callMethod(db.Statement.ReflectValue.Interface()) + } + } + } } // ConvertToCreateValues convert to create values diff --git a/callbacks/delete.go b/callbacks/delete.go index 96c392f2..d79f88fc 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -1,12 +1,60 @@ package callbacks -import "github.com/jinzhu/gorm" +import ( + "reflect" + + "github.com/jinzhu/gorm" +) func BeforeDelete(db *gorm.DB) { + if db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { + callMethod := func(value interface{}) bool { + if db.Statement.Schema.BeforeDelete { + if i, ok := value.(gorm.BeforeDeleteInterface); ok { + i.BeforeDelete(db) + return true + } + } + return false + } + + if ok := callMethod(db.Statement.Dest); !ok { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + callMethod(db.Statement.ReflectValue.Index(i).Interface()) + } + case reflect.Struct: + callMethod(db.Statement.ReflectValue.Interface()) + } + } + } } func Delete(db *gorm.DB) { } func AfterDelete(db *gorm.DB) { + if db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { + callMethod := func(value interface{}) bool { + if db.Statement.Schema.AfterDelete { + if i, ok := value.(gorm.AfterDeleteInterface); ok { + i.AfterDelete(db) + return true + } + } + return false + } + + if ok := callMethod(db.Statement.Dest); !ok { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + callMethod(db.Statement.ReflectValue.Index(i).Interface()) + } + case reflect.Struct: + callMethod(db.Statement.ReflectValue.Interface()) + } + } + } } diff --git a/callbacks/query.go b/callbacks/query.go index 195709fe..d8785057 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -1,6 +1,8 @@ package callbacks import ( + "reflect" + "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" ) @@ -13,7 +15,7 @@ func Query(db *gorm.DB) { db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } - rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + _, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.AddError(err) } @@ -21,5 +23,26 @@ func Preload(db *gorm.DB) { } func AfterQuery(db *gorm.DB) { - // after find + if db.Statement.Schema != nil && db.Statement.Schema.AfterFind { + callMethod := func(value interface{}) bool { + if db.Statement.Schema.AfterFind { + if i, ok := value.(gorm.AfterFindInterface); ok { + i.AfterFind(db) + return true + } + } + return false + } + + if ok := callMethod(db.Statement.Dest); !ok { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + callMethod(db.Statement.ReflectValue.Index(i).Interface()) + } + case reflect.Struct: + callMethod(db.Statement.ReflectValue.Interface()) + } + } + } } diff --git a/callbacks/update.go b/callbacks/update.go index 8e504403..82df3e81 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -1,12 +1,76 @@ package callbacks -import "github.com/jinzhu/gorm" +import ( + "reflect" + + "github.com/jinzhu/gorm" +) func BeforeUpdate(db *gorm.DB) { + if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { + callMethod := func(value interface{}) bool { + var ok bool + if db.Statement.Schema.BeforeSave { + if i, ok := value.(gorm.BeforeSaveInterface); ok { + ok = true + i.BeforeSave(db) + } + } + + if db.Statement.Schema.BeforeUpdate { + if i, ok := value.(gorm.BeforeUpdateInterface); ok { + ok = true + i.BeforeUpdate(db) + } + } + return ok + } + + if ok := callMethod(db.Statement.Dest); !ok { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + callMethod(db.Statement.ReflectValue.Index(i).Interface()) + } + case reflect.Struct: + callMethod(db.Statement.ReflectValue.Interface()) + } + } + } } func Update(db *gorm.DB) { } func AfterUpdate(db *gorm.DB) { + if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { + callMethod := func(value interface{}) bool { + var ok bool + if db.Statement.Schema.AfterSave { + if i, ok := value.(gorm.AfterSaveInterface); ok { + ok = true + i.AfterSave(db) + } + } + + if db.Statement.Schema.AfterUpdate { + if i, ok := value.(gorm.AfterUpdateInterface); ok { + ok = true + i.AfterUpdate(db) + } + } + return ok + } + + if ok := callMethod(db.Statement.Dest); !ok { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + callMethod(db.Statement.ReflectValue.Index(i).Interface()) + } + case reflect.Struct: + callMethod(db.Statement.ReflectValue.Interface()) + } + } + } } diff --git a/clause/benchmarks_test.go b/clause/benchmarks_test.go index 33d3430a..3813fd8e 100644 --- a/clause/benchmarks_test.go +++ b/clause/benchmarks_test.go @@ -11,7 +11,7 @@ import ( ) func BenchmarkSelect(b *testing.B) { - user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) for i := 0; i < b.N; i++ { stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} @@ -27,7 +27,7 @@ func BenchmarkSelect(b *testing.B) { } func BenchmarkComplexSelect(b *testing.B) { - user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) for i := 0; i < b.N; i++ { stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} diff --git a/clause/clause_test.go b/clause/clause_test.go index 30ea9343..8e458043 100644 --- a/clause/clause_test.go +++ b/clause/clause_test.go @@ -18,7 +18,7 @@ func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, var ( buildNames []string buildNamesMap = map[string]bool{} - user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + user, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} ) diff --git a/clause/expression_test.go b/clause/expression_test.go index e51d189e..363b4047 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -24,7 +24,7 @@ func TestExpr(t *testing.T) { for idx, result := range results { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { - user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) if stmt.SQL.String() != result.Result { diff --git a/interfaces.go b/interfaces.go index bf1aab46..21563b7d 100644 --- a/interfaces.go +++ b/interfaces.go @@ -24,3 +24,39 @@ type CommonDB interface { QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row } + +type BeforeCreateInterface interface { + BeforeCreate(*DB) +} + +type AfterCreateInterface interface { + AfterCreate(*DB) +} + +type BeforeUpdateInterface interface { + BeforeUpdate(*DB) +} + +type AfterUpdateInterface interface { + AfterUpdate(*DB) +} + +type BeforeSaveInterface interface { + BeforeSave(*DB) +} + +type AfterSaveInterface interface { + AfterSave(*DB) +} + +type BeforeDeleteInterface interface { + BeforeDelete(*DB) +} + +type AfterDeleteInterface interface { + AfterDelete(*DB) +} + +type AfterFindInterface interface { + AfterFind(*DB) +} diff --git a/schema/callbacks_test.go b/schema/callbacks_test.go new file mode 100644 index 00000000..34c0e687 --- /dev/null +++ b/schema/callbacks_test.go @@ -0,0 +1,38 @@ +package schema_test + +import ( + "reflect" + "sync" + "testing" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/schema" +) + +type UserWithCallback struct { +} + +func (UserWithCallback) BeforeSave(*gorm.DB) { +} + +func (UserWithCallback) AfterCreate(*gorm.DB) { +} + +func TestCallback(t *testing.T) { + user, _, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user with callback, got error %v", err) + } + + for _, str := range []string{"BeforeSave", "AfterCreate"} { + if !reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) { + t.Errorf("%v should be true", str) + } + } + + for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} { + if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) { + t.Errorf("%v should be false", str) + } + } +} diff --git a/schema/check_test.go b/schema/check_test.go index e4bc9ebe..f0ba553c 100644 --- a/schema/check_test.go +++ b/schema/check_test.go @@ -15,7 +15,7 @@ type UserCheck struct { } func TestParseCheck(t *testing.T) { - user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) + user, _, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user check, got error %v", err) } diff --git a/schema/field_test.go b/schema/field_test.go index 15dfa41d..02e6aec0 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -14,8 +14,8 @@ import ( func TestFieldValuerAndSetter(t *testing.T) { var ( - userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) - user = tests.User{ + userSchema, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) + user = tests.User{ Model: gorm.Model{ ID: 10, CreatedAt: time.Now(), @@ -81,11 +81,11 @@ func TestFieldValuerAndSetter(t *testing.T) { func TestPointerFieldValuerAndSetter(t *testing.T) { var ( - userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) - name = "pointer_field_valuer_and_setter" - age uint = 18 - active = true - user = User{ + userSchema, _, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + name = "pointer_field_valuer_and_setter" + age uint = 18 + active = true + user = User{ Model: &gorm.Model{ ID: 10, CreatedAt: time.Now(), @@ -151,11 +151,11 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { var ( - userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) - name = "advanced_data_type_valuer_and_setter" - deletedAt = mytime(time.Now()) - isAdmin = mybool(false) - user = AdvancedDataTypeUser{ + userSchema, _, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) + name = "advanced_data_type_valuer_and_setter" + deletedAt = mytime(time.Now()) + isAdmin = mybool(false) + user = AdvancedDataTypeUser{ ID: sql.NullInt64{Int64: 10, Valid: true}, Name: &sql.NullString{String: name, Valid: true}, Birthday: sql.NullTime{Time: time.Now(), Valid: true}, diff --git a/schema/index_test.go b/schema/index_test.go index d0e8dfe0..03d75b97 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -19,7 +19,7 @@ type UserIndex struct { } func TestParseIndex(t *testing.T) { - user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) + user, _, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user index, got error %v", err) } diff --git a/schema/schema.go b/schema/schema.go index c3ac2bd9..c56932ad 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -14,20 +14,25 @@ import ( var ErrUnsupportedDataType = errors.New("unsupported data type") type Schema struct { - Name string - ModelType reflect.Type - Table string - PrioritizedPrimaryField *Field - DBNames []string - PrimaryFields []*Field - Fields []*Field - FieldsByName map[string]*Field - FieldsByDBName map[string]*Field - FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database - Relationships Relationships - err error - namer Namer - cacheStore *sync.Map + Name string + ModelType reflect.Type + Table string + PrioritizedPrimaryField *Field + DBNames []string + PrimaryFields []*Field + Fields []*Field + FieldsByName map[string]*Field + FieldsByDBName map[string]*Field + FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database + Relationships Relationships + BeforeCreate, AfterCreate bool + BeforeUpdate, AfterUpdate bool + BeforeDelete, AfterDelete bool + BeforeSave, AfterSave bool + AfterFind bool + err error + namer Namer + cacheStore *sync.Map } func (schema Schema) String() string { @@ -162,6 +167,18 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec } } + callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} + for _, name := range callbacks { + if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() { + switch methodValue.Type().String() { + case "func(*gorm.DB)": // TODO hack + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) + default: + logger.Default.Warn("Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name) + } + } + } + cacheStore.Store(modelType, schema) // parse relations for unidentified fields diff --git a/schema/schema_test.go b/schema/schema_test.go index ce225010..04cd9d82 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -9,7 +9,7 @@ import ( ) func TestParseSchema(t *testing.T) { - user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) + user, _, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user, got error %v", err) } @@ -18,7 +18,7 @@ func TestParseSchema(t *testing.T) { } func TestParseSchemaWithPointerFields(t *testing.T) { - user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + user, _, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse pointer user, got error %v", err) } @@ -114,7 +114,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { } func TestParseSchemaWithAdvancedDataType(t *testing.T) { - user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) + user, _, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse pointer user, got error %v", err) } From 5ccd76f76cf21722289615333a0b2a8615d95ed9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Feb 2020 23:28:35 +0800 Subject: [PATCH 258/881] Setup Transaction --- association.go | 4 ++++ callbacks/query.go | 5 +++-- finisher_api.go | 56 +++++++++++++++++++++++++++++++++------------- interfaces.go | 9 ++++++++ logger/logger.go | 1 + 5 files changed, 57 insertions(+), 18 deletions(-) diff --git a/association.go b/association.go index 17f8f4a5..14bc54b6 100644 --- a/association.go +++ b/association.go @@ -3,3 +3,7 @@ package gorm // Association Mode contains some helper methods to handle relationship things easily. type Association struct { } + +func (db *DB) Association(column string) *Association { + return nil +} diff --git a/callbacks/query.go b/callbacks/query.go index d8785057..baacbd24 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -11,12 +11,13 @@ func Query(db *gorm.DB) { if db.Statement.SQL.String() == "" { db.Statement.AddClauseIfNotExists(clause.Select{}) db.Statement.AddClauseIfNotExists(clause.From{}) - db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } - _, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.AddError(err) + _ = rows + // scan rows } func Preload(db *gorm.DB) { diff --git a/finisher_api.go b/finisher_api.go index 2c5d4f65..72c3d2aa 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -23,6 +23,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { // First find first record that match given conditions, order by primary key func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { + // TODO handle where tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Desc: true, @@ -35,12 +36,18 @@ func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { // Take return a record that match given conditions, the order will depend on the database implementation func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.Dest = out + tx.callbacks.Query().Execute(tx) return } // Last find last record that match given conditions, order by primary key func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) { - tx = db.getInstance() + tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + tx.Statement.Dest = out + tx.callbacks.Query().Execute(tx) return } @@ -88,21 +95,12 @@ func (db *DB) Delete(value interface{}, where ...interface{}) (tx *DB) { return } -func (db *DB) Related(value interface{}, foreignKeys ...string) (tx *DB) { - tx = db.getInstance() - return -} - //Preloads only preloads relations, don`t touch out func (db *DB) Preloads(out interface{}) (tx *DB) { tx = db.getInstance() return } -func (db *DB) Association(column string) *Association { - return nil -} - func (db *DB) Count(value interface{}) (tx *DB) { tx = db.getInstance() return @@ -130,6 +128,7 @@ func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error { return nil } +// Transaction start a transaction as a block, return error will rollback, otherwise to commit. func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { panicked := true tx := db.Begin(opts...) @@ -150,21 +149,46 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er return } +// Begin begins a transaction func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { tx = db.getInstance() + if beginner, ok := tx.DB.(TxBeginner); ok { + var opt *sql.TxOptions + var err error + if len(opts) > 0 { + opt = opts[0] + } + + if tx.DB, err = beginner.BeginTx(db.Context, opt); err != nil { + tx.AddError(err) + } + } else { + tx.AddError(ErrInvalidTransaction) + } return } -func (db *DB) Commit() (tx *DB) { - tx = db.getInstance() - return +// Commit commit a transaction +func (db *DB) Commit() *DB { + if comminter, ok := db.DB.(TxCommiter); ok && comminter != nil { + db.AddError(comminter.Commit()) + } else { + db.AddError(ErrInvalidTransaction) + } + return db } -func (db *DB) Rollback() (tx *DB) { - tx = db.getInstance() - return +// Rollback rollback a transaction +func (db *DB) Rollback() *DB { + if comminter, ok := db.DB.(TxCommiter); ok && comminter != nil { + db.AddError(comminter.Rollback()) + } else { + db.AddError(ErrInvalidTransaction) + } + return db } +// Exec execute raw sql func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.SQL = strings.Builder{} diff --git a/interfaces.go b/interfaces.go index 21563b7d..f0d14dd8 100644 --- a/interfaces.go +++ b/interfaces.go @@ -25,6 +25,15 @@ type CommonDB interface { QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row } +type TxBeginner interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +type TxCommiter interface { + Commit() error + Rollback() error +} + type BeforeCreateInterface interface { BeforeCreate(*DB) } diff --git a/logger/logger.go b/logger/logger.go index 568ddd57..d3b97b9d 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -53,6 +53,7 @@ type Interface interface { var Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ SlowThreshold: 100 * time.Millisecond, + LogLevel: Warn, Colorful: true, }) From 04adbaf7f6fcacc5adde7a66649537cdccab74fd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 24 Feb 2020 08:51:35 +0800 Subject: [PATCH 259/881] Fix parse stmt ReflectValue --- callbacks.go | 6 +++--- logger/sql.go | 2 +- schema/callbacks_test.go | 2 +- schema/check_test.go | 2 +- schema/field.go | 2 +- schema/field_test.go | 24 ++++++++++++------------ schema/index_test.go | 2 +- schema/relationship.go | 4 ++-- schema/schema.go | 16 ++++++++-------- schema/schema_test.go | 6 +++--- statement.go | 8 ++------ 11 files changed, 35 insertions(+), 39 deletions(-) diff --git a/callbacks.go b/callbacks.go index 3aed2d37..db8261c4 100644 --- a/callbacks.go +++ b/callbacks.go @@ -3,6 +3,7 @@ package gorm import ( "errors" "fmt" + "reflect" "time" "github.com/jinzhu/gorm/logger" @@ -77,12 +78,11 @@ func (p *processor) Execute(db *DB) { } if stmt.Model != nil { - err := stmt.Parse(stmt.Model) - - if err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") { + if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") { db.AddError(err) } } + stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest)) } for _, f := range p.fns { diff --git a/logger/sql.go b/logger/sql.go index eec72d47..cb50ccf6 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -84,7 +84,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v } else { sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") for idx, v := range vars { - sql = strings.Replace(sql, "$"+strconv.Itoa(idx)+"$", v.(string), 1) + sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v.(string), 1) } } diff --git a/schema/callbacks_test.go b/schema/callbacks_test.go index 34c0e687..720c9a5b 100644 --- a/schema/callbacks_test.go +++ b/schema/callbacks_test.go @@ -19,7 +19,7 @@ func (UserWithCallback) AfterCreate(*gorm.DB) { } func TestCallback(t *testing.T) { - user, _, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{}) + user, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user with callback, got error %v", err) } diff --git a/schema/check_test.go b/schema/check_test.go index f0ba553c..e4bc9ebe 100644 --- a/schema/check_test.go +++ b/schema/check_test.go @@ -15,7 +15,7 @@ type UserCheck struct { } func TestParseCheck(t *testing.T) { - user, _, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) + user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user check, got error %v", err) } diff --git a/schema/field.go b/schema/field.go index f640ec3b..ea4e6a40 100644 --- a/schema/field.go +++ b/schema/field.go @@ -235,7 +235,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { var err error field.Creatable = false field.Updatable = false - if field.EmbeddedSchema, _, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { + if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { schema.err = err } for _, ef := range field.EmbeddedSchema.Fields { diff --git a/schema/field_test.go b/schema/field_test.go index 02e6aec0..15dfa41d 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -14,8 +14,8 @@ import ( func TestFieldValuerAndSetter(t *testing.T) { var ( - userSchema, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) - user = tests.User{ + userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) + user = tests.User{ Model: gorm.Model{ ID: 10, CreatedAt: time.Now(), @@ -81,11 +81,11 @@ func TestFieldValuerAndSetter(t *testing.T) { func TestPointerFieldValuerAndSetter(t *testing.T) { var ( - userSchema, _, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) - name = "pointer_field_valuer_and_setter" - age uint = 18 - active = true - user = User{ + userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + name = "pointer_field_valuer_and_setter" + age uint = 18 + active = true + user = User{ Model: &gorm.Model{ ID: 10, CreatedAt: time.Now(), @@ -151,11 +151,11 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { var ( - userSchema, _, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) - name = "advanced_data_type_valuer_and_setter" - deletedAt = mytime(time.Now()) - isAdmin = mybool(false) - user = AdvancedDataTypeUser{ + userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) + name = "advanced_data_type_valuer_and_setter" + deletedAt = mytime(time.Now()) + isAdmin = mybool(false) + user = AdvancedDataTypeUser{ ID: sql.NullInt64{Int64: 10, Valid: true}, Name: &sql.NullString{String: name, Valid: true}, Birthday: sql.NullTime{Time: time.Now(), Valid: true}, diff --git a/schema/index_test.go b/schema/index_test.go index 03d75b97..d0e8dfe0 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -19,7 +19,7 @@ type UserIndex struct { } func TestParseIndex(t *testing.T) { - user, _, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) + user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user index, got error %v", err) } diff --git a/schema/relationship.go b/schema/relationship.go index 3b9d692a..4ffea8b3 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -65,7 +65,7 @@ func (schema *Schema) parseRelation(field *Field) { } ) - if relation.FieldSchema, _, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { + if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { schema.err = err return } @@ -192,7 +192,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } } - if relation.JoinTable, _, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { schema.err = err } relation.JoinTable.Name = many2many diff --git a/schema/schema.go b/schema/schema.go index c56932ad..2ac6d312 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -53,22 +53,21 @@ func (schema Schema) LookUpField(name string) *Field { } // get data type from dialector -func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflect.Value, error) { - reflectValue := reflect.ValueOf(dest) - modelType := reflectValue.Type() +func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + modelType := reflect.ValueOf(dest).Type() for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } if modelType.Kind() != reflect.Struct { if modelType.PkgPath() == "" { - return nil, reflectValue, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - return nil, reflectValue, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } if v, ok := cacheStore.Load(modelType); ok { - return v.(*Schema), reflectValue, nil + return v.(*Schema), nil } schema := &Schema{ @@ -167,6 +166,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec } } + reflectValue := reflect.Indirect(reflect.New(modelType)) callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} for _, name := range callbacks { if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() { @@ -185,10 +185,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec for _, field := range schema.Fields { if field.DataType == "" && field.Creatable { if schema.parseRelation(field); schema.err != nil { - return schema, reflectValue, schema.err + return schema, schema.err } } } - return schema, reflectValue, schema.err + return schema, schema.err } diff --git a/schema/schema_test.go b/schema/schema_test.go index 04cd9d82..ce225010 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -9,7 +9,7 @@ import ( ) func TestParseSchema(t *testing.T) { - user, _, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) + user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user, got error %v", err) } @@ -18,7 +18,7 @@ func TestParseSchema(t *testing.T) { } func TestParseSchemaWithPointerFields(t *testing.T) { - user, _, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse pointer user, got error %v", err) } @@ -114,7 +114,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { } func TestParseSchemaWithAdvancedDataType(t *testing.T) { - user, _, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) + user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse pointer user, got error %v", err) } diff --git a/statement.go b/statement.go index 91f45b2b..ad30ed08 100644 --- a/statement.go +++ b/statement.go @@ -274,12 +274,8 @@ func (stmt *Statement) Build(clauses ...string) { } func (stmt *Statement) Parse(value interface{}) (err error) { - if stmt.Schema, stmt.ReflectValue, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { - stmt.ReflectValue = reflect.Indirect(stmt.ReflectValue) - - if stmt.Table == "" { - stmt.Table = stmt.Schema.Table - } + if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { + stmt.Table = stmt.Schema.Table } return err } From 9fcc546a69d014a81a5c459879f2a1ce80c4c97f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 26 Feb 2020 19:06:42 +0800 Subject: [PATCH 260/881] Fix tests --- clause/benchmarks_test.go | 4 ++-- clause/clause_test.go | 2 +- clause/expression_test.go | 2 +- logger/sql_test.go | 6 +++--- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/clause/benchmarks_test.go b/clause/benchmarks_test.go index 3813fd8e..33d3430a 100644 --- a/clause/benchmarks_test.go +++ b/clause/benchmarks_test.go @@ -11,7 +11,7 @@ import ( ) func BenchmarkSelect(b *testing.B) { - user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) for i := 0; i < b.N; i++ { stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} @@ -27,7 +27,7 @@ func BenchmarkSelect(b *testing.B) { } func BenchmarkComplexSelect(b *testing.B) { - user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) for i := 0; i < b.N; i++ { stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} diff --git a/clause/clause_test.go b/clause/clause_test.go index 8e458043..30ea9343 100644 --- a/clause/clause_test.go +++ b/clause/clause_test.go @@ -18,7 +18,7 @@ func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, var ( buildNames []string buildNamesMap = map[string]bool{} - user, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} ) diff --git a/clause/expression_test.go b/clause/expression_test.go index 363b4047..e51d189e 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -24,7 +24,7 @@ func TestExpr(t *testing.T) { for idx, result := range results { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { - user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) if stmt.SQL.String() != result.Result { diff --git a/logger/sql_test.go b/logger/sql_test.go index aee064d8..dd7b80c8 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -30,19 +30,19 @@ func TestExplainSQL(t *testing.T) { Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, }, { - SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p0, @p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)", + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", NumericRegexp: regexp.MustCompile("@p(\\d+)"), Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { - SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($2, $3, $0, $1, $6, $7, $4, $5, $8, $9, $10)", + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)", NumericRegexp: regexp.MustCompile("\\$(\\d+)"), Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { - SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p0, @p10, @p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9)", + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)", NumericRegexp: regexp.MustCompile("@p(\\d+)"), Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, From 0da8191f60660e4d9ebffdb84ad8aeda46235862 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 2 Mar 2020 23:43:34 +0800 Subject: [PATCH 261/881] Update test helper --- tests/utils.go | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/utils.go b/tests/utils.go index d12df2dc..292a357d 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -3,6 +3,7 @@ package tests import ( "reflect" "testing" + "time" ) func AssertEqual(t *testing.T, r, e interface{}, names ...string) { @@ -11,9 +12,18 @@ func AssertEqual(t *testing.T, r, e interface{}, names ...string) { expects := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() if !reflect.DeepEqual(got, expects) { - t.Run(name, func(t *testing.T) { - t.Errorf("expects: %v, got %v", expects, got) - }) + got = reflect.Indirect(reflect.ValueOf(got)).Interface() + expects = reflect.Indirect(reflect.ValueOf(got)).Interface() + if curTime, ok := got.(time.Time); ok { + format := "2006-01-02T15:04:05Z07:00" + if curTime.Format(format) != expects.(time.Time).Format(format) { + t.Errorf("expects: %v, got %v", expects.(time.Time).Format(format), curTime.Format(format)) + } + } else { + t.Run(name, func(t *testing.T) { + t.Errorf("expects: %v, got %v", expects, got) + }) + } } } } From 1403ee70c33bc455168af57bc32839ec2cd4d9ac Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 3 Mar 2020 14:18:12 +0800 Subject: [PATCH 262/881] Make Query works --- callbacks/query.go | 29 ++++++++++++++++++++++++++--- dialects/sqlite/sqlite.go | 4 +++- finisher_api.go | 7 ++++++- statement.go | 21 +++++++++++---------- tests/tests.go | 1 - 5 files changed, 46 insertions(+), 16 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index baacbd24..21b58aaf 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -1,6 +1,7 @@ package callbacks import ( + "database/sql" "reflect" "github.com/jinzhu/gorm" @@ -15,9 +16,31 @@ func Query(db *gorm.DB) { } rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) - db.AddError(err) - _ = rows - // scan rows + if err != nil { + db.AddError(err) + return + } + defer rows.Close() + + columns, _ := rows.Columns() + values := make([]interface{}, len(columns)) + + for idx, column := range columns { + if field, ok := db.Statement.Schema.FieldsByDBName[column]; ok { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + } else { + values[idx] = sql.RawBytes{} + } + } + + for rows.Next() { + db.RowsAffected++ + rows.Scan(values...) + } + + if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound { + db.AddError(gorm.ErrRecordNotFound) + } } func Preload(db *gorm.DB) { diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index a6aba066..5f9d49df 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -59,8 +59,10 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { } case schema.Float: return "real" - case schema.String, schema.Time: + case schema.String: return "text" + case schema.Time: + return "datetime" case schema.Bytes: return "blob" } diff --git a/finisher_api.go b/finisher_api.go index 72c3d2aa..83988546 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -28,6 +28,7 @@ func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Desc: true, }) + tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out tx.callbacks.Query().Execute(tx) return @@ -35,7 +36,8 @@ func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { // Take return a record that match given conditions, the order will depend on the database implementation func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) { - tx = db.getInstance() + tx = db.getInstance().Limit(1) + tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out tx.callbacks.Query().Execute(tx) return @@ -46,6 +48,7 @@ func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) + tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out tx.callbacks.Query().Execute(tx) return @@ -54,6 +57,8 @@ func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) { // Find find records that match given conditions func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.Dest = out + tx.callbacks.Query().Execute(tx) return } diff --git a/statement.go b/statement.go index ad30ed08..bad83717 100644 --- a/statement.go +++ b/statement.go @@ -40,16 +40,17 @@ func (inst *Instance) AddError(err error) { // Statement statement type Statement struct { - Table string - Model interface{} - Dest interface{} - ReflectValue reflect.Value - Clauses map[string]clause.Clause - Selects []string // selected columns - Omits []string // omit columns - Settings sync.Map - DB *DB - Schema *schema.Schema + Table string + Model interface{} + Dest interface{} + ReflectValue reflect.Value + Clauses map[string]clause.Clause + Selects []string // selected columns + Omits []string // omit columns + Settings sync.Map + DB *DB + Schema *schema.Schema + RaiseErrorOnNotFound bool // SQL Builder SQL strings.Builder diff --git a/tests/tests.go b/tests/tests.go index 53700710..5e47c09e 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -18,7 +18,6 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) { func TestCreate(t *testing.T, db *gorm.DB) { db.AutoMigrate(&User{}) - db = db.Debug() t.Run("Create", func(t *testing.T) { var user = User{ From b0e1bccf4ad5f803df27a8974491bcbc04a4b02c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 4 Mar 2020 11:32:36 +0800 Subject: [PATCH 263/881] Support scan into map, slice, struct --- callbacks/query.go | 21 +------- callbacks/scan.go | 98 ++++++++++++++++++++++++++++++++++++ finisher_api.go | 2 +- schema/schema_helper_test.go | 40 ++------------- tests/tests.go | 93 +++++++++++++++++++++++++++++++++- tests/utils.go | 41 +++++++++++---- utils/utils.go | 4 +- 7 files changed, 228 insertions(+), 71 deletions(-) create mode 100644 callbacks/scan.go diff --git a/callbacks/query.go b/callbacks/query.go index 21b58aaf..26c0e0ad 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -1,7 +1,6 @@ package callbacks import ( - "database/sql" "reflect" "github.com/jinzhu/gorm" @@ -22,25 +21,7 @@ func Query(db *gorm.DB) { } defer rows.Close() - columns, _ := rows.Columns() - values := make([]interface{}, len(columns)) - - for idx, column := range columns { - if field, ok := db.Statement.Schema.FieldsByDBName[column]; ok { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() - } else { - values[idx] = sql.RawBytes{} - } - } - - for rows.Next() { - db.RowsAffected++ - rows.Scan(values...) - } - - if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound { - db.AddError(gorm.ErrRecordNotFound) - } + Scan(rows, db) } func Preload(db *gorm.DB) { diff --git a/callbacks/scan.go b/callbacks/scan.go new file mode 100644 index 00000000..c9f948b1 --- /dev/null +++ b/callbacks/scan.go @@ -0,0 +1,98 @@ +package callbacks + +import ( + "database/sql" + "reflect" + + "github.com/jinzhu/gorm" +) + +func Scan(rows *sql.Rows, db *gorm.DB) { + columns, _ := rows.Columns() + values := make([]interface{}, len(columns)) + + switch dest := db.Statement.Dest.(type) { + case map[string]interface{}, *map[string]interface{}: + for idx, _ := range columns { + values[idx] = new(interface{}) + } + + if rows.Next() { + db.RowsAffected++ + rows.Scan(values...) + } + + mapValue, ok := dest.(map[string]interface{}) + if ok { + if v, ok := dest.(*map[string]interface{}); ok { + mapValue = *v + } + } + + for idx, column := range columns { + mapValue[column] = *(values[idx].(*interface{})) + } + case *[]map[string]interface{}: + for idx, _ := range columns { + values[idx] = new(interface{}) + } + + for rows.Next() { + db.RowsAffected++ + rows.Scan(values...) + + v := map[string]interface{}{} + for idx, column := range columns { + v[column] = *(values[idx].(*interface{})) + } + *dest = append(*dest, v) + } + default: + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + isPtr := db.Statement.ReflectValue.Type().Elem().Kind() == reflect.Ptr + db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) + + for rows.Next() { + elem := reflect.New(db.Statement.Schema.ModelType).Elem() + for idx, column := range columns { + if field := db.Statement.Schema.LookUpField(column); field != nil { + values[idx] = field.ReflectValueOf(elem).Addr().Interface() + } else if db.RowsAffected == 0 { + values[idx] = sql.RawBytes{} + } + } + + db.RowsAffected++ + if err := rows.Scan(values...); err != nil { + db.AddError(err) + } + + if isPtr { + db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr())) + } else { + db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem)) + } + } + case reflect.Struct: + for idx, column := range columns { + if field := db.Statement.Schema.LookUpField(column); field != nil { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + } else { + values[idx] = sql.RawBytes{} + } + } + + if rows.Next() { + db.RowsAffected++ + if err := rows.Scan(values...); err != nil { + db.AddError(err) + } + } + } + } + + if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound { + db.AddError(gorm.ErrRecordNotFound) + } +} diff --git a/finisher_api.go b/finisher_api.go index 83988546..c918c08a 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -26,7 +26,6 @@ func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { // TODO handle where tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, - Desc: true, }) tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out @@ -47,6 +46,7 @@ func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) { func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Desc: true, }) tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 196d19c4..146ba13a 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -1,7 +1,6 @@ package schema_test import ( - "database/sql/driver" "fmt" "reflect" "strings" @@ -13,7 +12,7 @@ import ( func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { t.Run("CheckSchema/"+s.Name, func(t *testing.T) { - tests.AssertEqual(t, s, v, "Name", "Table") + tests.AssertObjEqual(t, s, v, "Name", "Table") for idx, field := range primaryFields { var found bool @@ -53,7 +52,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* if parsedField, ok := s.FieldsByName[f.Name]; !ok { t.Errorf("schema %v failed to look up field with name %v", s, f.Name) } else { - tests.AssertEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") + tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) @@ -195,39 +194,8 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { for k, v := range values { t.Run("CheckField/"+k, func(t *testing.T) { - var ( - checker func(fv interface{}, v interface{}) - field = s.FieldsByDBName[k] - fv, _ = field.ValueOf(value) - ) - - checker = func(fv interface{}, v interface{}) { - if reflect.ValueOf(fv).Type() == reflect.ValueOf(v).Type() && fv != v { - t.Errorf("expects: %p, but got %p", v, fv) - } else if reflect.ValueOf(v).Type().ConvertibleTo(reflect.ValueOf(fv).Type()) { - if reflect.ValueOf(v).Convert(reflect.ValueOf(fv).Type()).Interface() != fv { - t.Errorf("expects: %p, but got %p", v, fv) - } - } else if reflect.ValueOf(fv).Type().ConvertibleTo(reflect.ValueOf(v).Type()) { - if reflect.ValueOf(fv).Convert(reflect.ValueOf(fv).Type()).Interface() != v { - t.Errorf("expects: %p, but got %p", v, fv) - } - } else if valuer, isValuer := fv.(driver.Valuer); isValuer { - valuerv, _ := valuer.Value() - checker(valuerv, v) - } else if valuer, isValuer := v.(driver.Valuer); isValuer { - valuerv, _ := valuer.Value() - checker(fv, valuerv) - } else if reflect.ValueOf(fv).Kind() == reflect.Ptr { - checker(reflect.ValueOf(fv).Elem().Interface(), v) - } else if reflect.ValueOf(v).Kind() == reflect.Ptr { - checker(fv, reflect.ValueOf(v).Elem().Interface()) - } else { - t.Errorf("expects: %+v, but got %+v", v, fv) - } - } - - checker(fv, v) + fv, _ := s.FieldsByDBName[k].ValueOf(value) + tests.AssertEqual(t, v, fv) }) } } diff --git a/tests/tests.go b/tests/tests.go index 5e47c09e..2f0dfd34 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -1,6 +1,9 @@ package tests import ( + "log" + "reflect" + "strconv" "testing" "time" @@ -14,6 +17,7 @@ func Now() *time.Time { func RunTestsSuit(t *testing.T, db *gorm.DB) { TestCreate(t, db) + TestFind(t, db) } func TestCreate(t *testing.T, db *gorm.DB) { @@ -38,7 +42,94 @@ func TestCreate(t *testing.T, db *gorm.DB) { if err := db.Where("id = ?", user.ID).First(&newUser).Error; err != nil { t.Errorf("errors happened when query: %v", err) } else { - AssertEqual(t, newUser, user, "Name", "Age", "Birthday") + AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday") + } + }) +} + +func TestFind(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&User{}) + db.AutoMigrate(&User{}) + + t.Run("Find", func(t *testing.T) { + var users = []User{{ + Name: "find", + Age: 1, + Birthday: Now(), + }, { + Name: "find", + Age: 2, + Birthday: Now(), + }, { + Name: "find", + Age: 3, + Birthday: Now(), + }} + + if err := db.Create(&users).Error; err != nil { + t.Errorf("errors happened when create users: %v", err) + } + + t.Run("First", func(t *testing.T) { + var first User + if err := db.Where("name = ?", "find").First(&first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + AssertObjEqual(t, first, users[0], "Name", "Age", "Birthday") + } + }) + + t.Run("Last", func(t *testing.T) { + var last User + if err := db.Where("name = ?", "find").Last(&last).Error; err != nil { + t.Errorf("errors happened when query last: %v", err) + } else { + AssertObjEqual(t, last, users[2], "Name", "Age", "Birthday") + } + }) + + var all []User + if err := db.Where("name = ?", "find").Find(&all).Error; err != nil || len(all) != 3 { + t.Errorf("errors happened when query find: %v, length: %v", err, len(all)) + } else { + for idx, user := range users { + t.Run("FindAll#"+strconv.Itoa(idx+1), func(t *testing.T) { + AssertObjEqual(t, all[idx], user, "Name", "Age", "Birthday") + }) + } + } + + t.Run("FirstMap", func(t *testing.T) { + var first = map[string]interface{}{} + if err := db.Model(&User{}).Where("name = ?", "find").First(first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := db.NamingStrategy.ColumnName("", name) + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) + AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) + }) + } + } + }) + + var allMap = []map[string]interface{}{} + if err := db.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + log.Printf("all map %+v %+v", len(allMap), allMap) + for idx, user := range users { + t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := db.NamingStrategy.ColumnName("", name) + reflectValue := reflect.Indirect(reflect.ValueOf(user)) + AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) + }) + } + }) + } } }) } diff --git a/tests/utils.go b/tests/utils.go index 292a357d..9d61c422 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -6,24 +6,43 @@ import ( "time" ) -func AssertEqual(t *testing.T, r, e interface{}, names ...string) { +func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { for _, name := range names { got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() - expects := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() + expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() + t.Run(name, func(t *testing.T) { + AssertEqual(t, got, expect) + }) + } +} - if !reflect.DeepEqual(got, expects) { - got = reflect.Indirect(reflect.ValueOf(got)).Interface() - expects = reflect.Indirect(reflect.ValueOf(got)).Interface() +func AssertEqual(t *testing.T, got, expect interface{}) { + if !reflect.DeepEqual(got, expect) { + isEqual := func() { if curTime, ok := got.(time.Time); ok { format := "2006-01-02T15:04:05Z07:00" - if curTime.Format(format) != expects.(time.Time).Format(format) { - t.Errorf("expects: %v, got %v", expects.(time.Time).Format(format), curTime.Format(format)) + if curTime.Format(format) != expect.(time.Time).Format(format) { + t.Errorf("expect: %v, got %v", expect.(time.Time).Format(format), curTime.Format(format)) } - } else { - t.Run(name, func(t *testing.T) { - t.Errorf("expects: %v, got %v", expects, got) - }) + } else if got != expect { + t.Errorf("expect: %#v, got %#v", expect, got) } } + + if got != nil { + got = reflect.Indirect(reflect.ValueOf(got)).Interface() + } + + if expect != nil { + expect = reflect.Indirect(reflect.ValueOf(expect)).Interface() + } + + if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { + got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() + isEqual() + } else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { + expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() + isEqual() + } } } diff --git a/utils/utils.go b/utils/utils.go index 315ba930..86ea557b 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -6,8 +6,8 @@ import ( "runtime" ) -var goSrcRegexp = regexp.MustCompile(`/gorm/.*.go`) -var goTestRegexp = regexp.MustCompile(`/gorm/.*test.go`) +var goSrcRegexp = regexp.MustCompile(`/gorm/.*\.go`) +var goTestRegexp = regexp.MustCompile(`/gorm/.*test\.go`) func FileWithLineNum() string { for i := 2; i < 15; i++ { From 9f7f4b430ea438e4427bb0c20f036d06aeabea08 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 4 Mar 2020 22:16:39 +0800 Subject: [PATCH 264/881] Refactor find slice --- callbacks/scan.go | 12 ++++++++---- logger/logger.go | 2 +- tests/docker-compose.yml | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/callbacks/scan.go b/callbacks/scan.go index c9f948b1..f8f1ef54 100644 --- a/callbacks/scan.go +++ b/callbacks/scan.go @@ -5,6 +5,7 @@ import ( "reflect" "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/schema" ) func Scan(rows *sql.Rows, db *gorm.DB) { @@ -52,14 +53,17 @@ func Scan(rows *sql.Rows, db *gorm.DB) { case reflect.Slice, reflect.Array: isPtr := db.Statement.ReflectValue.Type().Elem().Kind() == reflect.Ptr db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) + fields := make([]*schema.Field, len(columns)) + + for idx, column := range columns { + fields[idx] = db.Statement.Schema.LookUpField(column) + } for rows.Next() { elem := reflect.New(db.Statement.Schema.ModelType).Elem() - for idx, column := range columns { - if field := db.Statement.Schema.LookUpField(column); field != nil { + for idx, field := range fields { + if field != nil { values[idx] = field.ReflectValueOf(elem).Addr().Interface() - } else if db.RowsAffected == 0 { - values[idx] = sql.RawBytes{} } } diff --git a/logger/logger.go b/logger/logger.go index d3b97b9d..2a765628 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -121,7 +121,7 @@ func (l logger) Error(msg string, data ...interface{}) { // Trace print sql message func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) { - if elapsed := time.Now().Sub(begin); err != nil || elapsed > l.SlowThreshold { + if elapsed := time.Now().Sub(begin); err != nil || (elapsed > l.SlowThreshold && l.SlowThreshold != 0) { sql, rows := fc() fileline := utils.FileWithLineNum() if err != nil { diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 79bf5fc3..6bf3fadf 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -15,8 +15,8 @@ services: ports: - 9920:5432 environment: - - POSTGRES_USER=gorm - POSTGRES_DB=gorm + - POSTGRES_USER=gorm - POSTGRES_PASSWORD=gorm mssql: image: 'mcmoe/mssqldocker:latest' From 0c34123796a056335e9020f7db97c514f3d1e87f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 4 Mar 2020 23:56:42 +0800 Subject: [PATCH 265/881] Add Limit, Offset --- chainable_api.go | 6 ++++-- clause/limit.go | 12 ++++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 770b2236..49f260d3 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -168,14 +168,16 @@ func (db *DB) Order(value interface{}) (tx *DB) { } // Limit specify the number of records to be retrieved -func (db *DB) Limit(limit int64) (tx *DB) { +func (db *DB) Limit(limit int) (tx *DB) { tx = db.getInstance() + tx.Statement.AddClause(clause.Limit{Limit: limit}) return } // Offset specify the number of records to skip before starting to return the records -func (db *DB) Offset(offset int64) (tx *DB) { +func (db *DB) Offset(offset int) (tx *DB) { tx = db.getInstance() + tx.Statement.AddClause(clause.Limit{Offset: offset}) return } diff --git a/clause/limit.go b/clause/limit.go index 7b16f339..7775e6bf 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -18,11 +18,11 @@ func (limit Limit) Build(builder Builder) { if limit.Limit > 0 { builder.Write("LIMIT ") builder.Write(strconv.Itoa(limit.Limit)) + } - if limit.Offset > 0 { - builder.Write(" OFFSET ") - builder.Write(strconv.Itoa(limit.Offset)) - } + if limit.Offset > 0 { + builder.Write(" OFFSET ") + builder.Write(strconv.Itoa(limit.Offset)) } } @@ -33,10 +33,14 @@ func (limit Limit) MergeClause(clause *Clause) { if v, ok := clause.Expression.(Limit); ok { if limit.Limit == 0 && v.Limit > 0 { limit.Limit = v.Limit + } else if limit.Limit < 0 { + limit.Limit = 0 } if limit.Offset == 0 && v.Offset > 0 { limit.Offset = v.Offset + } else if limit.Offset < 0 { + limit.Offset = 0 } } From cbd55dbcd53ec368465d8fdbdba383f8285406ac Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 7 Mar 2020 13:43:20 +0800 Subject: [PATCH 266/881] Add Update test --- callbacks/helper.go | 3 ++- callbacks/update.go | 58 +++++++++++++++++++++++++++++++++++++++++++ clause/limit.go | 8 +++--- finisher_api.go | 29 ++++++++++++++++++---- tests/tests.go | 60 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 148 insertions(+), 10 deletions(-) diff --git a/callbacks/helper.go b/callbacks/helper.go index 56c0767d..baad2302 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -44,13 +44,14 @@ func ConvertMapToValues(stmt *gorm.Statement, mapValue map[string]interface{}) ( sort.Strings(keys) for _, k := range keys { + value := mapValue[k] if field := stmt.Schema.LookUpField(k); field != nil { k = field.DBName } if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { columns = append(columns, k) - values.Values[0] = append(values.Values[0], mapValue[k]) + values.Values[0] = append(values.Values[0], value) } } return diff --git a/callbacks/update.go b/callbacks/update.go index 82df3e81..9e1e9b78 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -2,8 +2,10 @@ package callbacks import ( "reflect" + "sort" "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" ) func BeforeUpdate(db *gorm.DB) { @@ -40,6 +42,17 @@ func BeforeUpdate(db *gorm.DB) { } func Update(db *gorm.DB) { + db.Statement.AddClauseIfNotExists(clause.Update{}) + db.Statement.AddClause(ConvertToAssignments(db.Statement)) + db.Statement.Build("UPDATE", "SET", "WHERE") + + result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } } func AfterUpdate(db *gorm.DB) { @@ -74,3 +87,48 @@ func AfterUpdate(db *gorm.DB) { } } } + +// ConvertToAssignments convert to update assignments +func ConvertToAssignments(stmt *gorm.Statement) clause.Set { + selectColumns, restricted := SelectAndOmitColumns(stmt) + reflectModelValue := reflect.ValueOf(stmt.Model) + + switch value := stmt.Dest.(type) { + case map[string]interface{}: + var set clause.Set = make([]clause.Assignment, 0, len(value)) + + var keys []string + for k, _ := range value { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + if field := stmt.Schema.LookUpField(k); field != nil { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) + field.Set(reflectModelValue, value[k]) + } + } else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]}) + } + } + + return set + default: + switch stmt.ReflectValue.Kind() { + case reflect.Struct: + var set clause.Set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) + for _, field := range stmt.Schema.FieldsByDBName { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + value, _ := field.ValueOf(stmt.ReflectValue) + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) + field.Set(reflectModelValue, value) + } + } + return set + } + } + + return clause.Set{} +} diff --git a/clause/limit.go b/clause/limit.go index 7775e6bf..e30666af 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -18,11 +18,11 @@ func (limit Limit) Build(builder Builder) { if limit.Limit > 0 { builder.Write("LIMIT ") builder.Write(strconv.Itoa(limit.Limit)) - } - if limit.Offset > 0 { - builder.Write(" OFFSET ") - builder.Write(strconv.Itoa(limit.Offset)) + if limit.Offset > 0 { + builder.Write(" OFFSET ") + builder.Write(strconv.Itoa(limit.Offset)) + } } } diff --git a/finisher_api.go b/finisher_api.go index c918c08a..e2f89cf0 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -22,11 +22,13 @@ func (db *DB) Save(value interface{}) (tx *DB) { } // First find first record that match given conditions, order by primary key -func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { - // TODO handle where +func (db *DB) First(out interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) + if len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) + } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out tx.callbacks.Query().Execute(tx) @@ -34,8 +36,11 @@ func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { } // Take return a record that match given conditions, the order will depend on the database implementation -func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) { +func (db *DB) Take(out interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1) + if len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) + } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out tx.callbacks.Query().Execute(tx) @@ -43,11 +48,14 @@ func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) { } // Last find last record that match given conditions, order by primary key -func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) { +func (db *DB) Last(out interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Desc: true, }) + if len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) + } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out tx.callbacks.Query().Execute(tx) @@ -55,8 +63,11 @@ func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) { } // Find find records that match given conditions -func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) { +func (db *DB) Find(out interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() + if len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) + } tx.Statement.Dest = out tx.callbacks.Query().Execute(tx) return @@ -75,22 +86,30 @@ func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (db *DB) Update(column string, value interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.Dest = map[string]interface{}{column: value} + tx.callbacks.Update().Execute(tx) return } // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (db *DB) Updates(values interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.Dest = values + tx.callbacks.Update().Execute(tx) return } func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.Dest = map[string]interface{}{column: value} + tx.callbacks.Update().Execute(tx) return } func (db *DB) UpdateColumns(values interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.Dest = values + tx.callbacks.Update().Execute(tx) return } diff --git a/tests/tests.go b/tests/tests.go index 2f0dfd34..18207268 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -18,6 +18,7 @@ func Now() *time.Time { func RunTestsSuit(t *testing.T, db *gorm.DB) { TestCreate(t, db) TestFind(t, db) + TestUpdate(t, db) } func TestCreate(t *testing.T, db *gorm.DB) { @@ -133,3 +134,62 @@ func TestFind(t *testing.T, db *gorm.DB) { } }) } + +func TestUpdate(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&User{}) + db.AutoMigrate(&User{}) + + t.Run("Update", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + } + + if err := db.Create(&user).Error; err != nil { + t.Errorf("errors happened when create: %v", err) + } + + if err := db.Model(&user).Update("Age", 10).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 10 { + t.Errorf("Age should equals to 10, but got %v", user.Age) + } + + var result User + if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, result, user, "Name", "Age", "Birthday") + } + + values := map[string]interface{}{"Active": true, "age": 5} + if err := db.Model(&user).Updates(values).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 5 { + t.Errorf("Age should equals to 5, but got %v", user.Age) + } else if user.Active != true { + t.Errorf("Active should be true, but got %v", user.Active) + } + + var result2 User + if err := db.Where("id = ?", user.ID).First(&result2).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, result2, user, "Name", "Age", "Birthday") + } + + if err := db.Model(&user).Updates(User{Age: 2}).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 2 { + t.Errorf("Age should equals to 2, but got %v", user.Age) + } + + var result3 User + if err := db.Where("id = ?", user.ID).First(&result3).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, result3, user, "Name", "Age", "Birthday") + } + }) +} From 2da0ad5beda714bf4971d66ae58abb72ff6b38d1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 8 Mar 2020 13:24:08 +0800 Subject: [PATCH 267/881] Add more tests for Update --- callbacks/helper.go | 7 ++++ callbacks/update.go | 50 +++++++++++++++++++----- finisher_api.go | 21 ++++++++++ schema/field.go | 32 ++++++++-------- tests/tests.go | 93 ++++++++++++++++++++++++++++++++++++++++----- utils/utils.go | 4 +- 6 files changed, 169 insertions(+), 38 deletions(-) diff --git a/callbacks/helper.go b/callbacks/helper.go index baad2302..433ab346 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -13,6 +13,13 @@ func SelectAndOmitColumns(stmt *gorm.Statement) (map[string]bool, bool) { // select columns for _, column := range stmt.Selects { + if column == "*" { + for _, dbName := range stmt.Schema.DBNames { + results[dbName] = true + } + return results, true + } + if field := stmt.Schema.LookUpField(column); field != nil { results[field.DBName] = true } else { diff --git a/callbacks/update.go b/callbacks/update.go index 9e1e9b78..ca31bf18 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -3,6 +3,7 @@ package callbacks import ( "reflect" "sort" + "time" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" @@ -89,13 +90,13 @@ func AfterUpdate(db *gorm.DB) { } // ConvertToAssignments convert to update assignments -func ConvertToAssignments(stmt *gorm.Statement) clause.Set { +func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { selectColumns, restricted := SelectAndOmitColumns(stmt) reflectModelValue := reflect.ValueOf(stmt.Model) switch value := stmt.Dest.(type) { case map[string]interface{}: - var set clause.Set = make([]clause.Assignment, 0, len(value)) + set = make([]clause.Assignment, 0, len(value)) var keys []string for k, _ := range value { @@ -106,6 +107,9 @@ func ConvertToAssignments(stmt *gorm.Statement) clause.Set { for _, k := range keys { if field := stmt.Schema.LookUpField(k); field != nil { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if field.AutoUpdateTime > 0 { + value[k] = time.Now() + } set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) field.Set(reflectModelValue, value[k]) } @@ -114,21 +118,47 @@ func ConvertToAssignments(stmt *gorm.Statement) clause.Set { } } - return set + for _, field := range stmt.Schema.FieldsByDBName { + if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { + now := time.Now() + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) + field.Set(reflectModelValue, now) + } + } default: switch stmt.ReflectValue.Kind() { case reflect.Struct: - var set clause.Set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) + set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) for _, field := range stmt.Schema.FieldsByDBName { - if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - value, _ := field.ValueOf(stmt.ReflectValue) - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) - field.Set(reflectModelValue, value) + if !field.PrimaryKey || stmt.Dest != stmt.Model { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + value, isZero := field.ValueOf(stmt.ReflectValue) + if field.AutoUpdateTime > 0 { + value = time.Now() + isZero = false + } + + if ok || !isZero { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) + field.Set(reflectModelValue, value) + } + } + } else { + if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + } } } - return set } } - return clause.Set{} + if stmt.Dest != stmt.Model { + reflectValue := reflect.ValueOf(stmt.Model) + for _, field := range stmt.Schema.PrimaryFields { + if value, isZero := field.ValueOf(reflectValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + } + } + } + return } diff --git a/finisher_api.go b/finisher_api.go index e2f89cf0..0b729cc9 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -2,6 +2,7 @@ package gorm import ( "database/sql" + "reflect" "strings" "github.com/jinzhu/gorm/clause" @@ -18,6 +19,26 @@ func (db *DB) Create(value interface{}) (tx *DB) { // Save update value in database, if the value doesn't have primary key, will insert it func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.Dest = value + + if err := tx.Statement.Parse(value); err != nil && tx.Statement.Schema != nil { + where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} + reflectValue := reflect.ValueOf(value) + for idx, pf := range tx.Statement.Schema.PrimaryFields { + if pv, isZero := pf.ValueOf(reflectValue); isZero { + tx.callbacks.Create().Execute(tx) + where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} + return + } + } + + tx.Statement.AddClause(where) + } + + if len(tx.Statement.Selects) == 0 { + tx.Statement.Selects = []string{"*"} + } + tx.callbacks.Update().Execute(tx) return } diff --git a/schema/field.go b/schema/field.go index ea4e6a40..c6de669d 100644 --- a/schema/field.go +++ b/schema/field.go @@ -164,22 +164,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DBDataType = val } - if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int)) { - if strings.ToUpper(v) == "NANO" { - field.AutoCreateTime = UnixNanosecond - } else { - field.AutoCreateTime = UnixSecond - } - } - - if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int)) { - if strings.ToUpper(v) == "NANO" { - field.AutoUpdateTime = UnixNanosecond - } else { - field.AutoUpdateTime = UnixSecond - } - } - switch fieldValue.Elem().Kind() { case reflect.Bool: field.DataType = Bool @@ -218,6 +202,22 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int)) { + if strings.ToUpper(v) == "NANO" { + field.AutoCreateTime = UnixNanosecond + } else { + field.AutoCreateTime = UnixSecond + } + } + + if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int)) { + if strings.ToUpper(v) == "NANO" { + field.AutoUpdateTime = UnixNanosecond + } else { + field.AutoUpdateTime = UnixSecond + } + } + if field.Size == 0 { switch fieldValue.Kind() { case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: diff --git a/tests/tests.go b/tests/tests.go index 18207268..4181ad46 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -1,7 +1,6 @@ package tests import ( - "log" "reflect" "strconv" "testing" @@ -22,6 +21,7 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) { } func TestCreate(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&User{}) db.AutoMigrate(&User{}) t.Run("Create", func(t *testing.T) { @@ -39,6 +39,14 @@ func TestCreate(t *testing.T, db *gorm.DB) { t.Errorf("user's primary key should has value after create, got : %v", user.ID) } + if user.CreatedAt.IsZero() { + t.Errorf("user's created at should be not zero") + } + + if user.UpdatedAt.IsZero() { + t.Errorf("user's updated at should be not zero") + } + var newUser User if err := db.Where("id = ?", user.ID).First(&newUser).Error; err != nil { t.Errorf("errors happened when query: %v", err) @@ -119,7 +127,6 @@ func TestFind(t *testing.T, db *gorm.DB) { if err := db.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { t.Errorf("errors happened when query first: %v", err) } else { - log.Printf("all map %+v %+v", len(allMap), allMap) for idx, user := range users { t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { for _, name := range []string{"Name", "Age", "Birthday"} { @@ -140,21 +147,64 @@ func TestUpdate(t *testing.T, db *gorm.DB) { db.AutoMigrate(&User{}) t.Run("Update", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), + var ( + users = []*User{{ + Name: "update-before", + Age: 1, + Birthday: Now(), + }, { + Name: "update", + Age: 18, + Birthday: Now(), + }, { + Name: "update-after", + Age: 1, + Birthday: Now(), + }} + user = users[1] + lastUpdatedAt time.Time + ) + + checkUpdatedTime := func(name string, n time.Time) { + if n.UnixNano() == lastUpdatedAt.UnixNano() { + t.Errorf("%v: user's updated at should be changed, but got %v, was %v", name, n, lastUpdatedAt) + } + lastUpdatedAt = n } - if err := db.Create(&user).Error; err != nil { + checkOtherData := func(name string) { + var beforeUser, afterUser User + if err := db.Where("id = ?", users[0].ID).First(&beforeUser).Error; err != nil { + t.Errorf("errors happened when query before user: %v", err) + } + t.Run(name, func(t *testing.T) { + AssertObjEqual(t, beforeUser, users[0], "Name", "Age", "Birthday") + }) + + if err := db.Where("id = ?", users[2].ID).First(&afterUser).Error; err != nil { + t.Errorf("errors happened when query after user: %v", err) + } + t.Run(name, func(t *testing.T) { + AssertObjEqual(t, afterUser, users[2], "Name", "Age", "Birthday") + }) + } + + if err := db.Create(&users).Error; err != nil { t.Errorf("errors happened when create: %v", err) + } else if user.ID == 0 { + t.Errorf("user's primary value should not zero, %v", user.ID) + } else if user.UpdatedAt.IsZero() { + t.Errorf("user's updated at should not zero, %v", user.UpdatedAt) } + lastUpdatedAt = user.UpdatedAt - if err := db.Model(&user).Update("Age", 10).Error; err != nil { + if err := db.Model(user).Update("Age", 10).Error; err != nil { t.Errorf("errors happened when update: %v", err) } else if user.Age != 10 { t.Errorf("Age should equals to 10, but got %v", user.Age) } + checkUpdatedTime("Update", user.UpdatedAt) + checkOtherData("Update") var result User if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { @@ -164,13 +214,15 @@ func TestUpdate(t *testing.T, db *gorm.DB) { } values := map[string]interface{}{"Active": true, "age": 5} - if err := db.Model(&user).Updates(values).Error; err != nil { + if err := db.Model(user).Updates(values).Error; err != nil { t.Errorf("errors happened when update: %v", err) } else if user.Age != 5 { t.Errorf("Age should equals to 5, but got %v", user.Age) } else if user.Active != true { t.Errorf("Active should be true, but got %v", user.Active) } + checkUpdatedTime("Updates with map", user.UpdatedAt) + checkOtherData("Updates with map") var result2 User if err := db.Where("id = ?", user.ID).First(&result2).Error; err != nil { @@ -179,11 +231,13 @@ func TestUpdate(t *testing.T, db *gorm.DB) { AssertObjEqual(t, result2, user, "Name", "Age", "Birthday") } - if err := db.Model(&user).Updates(User{Age: 2}).Error; err != nil { + if err := db.Model(user).Updates(User{Age: 2}).Error; err != nil { t.Errorf("errors happened when update: %v", err) } else if user.Age != 2 { t.Errorf("Age should equals to 2, but got %v", user.Age) } + checkUpdatedTime("Updates with struct", user.UpdatedAt) + checkOtherData("Updates with struct") var result3 User if err := db.Where("id = ?", user.ID).First(&result3).Error; err != nil { @@ -191,5 +245,24 @@ func TestUpdate(t *testing.T, db *gorm.DB) { } else { AssertObjEqual(t, result3, user, "Name", "Age", "Birthday") } + + user.Active = false + user.Age = 1 + if err := db.Save(user).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 1 { + t.Errorf("Age should equals to 1, but got %v", user.Age) + } else if user.Active != false { + t.Errorf("Active should equals to false, but got %v", user.Active) + } + checkUpdatedTime("Save", user.UpdatedAt) + checkOtherData("Save") + + var result4 User + if err := db.Where("id = ?", user.ID).First(&result4).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, result4, user, "Name", "Age", "Birthday") + } }) } diff --git a/utils/utils.go b/utils/utils.go index 86ea557b..e7ed512c 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -6,8 +6,8 @@ import ( "runtime" ) -var goSrcRegexp = regexp.MustCompile(`/gorm/.*\.go`) -var goTestRegexp = regexp.MustCompile(`/gorm/.*test\.go`) +var goSrcRegexp = regexp.MustCompile(`/gorm/.*.go`) +var goTestRegexp = regexp.MustCompile(`/gorm/.*test.*.go`) func FileWithLineNum() string { for i := 2; i < 15; i++ { From ce0e6f9f337172d44208e9451326a95f0e37f157 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 8 Mar 2020 14:51:52 +0800 Subject: [PATCH 268/881] Add Delete test --- callbacks/delete.go | 32 +++++++++++++++++++++++++ finisher_api.go | 7 +++++- helpers.go | 2 ++ logger/logger.go | 2 +- tests/tests.go | 58 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 99 insertions(+), 2 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index d79f88fc..05d00d0a 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -4,6 +4,7 @@ import ( "reflect" "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" ) func BeforeDelete(db *gorm.DB) { @@ -32,6 +33,37 @@ func BeforeDelete(db *gorm.DB) { } func Delete(db *gorm.DB) { + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Delete{}) + + values := []reflect.Value{db.Statement.ReflectValue} + if db.Statement.Dest != db.Statement.Model { + values = append(values, reflect.ValueOf(db.Statement.Model)) + } + for _, field := range db.Statement.Schema.PrimaryFields { + for _, value := range values { + if value, isZero := field.ValueOf(value); !isZero { + db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + } + } + } + + if _, ok := db.Statement.Clauses["WHERE"]; !ok { + db.AddError(gorm.ErrMissingWhereClause) + return + } + + db.Statement.AddClauseIfNotExists(clause.From{}) + db.Statement.Build("DELETE", "FROM", "WHERE") + } + + result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } } func AfterDelete(db *gorm.DB) { diff --git a/finisher_api.go b/finisher_api.go index 0b729cc9..806c6723 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -135,8 +135,13 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) { } // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition -func (db *DB) Delete(value interface{}, where ...interface{}) (tx *DB) { +func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() + if len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) + } + tx.Statement.Dest = value + tx.callbacks.Delete().Execute(tx) return } diff --git a/helpers.go b/helpers.go index d7177ba7..241d3fbd 100644 --- a/helpers.go +++ b/helpers.go @@ -17,6 +17,8 @@ var ( ErrUnaddressable = errors.New("using unaddressable value") // ErrNotImplemented not implemented ErrNotImplemented = errors.New("not implemented") + // ErrMissingWhereClause missing where clause + ErrMissingWhereClause = errors.New("missing WHERE clause while deleting") ) // Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt diff --git a/logger/logger.go b/logger/logger.go index 2a765628..80ae31b1 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -121,7 +121,7 @@ func (l logger) Error(msg string, data ...interface{}) { // Trace print sql message func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) { - if elapsed := time.Now().Sub(begin); err != nil || (elapsed > l.SlowThreshold && l.SlowThreshold != 0) { + if elapsed := time.Now().Sub(begin); elapsed > l.SlowThreshold && l.SlowThreshold != 0 { sql, rows := fc() fileline := utils.FileWithLineNum() if err != nil { diff --git a/tests/tests.go b/tests/tests.go index 4181ad46..a15a9d0d 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -1,6 +1,7 @@ package tests import ( + "errors" "reflect" "strconv" "testing" @@ -18,6 +19,7 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) { TestCreate(t, db) TestFind(t, db) TestUpdate(t, db) + TestDelete(t, db) } func TestCreate(t *testing.T, db *gorm.DB) { @@ -266,3 +268,59 @@ func TestUpdate(t *testing.T, db *gorm.DB) { } }) } + +func TestDelete(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&User{}) + db.AutoMigrate(&User{}) + + t.Run("Delete", func(t *testing.T) { + var users = []User{{ + Name: "find", + Age: 1, + Birthday: Now(), + }, { + Name: "find", + Age: 2, + Birthday: Now(), + }, { + Name: "find", + Age: 3, + Birthday: Now(), + }} + + if err := db.Create(&users).Error; err != nil { + t.Errorf("errors happened when create: %v", err) + } + + for _, user := range users { + if user.ID == 0 { + t.Errorf("user's primary key should has value after create, got : %v", user.ID) + } + } + + if err := db.Delete(&users[1]).Error; err != nil { + t.Errorf("errors happened when delete: %v", err) + } + + var result User + if err := db.Where("id = ?", users[1].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should returns record not found error, but got %v", err) + } + + for _, user := range []User{users[0], users[2]} { + if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("no error should returns when query %v, but got %v", user.ID, err) + } + } + + if err := db.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { + t.Errorf("should returns missing WHERE clause while deleting error") + } + + for _, user := range []User{users[0], users[2]} { + if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("no error should returns when query %v, but got %v", user.ID, err) + } + } + }) +} From a158d1ada035e61e5309fbf594ad4f813e6db06a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 8 Mar 2020 18:05:22 +0800 Subject: [PATCH 269/881] Add GroupBy test --- chainable_api.go | 8 ++++- clause/benchmarks_test.go | 2 +- clause/group_by.go | 8 ++--- clause/group_by_test.go | 6 ++-- finisher_api.go | 6 ---- tests/group_by.go | 62 +++++++++++++++++++++++++++++++++++++++ tests/tests.go | 2 ++ 7 files changed, 79 insertions(+), 15 deletions(-) create mode 100644 tests/group_by.go diff --git a/chainable_api.go b/chainable_api.go index 49f260d3..f0bf8018 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -135,14 +135,20 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { } // Group specify the group method on the find -func (db *DB) Group(column string) (tx *DB) { +func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() + tx.Statement.AddClause(clause.GroupBy{ + Columns: []clause.Column{{Name: name}}, + }) return } // Having specify HAVING conditions for GROUP BY func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.AddClause(clause.GroupBy{ + Having: tx.Statement.BuildCondtion(query, args...), + }) return } diff --git a/clause/benchmarks_test.go b/clause/benchmarks_test.go index 33d3430a..47001cd1 100644 --- a/clause/benchmarks_test.go +++ b/clause/benchmarks_test.go @@ -41,7 +41,7 @@ func BenchmarkComplexSelect(b *testing.B) { clause.Where{Exprs: []clause.Expression{ clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}), }}, - clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: clause.Where{[]clause.Expression{clause.Eq{"role", "admin"}}}}, + clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}}, clause.Limit{Limit: 10, Offset: 20}, clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}}, } diff --git a/clause/group_by.go b/clause/group_by.go index 8d164731..a245d50a 100644 --- a/clause/group_by.go +++ b/clause/group_by.go @@ -3,7 +3,7 @@ package clause // GroupBy group by clause type GroupBy struct { Columns []Column - Having Where + Having []Expression } // Name from clause name @@ -21,9 +21,9 @@ func (groupBy GroupBy) Build(builder Builder) { builder.WriteQuoted(column) } - if len(groupBy.Having.Exprs) > 0 { + if len(groupBy.Having) > 0 { builder.Write(" HAVING ") - groupBy.Having.Build(builder) + Where{Exprs: groupBy.Having}.Build(builder) } } @@ -31,7 +31,7 @@ func (groupBy GroupBy) Build(builder Builder) { func (groupBy GroupBy) MergeClause(clause *Clause) { if v, ok := clause.Expression.(GroupBy); ok { groupBy.Columns = append(v.Columns, groupBy.Columns...) - groupBy.Having.Exprs = append(v.Having.Exprs, groupBy.Having.Exprs...) + groupBy.Having = append(v.Having, groupBy.Having...) } clause.Expression = groupBy } diff --git a/clause/group_by_test.go b/clause/group_by_test.go index 35be84a4..98aad3eb 100644 --- a/clause/group_by_test.go +++ b/clause/group_by_test.go @@ -16,17 +16,17 @@ func TestGroupBy(t *testing.T) { { []clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{ Columns: []clause.Column{{Name: "role"}}, - Having: clause.Where{[]clause.Expression{clause.Eq{"role", "admin"}}}, + Having: []clause.Expression{clause.Eq{"role", "admin"}}, }}, "SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", []interface{}{"admin"}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{ Columns: []clause.Column{{Name: "role"}}, - Having: clause.Where{[]clause.Expression{clause.Eq{"role", "admin"}}}, + Having: []clause.Expression{clause.Eq{"role", "admin"}}, }, clause.GroupBy{ Columns: []clause.Column{{Name: "gender"}}, - Having: clause.Where{[]clause.Expression{clause.Neq{"gender", "U"}}}, + Having: []clause.Expression{clause.Neq{"gender", "U"}}, }}, "SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", []interface{}{"admin", "U"}, }, diff --git a/finisher_api.go b/finisher_api.go index 806c6723..51d9b409 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -145,12 +145,6 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { return } -//Preloads only preloads relations, don`t touch out -func (db *DB) Preloads(out interface{}) (tx *DB) { - tx = db.getInstance() - return -} - func (db *DB) Count(value interface{}) (tx *DB) { tx = db.getInstance() return diff --git a/tests/group_by.go b/tests/group_by.go new file mode 100644 index 00000000..b0bb4155 --- /dev/null +++ b/tests/group_by.go @@ -0,0 +1,62 @@ +package tests + +import ( + "testing" + + "github.com/jinzhu/gorm" +) + +func TestGroupBy(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&User{}) + db.AutoMigrate(&User{}) + + t.Run("GroupBy", func(t *testing.T) { + var users = []User{{ + Name: "groupby", + Age: 10, + Birthday: Now(), + }, { + Name: "groupby", + Age: 20, + Birthday: Now(), + }, { + Name: "groupby", + Age: 30, + Birthday: Now(), + }, { + Name: "groupby1", + Age: 110, + Birthday: Now(), + }, { + Name: "groupby1", + Age: 220, + Birthday: Now(), + }, { + Name: "groupby1", + Age: 330, + Birthday: Now(), + }} + + if err := db.Create(&users).Error; err != nil { + t.Errorf("errors happened when create: %v", err) + } + + var name string + var total int + if err := db.Model(&User{}).Select("name, sum(age)").Where("name = ?", "groupby").Group("name").Row().Scan(&name, &total); err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if name != "groupby" || total != 60 { + t.Errorf("name should be groupby, but got %v, total should be 60, but got %v", name, total) + } + + if err := db.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Row().Scan(&name, &total); err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if name != "groupby1" || total != 660 { + t.Errorf("name should be groupby, but got %v, total should be 660, but got %v", name, total) + } + }) +} diff --git a/tests/tests.go b/tests/tests.go index a15a9d0d..65c1ca96 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -20,6 +20,8 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) { TestFind(t, db) TestUpdate(t, db) TestDelete(t, db) + + TestGroupBy(t, db) } func TestCreate(t *testing.T, db *gorm.DB) { From 5fce17543a2b166c915bff00ad2581ba1626255e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 8 Mar 2020 19:12:33 +0800 Subject: [PATCH 270/881] Add Joins --- chainable_api.go | 1 + clause/joins.go | 8 ++++++++ tests/joins.go | 10 ++++++++++ tests/tests.go | 1 + 4 files changed, 20 insertions(+) create mode 100644 clause/joins.go create mode 100644 tests/joins.go diff --git a/chainable_api.go b/chainable_api.go index f0bf8018..6f80d4be 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -128,6 +128,7 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { } // Joins specify Joins conditions +// db.Joins("Account").Find(&user) // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() diff --git a/clause/joins.go b/clause/joins.go new file mode 100644 index 00000000..4983d6fd --- /dev/null +++ b/clause/joins.go @@ -0,0 +1,8 @@ +package clause + +// Joins joins clause +type Joins struct { + Name string + Query string + Vars []interface{} +} diff --git a/tests/joins.go b/tests/joins.go new file mode 100644 index 00000000..3c4bfbb5 --- /dev/null +++ b/tests/joins.go @@ -0,0 +1,10 @@ +package tests + +import ( + "testing" + + "github.com/jinzhu/gorm" +) + +func TestJoins(t *testing.T, db *gorm.DB) { +} diff --git a/tests/tests.go b/tests/tests.go index 65c1ca96..33013032 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -22,6 +22,7 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) { TestDelete(t, db) TestGroupBy(t, db) + TestJoins(t, db) } func TestCreate(t *testing.T, db *gorm.DB) { From 078ba75b9cc749820610e11b205a2e219a5e7239 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 8 Mar 2020 23:30:16 +0800 Subject: [PATCH 271/881] Add QuoteTo method --- dialects/mssql/mssql.go | 7 +++-- dialects/mysql/mysql.go | 7 +++-- dialects/postgres/postgres.go | 7 +++-- dialects/sqlite/sqlite.go | 7 +++-- go.mod | 4 +++ gorm.go | 1 - interfaces.go | 3 +- statement.go | 55 +++++++++++++++-------------------- tests/dummy_dialecter.go | 8 +++-- 9 files changed, 56 insertions(+), 43 deletions(-) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index b93cc8f6..91574787 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -5,6 +5,7 @@ import ( "fmt" "regexp" "strconv" + "strings" _ "github.com/denisenkom/go-mssqldb" "github.com/jinzhu/gorm" @@ -42,8 +43,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "@p" + strconv.Itoa(len(stmt.Vars)) } -func (dialector Dialector) QuoteChars() [2]byte { - return [2]byte{'"', '"'} // `name` +func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) { + builder.WriteByte('"') + builder.WriteString(str) + builder.WriteByte('"') } var numericPlaceholder = regexp.MustCompile("@p(\\d+)") diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index e1bf985a..9d16507e 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" "math" + "strings" _ "github.com/go-sql-driver/mysql" "github.com/jinzhu/gorm" @@ -39,8 +40,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } -func (dialector Dialector) QuoteChars() [2]byte { - return [2]byte{'`', '`'} // `name` +func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) { + builder.WriteByte('`') + builder.WriteString(str) + builder.WriteByte('`') } func (dialector Dialector) Explain(sql string, vars ...interface{}) string { diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 3ee4ba9f..0005f7ed 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -5,6 +5,7 @@ import ( "fmt" "regexp" "strconv" + "strings" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" @@ -42,8 +43,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "$" + strconv.Itoa(len(stmt.Vars)) } -func (dialector Dialector) QuoteChars() [2]byte { - return [2]byte{'"', '"'} // "name" +func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) { + builder.WriteByte('"') + builder.WriteString(str) + builder.WriteByte('"') } var numericPlaceholder = regexp.MustCompile("\\$(\\d+)") diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 5f9d49df..91762343 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -2,6 +2,7 @@ package sqlite import ( "database/sql" + "strings" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" @@ -38,8 +39,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } -func (dialector Dialector) QuoteChars() [2]byte { - return [2]byte{'`', '`'} // `name` +func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) { + builder.WriteByte('`') + builder.WriteString(str) + builder.WriteByte('`') } func (dialector Dialector) Explain(sql string, vars ...interface{}) string { diff --git a/go.mod b/go.mod index cdb7e574..3e067d3c 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,10 @@ module github.com/jinzhu/gorm go 1.13 require ( + github.com/denisenkom/go-mssqldb v0.0.0-20200206145737-bbfc9a55622e // indirect + github.com/go-sql-driver/mysql v1.5.0 // indirect github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.1.1 + github.com/lib/pq v1.3.0 // indirect + github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect ) diff --git a/gorm.go b/gorm.go index 2f10be60..eac95868 100644 --- a/gorm.go +++ b/gorm.go @@ -79,7 +79,6 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { if dialector != nil { err = dialector.Initialize(db) - db.quoteChars = dialector.QuoteChars() } return } diff --git a/interfaces.go b/interfaces.go index f0d14dd8..c89c3624 100644 --- a/interfaces.go +++ b/interfaces.go @@ -3,6 +3,7 @@ package gorm import ( "context" "database/sql" + "strings" "github.com/jinzhu/gorm/schema" ) @@ -13,7 +14,7 @@ type Dialector interface { Migrator(db *DB) Migrator DataTypeOf(*schema.Field) string BindVar(stmt *Statement, v interface{}) string - QuoteChars() [2]byte + QuoteTo(*strings.Builder, string) Explain(sql string, vars ...interface{}) string } diff --git a/statement.go b/statement.go index bad83717..f04ea269 100644 --- a/statement.go +++ b/statement.go @@ -76,65 +76,58 @@ func (stmt *Statement) WriteByte(c byte) (err error) { return stmt.SQL.WriteByte(c) } -// WriteQuoted write quoted field -func (stmt *Statement) WriteQuoted(field interface{}) (err error) { - _, err = stmt.SQL.WriteString(stmt.Quote(field)) - return +// WriteQuoted write quoted value +func (stmt *Statement) WriteQuoted(value interface{}) error { + stmt.QuoteTo(&stmt.SQL, value) + return nil } -// Quote returns quoted value -func (stmt Statement) Quote(field interface{}) string { - var str strings.Builder - str.WriteByte(stmt.DB.quoteChars[0]) - +// QuoteTo write quoted value to writer +func (stmt Statement) QuoteTo(writer *strings.Builder, field interface{}) { switch v := field.(type) { case clause.Table: if v.Name == clause.CurrentTable { - str.WriteString(stmt.Table) + stmt.DB.Dialector.QuoteTo(writer, stmt.Table) } else { - str.WriteString(v.Name) + stmt.DB.Dialector.QuoteTo(writer, v.Name) } if v.Alias != "" { - str.WriteByte(stmt.DB.quoteChars[1]) - str.WriteString(" AS ") - str.WriteByte(stmt.DB.quoteChars[0]) - str.WriteString(v.Alias) - str.WriteByte(stmt.DB.quoteChars[1]) + writer.WriteString(" AS ") + stmt.DB.Dialector.QuoteTo(writer, v.Alias) } case clause.Column: if v.Table != "" { if v.Table == clause.CurrentTable { - str.WriteString(stmt.Table) + stmt.DB.Dialector.QuoteTo(writer, stmt.Table) } else { - str.WriteString(v.Table) + stmt.DB.Dialector.QuoteTo(writer, v.Table) } - str.WriteByte(stmt.DB.quoteChars[1]) - str.WriteByte('.') - str.WriteByte(stmt.DB.quoteChars[0]) + writer.WriteByte('.') } if v.Name == clause.PrimaryKey { if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil { - str.WriteString(stmt.Schema.PrioritizedPrimaryField.DBName) + stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName) } } else { - str.WriteString(v.Name) + stmt.DB.Dialector.QuoteTo(writer, v.Name) } if v.Alias != "" { - str.WriteByte(stmt.DB.quoteChars[1]) - str.WriteString(" AS ") - str.WriteByte(stmt.DB.quoteChars[0]) - str.WriteString(v.Alias) - str.WriteByte(stmt.DB.quoteChars[1]) + writer.WriteString(" AS ") + stmt.DB.Dialector.QuoteTo(writer, v.Alias) } default: - str.WriteString(fmt.Sprint(field)) + stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field)) } +} - str.WriteByte(stmt.DB.quoteChars[1]) - return str.String() +// Quote returns quoted value +func (stmt Statement) Quote(field interface{}) string { + var builder strings.Builder + stmt.QuoteTo(&builder, field) + return builder.String() } // Write write string diff --git a/tests/dummy_dialecter.go b/tests/dummy_dialecter.go index 04d6248d..9e3146fe 100644 --- a/tests/dummy_dialecter.go +++ b/tests/dummy_dialecter.go @@ -1,6 +1,8 @@ package tests import ( + "strings" + "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/schema" @@ -21,8 +23,10 @@ func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } -func (DummyDialector) QuoteChars() [2]byte { - return [2]byte{'`', '`'} // `name` +func (DummyDialector) QuoteTo(builder *strings.Builder, str string) { + builder.WriteByte('`') + builder.WriteString(str) + builder.WriteByte('`') } func (DummyDialector) Explain(sql string, vars ...interface{}) string { From a145d7e01946a4f0777b0c1764bd8e24d3425789 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 9 Mar 2020 13:10:48 +0800 Subject: [PATCH 272/881] Refactor structure --- callbacks.go | 3 ++ callbacks/create.go | 2 +- callbacks/delete.go | 2 +- callbacks/query.go | 2 +- callbacks/raw.go | 2 +- callbacks/row.go | 4 +-- callbacks/update.go | 2 +- chainable_api.go | 5 +-- dialects/mssql/mssql.go | 3 +- dialects/mysql/mysql.go | 2 +- dialects/postgres/postgres.go | 3 +- dialects/sqlite/sqlite.go | 2 +- helpers.go => errors.go | 18 ---------- finisher_api.go | 8 ++--- gorm.go | 64 ++++++++++++++++++++--------------- interfaces.go | 4 +-- model.go | 15 ++++++++ statement.go | 36 +++++++------------- utils/utils.go | 5 +++ 19 files changed, 91 insertions(+), 91 deletions(-) rename helpers.go => errors.go (60%) create mode 100644 model.go diff --git a/callbacks.go b/callbacks.go index db8261c4..d1164019 100644 --- a/callbacks.go +++ b/callbacks.go @@ -90,6 +90,9 @@ func (p *processor) Execute(db *DB) { } if stmt := db.Statement; stmt != nil { + db.Error = stmt.Error + db.RowsAffected = stmt.RowsAffected + db.Logger.Trace(curTime, func() (string, int64) { return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected }, db.Error) diff --git a/callbacks/create.go b/callbacks/create.go index 2e1b3381..42dcda27 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -50,7 +50,7 @@ func Create(db *gorm.DB) { db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") - result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { if db.Statement.Schema != nil { diff --git a/callbacks/delete.go b/callbacks/delete.go index 05d00d0a..50b2880a 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -57,7 +57,7 @@ func Delete(db *gorm.DB) { db.Statement.Build("DELETE", "FROM", "WHERE") } - result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { db.RowsAffected, _ = result.RowsAffected() diff --git a/callbacks/query.go b/callbacks/query.go index 26c0e0ad..00820bfd 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -14,7 +14,7 @@ func Query(db *gorm.DB) { db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } - rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err != nil { db.AddError(err) return diff --git a/callbacks/raw.go b/callbacks/raw.go index e8cad25d..ce125e61 100644 --- a/callbacks/raw.go +++ b/callbacks/raw.go @@ -5,7 +5,7 @@ import ( ) func RawExec(db *gorm.DB) { - result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err != nil { db.AddError(err) } else { diff --git a/callbacks/row.go b/callbacks/row.go index f7d6752d..b84cf694 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -14,8 +14,8 @@ func RowQuery(db *gorm.DB) { } if _, ok := db.Get("rows"); ok { - db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } else { - db.Statement.Dest = db.DB.QueryRowContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } } diff --git a/callbacks/update.go b/callbacks/update.go index ca31bf18..eab9f929 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -47,7 +47,7 @@ func Update(db *gorm.DB) { db.Statement.AddClause(ConvertToAssignments(db.Statement)) db.Statement.Build("UPDATE", "SET", "WHERE") - result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { db.RowsAffected, _ = result.RowsAffected() diff --git a/chainable_api.go b/chainable_api.go index 6f80d4be..98c1898e 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/utils" ) // Model specify the model you would like to run db operations @@ -64,7 +65,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } } case string: - fields := strings.FieldsFunc(v, isChar) + fields := strings.FieldsFunc(v, utils.IsChar) // normal field names if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { @@ -100,7 +101,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { - tx.Statement.Omits = strings.FieldsFunc(columns[0], isChar) + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsChar) } else { tx.Statement.Omits = columns } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 91574787..7e51de75 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -26,8 +26,7 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db) - - db.DB, err = sql.Open("sqlserver", dialector.DSN) + db.ConnPool, err = sql.Open("sqlserver", dialector.DSN) return } diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index 9d16507e..55b5a53f 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -25,7 +25,7 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db) - db.DB, err = sql.Open("mysql", dialector.DSN) + db.ConnPool, err = sql.Open("mysql", dialector.DSN) return } diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 0005f7ed..e90fa4ae 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -26,8 +26,7 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db) - - db.DB, err = sql.Open("postgres", dialector.DSN) + db.ConnPool, err = sql.Open("postgres", dialector.DSN) return } diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 91762343..8e3cc058 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -23,7 +23,7 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db) - db.DB, err = sql.Open("sqlite3", dialector.DSN) + db.ConnPool, err = sql.Open("sqlite3", dialector.DSN) return } diff --git a/helpers.go b/errors.go similarity index 60% rename from helpers.go rename to errors.go index 241d3fbd..32f55e01 100644 --- a/helpers.go +++ b/errors.go @@ -2,8 +2,6 @@ package gorm import ( "errors" - "time" - "unicode" ) var ( @@ -20,19 +18,3 @@ var ( // ErrMissingWhereClause missing where clause ErrMissingWhereClause = errors.New("missing WHERE clause while deleting") ) - -// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt -// It may be embeded into your model or you may build your own model without it -// type User struct { -// gorm.Model -// } -type Model struct { - ID uint `gorm:"primarykey"` - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time `gorm:"index"` -} - -func isChar(c rune) bool { - return !unicode.IsLetter(c) && !unicode.IsNumber(c) -} diff --git a/finisher_api.go b/finisher_api.go index 51d9b409..62c1af30 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -196,14 +196,14 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er // Begin begins a transaction func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { tx = db.getInstance() - if beginner, ok := tx.DB.(TxBeginner); ok { + if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { var opt *sql.TxOptions var err error if len(opts) > 0 { opt = opts[0] } - if tx.DB, err = beginner.BeginTx(db.Context, opt); err != nil { + if tx.Statement.ConnPool, err = beginner.BeginTx(db.Statement.Context, opt); err != nil { tx.AddError(err) } } else { @@ -214,7 +214,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { // Commit commit a transaction func (db *DB) Commit() *DB { - if comminter, ok := db.DB.(TxCommiter); ok && comminter != nil { + if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { db.AddError(comminter.Commit()) } else { db.AddError(ErrInvalidTransaction) @@ -224,7 +224,7 @@ func (db *DB) Commit() *DB { // Rollback rollback a transaction func (db *DB) Rollback() *DB { - if comminter, ok := db.DB.(TxCommiter); ok && comminter != nil { + if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { db.AddError(comminter.Rollback()) } else { db.AddError(ErrInvalidTransaction) diff --git a/gorm.go b/gorm.go index eac95868..b238d572 100644 --- a/gorm.go +++ b/gorm.go @@ -21,23 +21,25 @@ type Config struct { Logger logger.Interface // NowFunc the function to be used when creating a new timestamp NowFunc func() time.Time -} -type shared struct { + // ClauseBuilders clause builder + ClauseBuilders map[string]clause.ClauseBuilder + // ConnPool db conn pool + ConnPool ConnPool + // Dialector database dialector + Dialector + callbacks *callbacks cacheStore *sync.Map - quoteChars [2]byte } // DB GORM DB definition type DB struct { *Config - Dialector - Instance - ClauseBuilders map[string]clause.ClauseBuilder - DB CommonDB - clone bool - *shared + Error error + RowsAffected int64 + Statement *Statement + clone bool } // Session session config when create session with Session() method @@ -65,14 +67,17 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { config.NowFunc = func() time.Time { return time.Now().Local() } } + if dialector != nil { + config.Dialector = dialector + } + + if config.cacheStore == nil { + config.cacheStore = &sync.Map{} + } + db = &DB{ - Config: config, - Dialector: dialector, - ClauseBuilders: map[string]clause.ClauseBuilder{}, - clone: true, - shared: &shared{ - cacheStore: &sync.Map{}, - }, + Config: config, + clone: true, } db.callbacks = initializeCallbacks(db) @@ -91,7 +96,7 @@ func (db *DB) Session(config *Session) *DB { ) if config.Context != nil { - tx.Context = config.Context + tx.Statement.Context = config.Context } if config.Logger != nil { @@ -142,23 +147,26 @@ func (db *DB) AutoMigrate(dst ...interface{}) error { return db.Migrator().AutoMigrate(dst...) } +// AddError add error to db +func (db *DB) AddError(err error) { + db.Statement.AddError(err) +} + func (db *DB) getInstance() *DB { if db.clone { - ctx := db.Instance.Context - if ctx == nil { - ctx = context.Background() + ctx := context.Background() + if db.Statement != nil { + ctx = db.Statement.Context } return &DB{ - Instance: Instance{ - Context: ctx, - Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}}, + Config: db.Config, + Statement: &Statement{ + DB: db, + ConnPool: db.ConnPool, + Clauses: map[string]clause.Clause{}, + Context: ctx, }, - Config: db.Config, - Dialector: db.Dialector, - ClauseBuilders: db.ClauseBuilders, - DB: db.DB, - shared: db.shared, } } diff --git a/interfaces.go b/interfaces.go index c89c3624..9859d1fa 100644 --- a/interfaces.go +++ b/interfaces.go @@ -18,8 +18,8 @@ type Dialector interface { Explain(sql string, vars ...interface{}) string } -// CommonDB common db interface -type CommonDB interface { +// ConnPool db conns pool interface +type ConnPool interface { ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) diff --git a/model.go b/model.go new file mode 100644 index 00000000..fdee99dc --- /dev/null +++ b/model.go @@ -0,0 +1,15 @@ +package gorm + +import "time" + +// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt +// It may be embeded into your model or you may build your own model without it +// type User struct { +// gorm.Model +// } +type Model struct { + ID uint `gorm:"primarykey"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time `gorm:"index"` +} diff --git a/statement.go b/statement.go index f04ea269..10b62567 100644 --- a/statement.go +++ b/statement.go @@ -14,30 +14,6 @@ import ( "github.com/jinzhu/gorm/schema" ) -// Instance db instance -type Instance struct { - Error error - RowsAffected int64 - Context context.Context - Statement *Statement -} - -func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) { - if len(clauses) > 0 { - instance.Statement.Build(clauses...) - } - return strings.TrimSpace(instance.Statement.SQL.String()), instance.Statement.Vars -} - -// AddError add error to instance -func (inst *Instance) AddError(err error) { - if inst.Error == nil { - inst.Error = err - } else if err != nil { - inst.Error = fmt.Errorf("%v; %w", inst.Error, err) - } -} - // Statement statement type Statement struct { Table string @@ -48,8 +24,12 @@ type Statement struct { Selects []string // selected columns Omits []string // omit columns Settings sync.Map + ConnPool ConnPool DB *DB Schema *schema.Schema + Context context.Context + Error error + RowsAffected int64 RaiseErrorOnNotFound bool // SQL Builder @@ -246,6 +226,14 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con return conditions } +func (stmt *Statement) AddError(err error) { + if stmt.Error == nil { + stmt.Error = err + } else if err != nil { + stmt.Error = fmt.Errorf("%v; %w", stmt.Error, err) + } +} + // Build build sql with clauses names func (stmt *Statement) Build(clauses ...string) { var firstClauseWritten bool diff --git a/utils/utils.go b/utils/utils.go index e7ed512c..25cd585a 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -4,6 +4,7 @@ import ( "fmt" "regexp" "runtime" + "unicode" ) var goSrcRegexp = regexp.MustCompile(`/gorm/.*.go`) @@ -18,3 +19,7 @@ func FileWithLineNum() string { } return "" } + +func IsChar(c rune) bool { + return !unicode.IsLetter(c) && !unicode.IsNumber(c) +} From 3aa1891068543c96eb8e6b175c61c19e193906ba Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 9 Mar 2020 15:32:55 +0800 Subject: [PATCH 273/881] Add sync pool --- callbacks.go | 3 ++ chainable_api.go | 42 +++++++++---------- dialects/sqlite/sqlite_test.go | 6 +-- finisher_api.go | 76 +++++++++++++++++----------------- gorm.go | 56 +++++++++++++------------ migrator/migrator.go | 4 +- statement.go | 65 ++++++++++++++++++++--------- 7 files changed, 143 insertions(+), 109 deletions(-) diff --git a/callbacks.go b/callbacks.go index d1164019..e2907178 100644 --- a/callbacks.go +++ b/callbacks.go @@ -96,6 +96,9 @@ func (p *processor) Execute(db *DB) { db.Logger.Trace(curTime, func() (string, int64) { return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected }, db.Error) + + stmt.reinit() + db.Config.statementPool.Put(stmt) } } diff --git a/chainable_api.go b/chainable_api.go index 98c1898e..c2a6247b 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -13,14 +13,14 @@ import ( // db.Model(&User{}).Update("name", "hello") // // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` // db.Model(&user).Update("name", "hello") -func (db *DB) Model(value interface{}) (tx *DB) { +func (db DB) Model(value interface{}) (tx DB) { tx = db.getInstance() tx.Statement.Model = value return } // Clauses Add clauses -func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { +func (db DB) Clauses(conds ...clause.Expression) (tx DB) { tx = db.getInstance() var whereConds []interface{} @@ -39,14 +39,14 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { } // Table specify the table you would like to run db operations -func (db *DB) Table(name string) (tx *DB) { +func (db DB) Table(name string) (tx DB) { tx = db.getInstance() tx.Statement.Table = name return } // Select specify fields that you want when querying, creating, updating -func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { +func (db DB) Select(query interface{}, args ...interface{}) (tx DB) { tx = db.getInstance() switch v := query.(type) { @@ -97,7 +97,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } // Omit specify fields that you want to ignore when creating, updating and querying -func (db *DB) Omit(columns ...string) (tx *DB) { +func (db DB) Omit(columns ...string) (tx DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { @@ -108,21 +108,21 @@ func (db *DB) Omit(columns ...string) (tx *DB) { return } -func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { +func (db DB) Where(query interface{}, args ...interface{}) (tx DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(query, args...)}) return } // Not add NOT condition -func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { +func (db DB) Not(query interface{}, args ...interface{}) (tx DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}}) return } // Or add OR conditions -func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { +func (db DB) Or(query interface{}, args ...interface{}) (tx DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(tx.Statement.BuildCondtion(query, args...)...)}}) return @@ -131,13 +131,13 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { // Joins specify Joins conditions // db.Joins("Account").Find(&user) // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) -func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { +func (db DB) Joins(query string, args ...interface{}) (tx DB) { tx = db.getInstance() return } // Group specify the group method on the find -func (db *DB) Group(name string) (tx *DB) { +func (db DB) Group(name string) (tx DB) { tx = db.getInstance() tx.Statement.AddClause(clause.GroupBy{ Columns: []clause.Column{{Name: name}}, @@ -146,7 +146,7 @@ func (db *DB) Group(name string) (tx *DB) { } // Having specify HAVING conditions for GROUP BY -func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { +func (db DB) Having(query interface{}, args ...interface{}) (tx DB) { tx = db.getInstance() tx.Statement.AddClause(clause.GroupBy{ Having: tx.Statement.BuildCondtion(query, args...), @@ -157,7 +157,7 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { // Order specify order when retrieve records from database // db.Order("name DESC") // db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression -func (db *DB) Order(value interface{}) (tx *DB) { +func (db DB) Order(value interface{}) (tx DB) { tx = db.getInstance() switch v := value.(type) { @@ -176,20 +176,20 @@ func (db *DB) Order(value interface{}) (tx *DB) { } // Limit specify the number of records to be retrieved -func (db *DB) Limit(limit int) (tx *DB) { +func (db DB) Limit(limit int) (tx DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Limit{Limit: limit}) return } // Offset specify the number of records to skip before starting to return the records -func (db *DB) Offset(offset int) (tx *DB) { +func (db DB) Offset(offset int) (tx DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Limit{Offset: offset}) return } -// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically +// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically // func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { // return db.Where("amount > ?", 1000) // } @@ -201,7 +201,7 @@ func (db *DB) Offset(offset int) (tx *DB) { // } // // db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) -func (db *DB) Scopes(funcs ...func(*DB) *DB) *DB { +func (db DB) Scopes(funcs ...func(DB) DB) DB { for _, f := range funcs { db = f(db) } @@ -210,27 +210,27 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) *DB { // Preload preload associations with given conditions // db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) -func (db *DB) Preload(column string, conditions ...interface{}) (tx *DB) { +func (db DB) Preload(column string, conditions ...interface{}) (tx DB) { tx = db.getInstance() return } -func (db *DB) Assign(attrs ...interface{}) (tx *DB) { +func (db DB) Assign(attrs ...interface{}) (tx DB) { tx = db.getInstance() return } -func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { +func (db DB) Attrs(attrs ...interface{}) (tx DB) { tx = db.getInstance() return } -func (db *DB) Unscoped() (tx *DB) { +func (db DB) Unscoped() (tx DB) { tx = db.getInstance() return } -func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { +func (db DB) Raw(sql string, values ...interface{}) (tx DB) { tx = db.getInstance() tx.Statement.SQL = strings.Builder{} clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) diff --git a/dialects/sqlite/sqlite_test.go b/dialects/sqlite/sqlite_test.go index a42bc8ee..7a07db01 100644 --- a/dialects/sqlite/sqlite_test.go +++ b/dialects/sqlite/sqlite_test.go @@ -12,7 +12,7 @@ import ( ) var ( - DB *gorm.DB + DB gorm.DB err error ) @@ -23,9 +23,9 @@ func init() { } func TestCURD(t *testing.T) { - tests.RunTestsSuit(t, DB) + tests.RunTestsSuit(t, &DB) } func TestMigrate(t *testing.T) { - tests.TestMigrate(t, DB) + tests.TestMigrate(t, &DB) } diff --git a/finisher_api.go b/finisher_api.go index 62c1af30..4b3829a2 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -9,15 +9,15 @@ import ( ) // Create insert the value into database -func (db *DB) Create(value interface{}) (tx *DB) { +func (db DB) Create(value interface{}) (tx DB) { tx = db.getInstance() tx.Statement.Dest = value - tx.callbacks.Create().Execute(tx) + tx.callbacks.Create().Execute(&tx) return } // Save update value in database, if the value doesn't have primary key, will insert it -func (db *DB) Save(value interface{}) (tx *DB) { +func (db DB) Save(value interface{}) (tx DB) { tx = db.getInstance() tx.Statement.Dest = value @@ -26,7 +26,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { reflectValue := reflect.ValueOf(value) for idx, pf := range tx.Statement.Schema.PrimaryFields { if pv, isZero := pf.ValueOf(reflectValue); isZero { - tx.callbacks.Create().Execute(tx) + tx.callbacks.Create().Execute(&tx) where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} return } @@ -38,12 +38,12 @@ func (db *DB) Save(value interface{}) (tx *DB) { if len(tx.Statement.Selects) == 0 { tx.Statement.Selects = []string{"*"} } - tx.callbacks.Update().Execute(tx) + tx.callbacks.Update().Execute(&tx) return } // First find first record that match given conditions, order by primary key -func (db *DB) First(out interface{}, conds ...interface{}) (tx *DB) { +func (db DB) First(out interface{}, conds ...interface{}) (tx DB) { tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) @@ -52,24 +52,24 @@ func (db *DB) First(out interface{}, conds ...interface{}) (tx *DB) { } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out - tx.callbacks.Query().Execute(tx) + tx.callbacks.Query().Execute(&tx) return } // Take return a record that match given conditions, the order will depend on the database implementation -func (db *DB) Take(out interface{}, conds ...interface{}) (tx *DB) { +func (db DB) Take(out interface{}, conds ...interface{}) (tx DB) { tx = db.getInstance().Limit(1) if len(conds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out - tx.callbacks.Query().Execute(tx) + tx.callbacks.Query().Execute(&tx) return } // Last find last record that match given conditions, order by primary key -func (db *DB) Last(out interface{}, conds ...interface{}) (tx *DB) { +func (db DB) Last(out interface{}, conds ...interface{}) (tx DB) { tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Desc: true, @@ -79,101 +79,101 @@ func (db *DB) Last(out interface{}, conds ...interface{}) (tx *DB) { } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out - tx.callbacks.Query().Execute(tx) + tx.callbacks.Query().Execute(&tx) return } // Find find records that match given conditions -func (db *DB) Find(out interface{}, conds ...interface{}) (tx *DB) { +func (db DB) Find(out interface{}, conds ...interface{}) (tx DB) { tx = db.getInstance() if len(conds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) } tx.Statement.Dest = out - tx.callbacks.Query().Execute(tx) + tx.callbacks.Query().Execute(&tx) return } -func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) { +func (db DB) FirstOrInit(out interface{}, where ...interface{}) (tx DB) { tx = db.getInstance() return } -func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { +func (db DB) FirstOrCreate(out interface{}, where ...interface{}) (tx DB) { tx = db.getInstance() return } // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (db *DB) Update(column string, value interface{}) (tx *DB) { +func (db DB) Update(column string, value interface{}) (tx DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} - tx.callbacks.Update().Execute(tx) + tx.callbacks.Update().Execute(&tx) return } // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (db *DB) Updates(values interface{}) (tx *DB) { +func (db DB) Updates(values interface{}) (tx DB) { tx = db.getInstance() tx.Statement.Dest = values - tx.callbacks.Update().Execute(tx) + tx.callbacks.Update().Execute(&tx) return } -func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { +func (db DB) UpdateColumn(column string, value interface{}) (tx DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} - tx.callbacks.Update().Execute(tx) + tx.callbacks.Update().Execute(&tx) return } -func (db *DB) UpdateColumns(values interface{}) (tx *DB) { +func (db DB) UpdateColumns(values interface{}) (tx DB) { tx = db.getInstance() tx.Statement.Dest = values - tx.callbacks.Update().Execute(tx) + tx.callbacks.Update().Execute(&tx) return } // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition -func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { +func (db DB) Delete(value interface{}, conds ...interface{}) (tx DB) { tx = db.getInstance() if len(conds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) } tx.Statement.Dest = value - tx.callbacks.Delete().Execute(tx) + tx.callbacks.Delete().Execute(&tx) return } -func (db *DB) Count(value interface{}) (tx *DB) { +func (db DB) Count(value interface{}) (tx DB) { tx = db.getInstance() return } -func (db *DB) Row() *sql.Row { +func (db DB) Row() *sql.Row { tx := db.getInstance() - tx.callbacks.Row().Execute(tx) + tx.callbacks.Row().Execute(&tx) return tx.Statement.Dest.(*sql.Row) } -func (db *DB) Rows() (*sql.Rows, error) { +func (db DB) Rows() (*sql.Rows, error) { tx := db.Set("rows", true) - tx.callbacks.Row().Execute(tx) + tx.callbacks.Row().Execute(&tx) return tx.Statement.Dest.(*sql.Rows), tx.Error } // Scan scan value to a struct -func (db *DB) Scan(dest interface{}) (tx *DB) { +func (db DB) Scan(dest interface{}) (tx DB) { tx = db.getInstance() return } -func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error { +func (db DB) ScanRows(rows *sql.Rows, result interface{}) error { return nil } // Transaction start a transaction as a block, return error will rollback, otherwise to commit. -func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { +func (db DB) Transaction(fc func(tx DB) error, opts ...*sql.TxOptions) (err error) { panicked := true tx := db.Begin(opts...) defer func() { @@ -194,7 +194,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } // Begin begins a transaction -func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { +func (db DB) Begin(opts ...*sql.TxOptions) (tx DB) { tx = db.getInstance() if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { var opt *sql.TxOptions @@ -213,7 +213,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { } // Commit commit a transaction -func (db *DB) Commit() *DB { +func (db DB) Commit() DB { if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { db.AddError(comminter.Commit()) } else { @@ -223,7 +223,7 @@ func (db *DB) Commit() *DB { } // Rollback rollback a transaction -func (db *DB) Rollback() *DB { +func (db DB) Rollback() DB { if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { db.AddError(comminter.Rollback()) } else { @@ -233,10 +233,10 @@ func (db *DB) Rollback() *DB { } // Exec execute raw sql -func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { +func (db DB) Exec(sql string, values ...interface{}) (tx DB) { tx = db.getInstance() tx.Statement.SQL = strings.Builder{} clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) - tx.callbacks.Raw().Execute(tx) + tx.callbacks.Raw().Execute(&tx) return } diff --git a/gorm.go b/gorm.go index b238d572..b7d3e929 100644 --- a/gorm.go +++ b/gorm.go @@ -29,8 +29,9 @@ type Config struct { // Dialector database dialector Dialector - callbacks *callbacks - cacheStore *sync.Map + statementPool sync.Pool + callbacks *callbacks + cacheStore *sync.Map } // DB GORM DB definition @@ -50,7 +51,7 @@ type Session struct { } // Open initialize db session based on dialector -func Open(dialector Dialector, config *Config) (db *DB, err error) { +func Open(dialector Dialector, config *Config) (db DB, err error) { if config == nil { config = &Config{} } @@ -75,21 +76,32 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { config.cacheStore = &sync.Map{} } - db = &DB{ + config.statementPool = sync.Pool{ + New: func() interface{} { + return &Statement{ + DB: db, + ConnPool: db.ConnPool, + Clauses: map[string]clause.Clause{}, + Context: context.Background(), + } + }, + } + + db = DB{ Config: config, clone: true, } - db.callbacks = initializeCallbacks(db) + db.callbacks = initializeCallbacks(&db) if dialector != nil { - err = dialector.Initialize(db) + err = dialector.Initialize(&db) } return } // Session create new db session -func (db *DB) Session(config *Session) *DB { +func (db DB) Session(config *Session) DB { var ( tx = db.getInstance() txConfig = *tx.Config @@ -113,24 +125,24 @@ func (db *DB) Session(config *Session) *DB { } // WithContext change current instance db's context to ctx -func (db *DB) WithContext(ctx context.Context) *DB { +func (db DB) WithContext(ctx context.Context) DB { return db.Session(&Session{Context: ctx}) } // Debug start debug mode -func (db *DB) Debug() (tx *DB) { +func (db DB) Debug() (tx DB) { return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)}) } // Set store value with key into current db instance's context -func (db *DB) Set(key string, value interface{}) *DB { +func (db DB) Set(key string, value interface{}) DB { tx := db.getInstance() tx.Statement.Settings.Store(key, value) return tx } // Get get value with key from current db instance's context -func (db *DB) Get(key string) (interface{}, bool) { +func (db DB) Get(key string) (interface{}, bool) { if db.Statement != nil { return db.Statement.Settings.Load(key) } @@ -138,36 +150,28 @@ func (db *DB) Get(key string) (interface{}, bool) { } // Callback returns callback manager -func (db *DB) Callback() *callbacks { +func (db DB) Callback() *callbacks { return db.callbacks } // AutoMigrate run auto migration for given models -func (db *DB) AutoMigrate(dst ...interface{}) error { +func (db DB) AutoMigrate(dst ...interface{}) error { return db.Migrator().AutoMigrate(dst...) } // AddError add error to db -func (db *DB) AddError(err error) { +func (db DB) AddError(err error) { db.Statement.AddError(err) } -func (db *DB) getInstance() *DB { +func (db DB) getInstance() DB { if db.clone { - ctx := context.Background() + stmt := db.Config.statementPool.Get().(*Statement) if db.Statement != nil { - ctx = db.Statement.Context + stmt.Context = db.Statement.Context } - return &DB{ - Config: db.Config, - Statement: &Statement{ - DB: db, - ConnPool: db.ConnPool, - Clauses: map[string]clause.Clause{}, - Context: ctx, - }, - } + return DB{Config: db.Config, Statement: stmt} } return db diff --git a/migrator/migrator.go b/migrator/migrator.go index 730e8cfe..b2458bfc 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -27,7 +27,7 @@ type Config struct { func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { stmt := m.DB.Statement if stmt == nil { - stmt = &gorm.Statement{DB: m.DB} + stmt = &gorm.Statement{DB: *m.DB} } if err := stmt.Parse(value); err != nil { @@ -496,7 +496,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i parseDependence := func(value interface{}, addToList bool) { dep := Dependency{ - Statement: &gorm.Statement{DB: m.DB, Dest: value}, + Statement: &gorm.Statement{DB: *m.DB, Dest: value}, } dep.Parse(value) diff --git a/statement.go b/statement.go index 10b62567..0190df7c 100644 --- a/statement.go +++ b/statement.go @@ -25,17 +25,16 @@ type Statement struct { Omits []string // omit columns Settings sync.Map ConnPool ConnPool - DB *DB + DB DB Schema *schema.Schema Context context.Context Error error RowsAffected int64 RaiseErrorOnNotFound bool - - // SQL Builder - SQL strings.Builder - Vars []interface{} - NamedVars []sql.NamedArg + SQL strings.Builder + Vars []interface{} + NamedVars []sql.NamedArg + placeholders strings.Builder } // StatementOptimizer statement optimizer interface @@ -112,41 +111,43 @@ func (stmt Statement) Quote(field interface{}) string { // Write write string func (stmt *Statement) AddVar(vars ...interface{}) string { - var placeholders strings.Builder + stmt.placeholders = strings.Builder{} + stmt.placeholders.Reset() + for idx, v := range vars { if idx > 0 { - placeholders.WriteByte(',') + stmt.placeholders.WriteByte(',') } switch v := v.(type) { case sql.NamedArg: if len(v.Name) > 0 { stmt.NamedVars = append(stmt.NamedVars, v) - placeholders.WriteByte('@') - placeholders.WriteString(v.Name) + stmt.placeholders.WriteByte('@') + stmt.placeholders.WriteString(v.Name) } else { stmt.Vars = append(stmt.Vars, v.Value) - placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value)) + stmt.placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value)) } case clause.Column, clause.Table: - placeholders.WriteString(stmt.Quote(v)) + stmt.placeholders.WriteString(stmt.Quote(v)) case clause.Expr: - placeholders.WriteString(v.SQL) + stmt.placeholders.WriteString(v.SQL) stmt.Vars = append(stmt.Vars, v.Vars...) case []interface{}: if len(v) > 0 { - placeholders.WriteByte('(') - placeholders.WriteString(stmt.AddVar(v...)) - placeholders.WriteByte(')') + stmt.placeholders.WriteByte('(') + stmt.placeholders.WriteString(stmt.AddVar(v...)) + stmt.placeholders.WriteByte(')') } else { - placeholders.WriteString("(NULL)") + stmt.placeholders.WriteString("(NULL)") } default: stmt.Vars = append(stmt.Vars, v) - placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v)) + stmt.placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v)) } } - return placeholders.String() + return stmt.placeholders.String() } // AddClause add clause @@ -261,3 +262,29 @@ func (stmt *Statement) Parse(value interface{}) (err error) { } return err } + +func (stmt *Statement) reinit() { + stmt.Table = "" + stmt.Model = nil + stmt.Selects = nil + stmt.Omits = nil + stmt.ConnPool = stmt.DB.Config.ConnPool + stmt.Schema = nil + stmt.Context = context.Background() + stmt.Error = nil + stmt.RowsAffected = 0 + stmt.RaiseErrorOnNotFound = false + + stmt.SQL.Reset() + stmt.Vars = nil + stmt.NamedVars = nil + + for k := range stmt.Clauses { + delete(stmt.Clauses, k) + } + + stmt.Settings.Range(func(k, _ interface{}) bool { + stmt.Settings.Delete(k) + return true + }) +} From 504f42760a2f4be453c51798bc075dc7fd414bd5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 9 Mar 2020 17:07:00 +0800 Subject: [PATCH 274/881] Refactor clause Writer --- clause/clause.go | 11 ++++--- clause/delete.go | 4 +-- clause/expression.go | 60 +++++++++++++++++++++-------------- clause/from.go | 8 ++--- clause/group_by.go | 2 +- clause/insert.go | 4 +-- clause/limit.go | 8 ++--- clause/locking.go | 8 ++--- clause/order_by.go | 2 +- clause/set.go | 2 +- clause/update.go | 2 +- clause/values.go | 6 ++-- clause/where.go | 12 +++---- dialects/mssql/mssql.go | 10 +++--- dialects/mysql/mysql.go | 10 +++--- dialects/postgres/postgres.go | 10 +++--- dialects/sqlite/sqlite.go | 10 +++--- interfaces.go | 4 +-- statement.go | 41 ++++++++++-------------- tests/dummy_dialecter.go | 11 +++---- 20 files changed, 117 insertions(+), 108 deletions(-) diff --git a/clause/clause.go b/clause/clause.go index df8e3a57..59b229ce 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -12,13 +12,16 @@ type ClauseBuilder interface { Build(Clause, Builder) } +type Writer interface { + WriteByte(byte) error + WriteString(string) (int, error) +} + // Builder builder interface type Builder interface { - WriteByte(byte) error - Write(sql ...string) error + Writer WriteQuoted(field interface{}) error - AddVar(vars ...interface{}) string - Quote(field interface{}) string + AddVar(Writer, ...interface{}) } // Clause diff --git a/clause/delete.go b/clause/delete.go index 2a622b45..fc462cd7 100644 --- a/clause/delete.go +++ b/clause/delete.go @@ -9,11 +9,11 @@ func (d Delete) Name() string { } func (d Delete) Build(builder Builder) { - builder.Write("DELETE") + builder.WriteString("DELETE") if d.Modifier != "" { builder.WriteByte(' ') - builder.Write(d.Modifier) + builder.WriteString(d.Modifier) } } diff --git a/clause/expression.go b/clause/expression.go index d72db08d..8150f838 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -1,9 +1,5 @@ package clause -import ( - "strings" -) - // Expression expression interface type Expression interface { Build(builder Builder) @@ -22,11 +18,15 @@ type Expr struct { // Build build raw expression func (expr Expr) Build(builder Builder) { - sql := expr.SQL - for _, v := range expr.Vars { - sql = strings.Replace(sql, "?", builder.AddVar(v), 1) + var idx int + for _, v := range []byte(expr.SQL) { + if v == '?' { + builder.AddVar(builder, expr.Vars[idx]) + idx++ + } else { + builder.WriteByte(v) + } } - builder.Write(sql) } // IN Whether a value is within a set of values @@ -40,11 +40,14 @@ func (in IN) Build(builder Builder) { switch len(in.Values) { case 0: - builder.Write(" IN (NULL)") + builder.WriteString(" IN (NULL)") case 1: - builder.Write(" = ", builder.AddVar(in.Values...)) + builder.WriteString(" = ") + builder.AddVar(builder, in.Values...) default: - builder.Write(" IN (", builder.AddVar(in.Values...), ")") + builder.WriteString(" IN (") + builder.AddVar(builder, in.Values...) + builder.WriteByte(')') } } @@ -52,9 +55,12 @@ func (in IN) NegationBuild(builder Builder) { switch len(in.Values) { case 0: case 1: - builder.Write(" <> ", builder.AddVar(in.Values...)) + builder.WriteString(" <> ") + builder.AddVar(builder, in.Values...) default: - builder.Write(" NOT IN (", builder.AddVar(in.Values...), ")") + builder.WriteString(" NOT IN (") + builder.AddVar(builder, in.Values...) + builder.WriteByte(')') } } @@ -68,9 +74,10 @@ func (eq Eq) Build(builder Builder) { builder.WriteQuoted(eq.Column) if eq.Value == nil { - builder.Write(" IS NULL") + builder.WriteString(" IS NULL") } else { - builder.Write(" = ", builder.AddVar(eq.Value)) + builder.WriteString(" = ") + builder.AddVar(builder, eq.Value) } } @@ -85,9 +92,10 @@ func (neq Neq) Build(builder Builder) { builder.WriteQuoted(neq.Column) if neq.Value == nil { - builder.Write(" IS NOT NULL") + builder.WriteString(" IS NOT NULL") } else { - builder.Write(" <> ", builder.AddVar(neq.Value)) + builder.WriteString(" <> ") + builder.AddVar(builder, neq.Value) } } @@ -100,7 +108,8 @@ type Gt Eq func (gt Gt) Build(builder Builder) { builder.WriteQuoted(gt.Column) - builder.Write(" > ", builder.AddVar(gt.Value)) + builder.WriteString(" > ") + builder.AddVar(builder, gt.Value) } func (gt Gt) NegationBuild(builder Builder) { @@ -112,7 +121,8 @@ type Gte Eq func (gte Gte) Build(builder Builder) { builder.WriteQuoted(gte.Column) - builder.Write(" >= ", builder.AddVar(gte.Value)) + builder.WriteString(" >= ") + builder.AddVar(builder, gte.Value) } func (gte Gte) NegationBuild(builder Builder) { @@ -124,7 +134,8 @@ type Lt Eq func (lt Lt) Build(builder Builder) { builder.WriteQuoted(lt.Column) - builder.Write(" < ", builder.AddVar(lt.Value)) + builder.WriteString(" < ") + builder.AddVar(builder, lt.Value) } func (lt Lt) NegationBuild(builder Builder) { @@ -136,7 +147,8 @@ type Lte Eq func (lte Lte) Build(builder Builder) { builder.WriteQuoted(lte.Column) - builder.Write(" <= ", builder.AddVar(lte.Value)) + builder.WriteString(" <= ") + builder.AddVar(builder, lte.Value) } func (lte Lte) NegationBuild(builder Builder) { @@ -148,12 +160,14 @@ type Like Eq func (like Like) Build(builder Builder) { builder.WriteQuoted(like.Column) - builder.Write(" LIKE ", builder.AddVar(like.Value)) + builder.WriteString(" LIKE ") + builder.AddVar(builder, like.Value) } func (like Like) NegationBuild(builder Builder) { builder.WriteQuoted(like.Column) - builder.Write(" NOT LIKE ", builder.AddVar(like.Value)) + builder.WriteString(" NOT LIKE ") + builder.AddVar(builder, like.Value) } // Map diff --git a/clause/from.go b/clause/from.go index f01065b5..5e8c5d25 100644 --- a/clause/from.go +++ b/clause/from.go @@ -50,18 +50,18 @@ func (from From) Build(builder Builder) { func (join Join) Build(builder Builder) { if join.Type != "" { - builder.Write(string(join.Type)) + builder.WriteString(string(join.Type)) builder.WriteByte(' ') } - builder.Write("JOIN ") + builder.WriteString("JOIN ") builder.WriteQuoted(join.Table) if len(join.ON.Exprs) > 0 { - builder.Write(" ON ") + builder.WriteString(" ON ") join.ON.Build(builder) } else if len(join.Using) > 0 { - builder.Write(" USING (") + builder.WriteString(" USING (") for idx, c := range join.Using { if idx > 0 { builder.WriteByte(',') diff --git a/clause/group_by.go b/clause/group_by.go index a245d50a..c1383c36 100644 --- a/clause/group_by.go +++ b/clause/group_by.go @@ -22,7 +22,7 @@ func (groupBy GroupBy) Build(builder Builder) { } if len(groupBy.Having) > 0 { - builder.Write(" HAVING ") + builder.WriteString(" HAVING ") Where{Exprs: groupBy.Having}.Build(builder) } } diff --git a/clause/insert.go b/clause/insert.go index 3f86c98f..8efaa035 100644 --- a/clause/insert.go +++ b/clause/insert.go @@ -13,11 +13,11 @@ func (insert Insert) Name() string { // Build build insert clause func (insert Insert) Build(builder Builder) { if insert.Modifier != "" { - builder.Write(insert.Modifier) + builder.WriteString(insert.Modifier) builder.WriteByte(' ') } - builder.Write("INTO ") + builder.WriteString("INTO ") if insert.Table.Name == "" { builder.WriteQuoted(currentTable) } else { diff --git a/clause/limit.go b/clause/limit.go index e30666af..ba5cf6c4 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -16,12 +16,12 @@ func (limit Limit) Name() string { // Build build where clause func (limit Limit) Build(builder Builder) { if limit.Limit > 0 { - builder.Write("LIMIT ") - builder.Write(strconv.Itoa(limit.Limit)) + builder.WriteString("LIMIT ") + builder.WriteString(strconv.Itoa(limit.Limit)) if limit.Offset > 0 { - builder.Write(" OFFSET ") - builder.Write(strconv.Itoa(limit.Offset)) + builder.WriteString(" OFFSET ") + builder.WriteString(strconv.Itoa(limit.Offset)) } } } diff --git a/clause/locking.go b/clause/locking.go index 48b84b34..3be1063b 100644 --- a/clause/locking.go +++ b/clause/locking.go @@ -22,16 +22,16 @@ func (f For) Build(builder Builder) { builder.WriteByte(' ') } - builder.Write("FOR ") - builder.Write(locking.Strength) + builder.WriteString("FOR ") + builder.WriteString(locking.Strength) if locking.Table.Name != "" { - builder.Write(" OF ") + builder.WriteString(" OF ") builder.WriteQuoted(locking.Table) } if locking.Options != "" { builder.WriteByte(' ') - builder.Write(locking.Options) + builder.WriteString(locking.Options) } } } diff --git a/clause/order_by.go b/clause/order_by.go index 2734f2bc..307bf930 100644 --- a/clause/order_by.go +++ b/clause/order_by.go @@ -24,7 +24,7 @@ func (orderBy OrderBy) Build(builder Builder) { builder.WriteQuoted(column.Column) if column.Desc { - builder.Write(" DESC") + builder.WriteString(" DESC") } } } diff --git a/clause/set.go b/clause/set.go index 3b7e972d..de78b1be 100644 --- a/clause/set.go +++ b/clause/set.go @@ -19,7 +19,7 @@ func (set Set) Build(builder Builder) { } builder.WriteQuoted(assignment.Column) builder.WriteByte('=') - builder.Write(builder.AddVar(assignment.Value)) + builder.AddVar(builder, assignment.Value) } } else { builder.WriteQuoted(PrimaryColumn) diff --git a/clause/update.go b/clause/update.go index c375b373..f9d68ac6 100644 --- a/clause/update.go +++ b/clause/update.go @@ -13,7 +13,7 @@ func (update Update) Name() string { // Build build update clause func (update Update) Build(builder Builder) { if update.Modifier != "" { - builder.Write(update.Modifier) + builder.WriteString(update.Modifier) builder.WriteByte(' ') } diff --git a/clause/values.go b/clause/values.go index 2c8dcf89..a997fc26 100644 --- a/clause/values.go +++ b/clause/values.go @@ -22,7 +22,7 @@ func (values Values) Build(builder Builder) { } builder.WriteByte(')') - builder.Write(" VALUES ") + builder.WriteString(" VALUES ") for idx, value := range values.Values { if idx > 0 { @@ -30,11 +30,11 @@ func (values Values) Build(builder Builder) { } builder.WriteByte('(') - builder.Write(builder.AddVar(value...)) + builder.AddVar(builder, value...) builder.WriteByte(')') } } else { - builder.Write("DEFAULT VALUES") + builder.WriteString("DEFAULT VALUES") } } diff --git a/clause/where.go b/clause/where.go index 0ee1a141..08c78b22 100644 --- a/clause/where.go +++ b/clause/where.go @@ -26,9 +26,9 @@ func (where Where) Build(builder Builder) { if expr != nil { if idx > 0 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { - builder.Write(" OR ") + builder.WriteString(" OR ") } else { - builder.Write(" AND ") + builder.WriteString(" AND ") } } @@ -65,7 +65,7 @@ func (and AndConditions) Build(builder Builder) { } for idx, c := range and.Exprs { if idx > 0 { - builder.Write(" AND ") + builder.WriteString(" AND ") } c.Build(builder) } @@ -91,7 +91,7 @@ func (or OrConditions) Build(builder Builder) { } for idx, c := range or.Exprs { if idx > 0 { - builder.Write(" OR ") + builder.WriteString(" OR ") } c.Build(builder) } @@ -117,13 +117,13 @@ func (not NotConditions) Build(builder Builder) { } for idx, c := range not.Exprs { if idx > 0 { - builder.Write(" AND ") + builder.WriteString(" AND ") } if negationBuilder, ok := c.(NegationExpressionBuilder); ok { negationBuilder.NegationBuild(builder) } else { - builder.Write(" NOT ") + builder.WriteString(" NOT ") c.Build(builder) } } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 7e51de75..0842fa79 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -5,11 +5,11 @@ import ( "fmt" "regexp" "strconv" - "strings" _ "github.com/denisenkom/go-mssqldb" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/schema" @@ -42,10 +42,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "@p" + strconv.Itoa(len(stmt.Vars)) } -func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) { - builder.WriteByte('"') - builder.WriteString(str) - builder.WriteByte('"') +func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { + writer.WriteByte('"') + writer.WriteString(str) + writer.WriteByte('"') } var numericPlaceholder = regexp.MustCompile("@p(\\d+)") diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index 55b5a53f..cff779e3 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -4,11 +4,11 @@ import ( "database/sql" "fmt" "math" - "strings" _ "github.com/go-sql-driver/mysql" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/schema" @@ -40,10 +40,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } -func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) { - builder.WriteByte('`') - builder.WriteString(str) - builder.WriteByte('`') +func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { + writer.WriteByte('`') + writer.WriteString(str) + writer.WriteByte('`') } func (dialector Dialector) Explain(sql string, vars ...interface{}) string { diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index e90fa4ae..99569f06 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -5,10 +5,10 @@ import ( "fmt" "regexp" "strconv" - "strings" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/schema" @@ -42,10 +42,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "$" + strconv.Itoa(len(stmt.Vars)) } -func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) { - builder.WriteByte('"') - builder.WriteString(str) - builder.WriteByte('"') +func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { + writer.WriteByte('"') + writer.WriteString(str) + writer.WriteByte('"') } var numericPlaceholder = regexp.MustCompile("\\$(\\d+)") diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 8e3cc058..4105863f 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -2,10 +2,10 @@ package sqlite import ( "database/sql" - "strings" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/migrator" "github.com/jinzhu/gorm/schema" @@ -39,10 +39,10 @@ func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } -func (dialector Dialector) QuoteTo(builder *strings.Builder, str string) { - builder.WriteByte('`') - builder.WriteString(str) - builder.WriteByte('`') +func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { + writer.WriteByte('`') + writer.WriteString(str) + writer.WriteByte('`') } func (dialector Dialector) Explain(sql string, vars ...interface{}) string { diff --git a/interfaces.go b/interfaces.go index 9859d1fa..310f801a 100644 --- a/interfaces.go +++ b/interfaces.go @@ -3,8 +3,8 @@ package gorm import ( "context" "database/sql" - "strings" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/schema" ) @@ -14,7 +14,7 @@ type Dialector interface { Migrator(db *DB) Migrator DataTypeOf(*schema.Field) string BindVar(stmt *Statement, v interface{}) string - QuoteTo(*strings.Builder, string) + QuoteTo(clause.Writer, string) Explain(sql string, vars ...interface{}) string } diff --git a/statement.go b/statement.go index 0190df7c..e632b409 100644 --- a/statement.go +++ b/statement.go @@ -34,7 +34,6 @@ type Statement struct { SQL strings.Builder Vars []interface{} NamedVars []sql.NamedArg - placeholders strings.Builder } // StatementOptimizer statement optimizer interface @@ -43,15 +42,12 @@ type StatementOptimizer interface { } // Write write string -func (stmt *Statement) Write(sql ...string) (err error) { - for _, s := range sql { - _, err = stmt.SQL.WriteString(s) - } - return +func (stmt *Statement) WriteString(str string) (int, error) { + return stmt.SQL.WriteString(str) } // Write write string -func (stmt *Statement) WriteByte(c byte) (err error) { +func (stmt *Statement) WriteByte(c byte) error { return stmt.SQL.WriteByte(c) } @@ -62,7 +58,7 @@ func (stmt *Statement) WriteQuoted(value interface{}) error { } // QuoteTo write quoted value to writer -func (stmt Statement) QuoteTo(writer *strings.Builder, field interface{}) { +func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { switch v := field.(type) { case clause.Table: if v.Name == clause.CurrentTable { @@ -110,44 +106,41 @@ func (stmt Statement) Quote(field interface{}) string { } // Write write string -func (stmt *Statement) AddVar(vars ...interface{}) string { - stmt.placeholders = strings.Builder{} - stmt.placeholders.Reset() - +func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { for idx, v := range vars { if idx > 0 { - stmt.placeholders.WriteByte(',') + writer.WriteByte(',') } switch v := v.(type) { case sql.NamedArg: if len(v.Name) > 0 { stmt.NamedVars = append(stmt.NamedVars, v) - stmt.placeholders.WriteByte('@') - stmt.placeholders.WriteString(v.Name) + writer.WriteByte('@') + writer.WriteString(v.Name) } else { stmt.Vars = append(stmt.Vars, v.Value) - stmt.placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value)) + writer.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value)) } case clause.Column, clause.Table: - stmt.placeholders.WriteString(stmt.Quote(v)) + stmt.QuoteTo(writer, v) case clause.Expr: - stmt.placeholders.WriteString(v.SQL) + writer.WriteString(v.SQL) stmt.Vars = append(stmt.Vars, v.Vars...) case []interface{}: if len(v) > 0 { - stmt.placeholders.WriteByte('(') - stmt.placeholders.WriteString(stmt.AddVar(v...)) - stmt.placeholders.WriteByte(')') + writer.WriteByte('(') + stmt.skipResetPlacehodler = true + stmt.AddVar(writer, v...) + writer.WriteByte(')') } else { - stmt.placeholders.WriteString("(NULL)") + writer.WriteString("(NULL)") } default: stmt.Vars = append(stmt.Vars, v) - stmt.placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v)) + writer.WriteString(stmt.DB.Dialector.BindVar(stmt, v)) } } - return stmt.placeholders.String() } // AddClause add clause diff --git a/tests/dummy_dialecter.go b/tests/dummy_dialecter.go index 9e3146fe..f6e9d9f9 100644 --- a/tests/dummy_dialecter.go +++ b/tests/dummy_dialecter.go @@ -1,9 +1,8 @@ package tests import ( - "strings" - "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/schema" ) @@ -23,10 +22,10 @@ func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } -func (DummyDialector) QuoteTo(builder *strings.Builder, str string) { - builder.WriteByte('`') - builder.WriteString(str) - builder.WriteByte('`') +func (DummyDialector) QuoteTo(writer clause.Writer, str string) { + writer.WriteByte('`') + writer.WriteString(str) + writer.WriteByte('`') } func (DummyDialector) Explain(sql string, vars ...interface{}) string { From 2a0c3e39f22cc840019fb42287d130b9c4cf2609 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 9 Mar 2020 17:59:54 +0800 Subject: [PATCH 275/881] AddVar accept writer --- dialects/mssql/mssql.go | 5 +++-- dialects/mysql/mysql.go | 4 ++-- dialects/postgres/postgres.go | 5 +++-- dialects/sqlite/sqlite.go | 4 ++-- interfaces.go | 2 +- statement.go | 5 ++--- tests/dummy_dialecter.go | 4 ++-- 7 files changed, 15 insertions(+), 14 deletions(-) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 0842fa79..8cf1e2e2 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -38,8 +38,9 @@ func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { }}} } -func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { - return "@p" + strconv.Itoa(len(stmt.Vars)) +func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { + writer.WriteString("@p") + writer.WriteString(strconv.Itoa(len(stmt.Vars))) } func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index cff779e3..514dfc14 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -36,8 +36,8 @@ func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { }}} } -func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { - return "?" +func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { + writer.WriteByte('?') } func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 99569f06..c2ddd82c 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -38,8 +38,9 @@ func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { }}} } -func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { - return "$" + strconv.Itoa(len(stmt.Vars)) +func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { + writer.WriteByte('$') + writer.WriteString(strconv.Itoa(len(stmt.Vars))) } func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 4105863f..c4837463 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -35,8 +35,8 @@ func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { }}} } -func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { - return "?" +func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { + writer.WriteByte('?') } func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { diff --git a/interfaces.go b/interfaces.go index 310f801a..9dd00c15 100644 --- a/interfaces.go +++ b/interfaces.go @@ -13,7 +13,7 @@ type Dialector interface { Initialize(*DB) error Migrator(db *DB) Migrator DataTypeOf(*schema.Field) string - BindVar(stmt *Statement, v interface{}) string + BindVarTo(writer clause.Writer, stmt *Statement, v interface{}) QuoteTo(clause.Writer, string) Explain(sql string, vars ...interface{}) string } diff --git a/statement.go b/statement.go index e632b409..6bc8b384 100644 --- a/statement.go +++ b/statement.go @@ -120,7 +120,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { writer.WriteString(v.Name) } else { stmt.Vars = append(stmt.Vars, v.Value) - writer.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value)) + stmt.DB.Dialector.BindVarTo(writer, stmt, v.Value) } case clause.Column, clause.Table: stmt.QuoteTo(writer, v) @@ -130,7 +130,6 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { case []interface{}: if len(v) > 0 { writer.WriteByte('(') - stmt.skipResetPlacehodler = true stmt.AddVar(writer, v...) writer.WriteByte(')') } else { @@ -138,7 +137,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } default: stmt.Vars = append(stmt.Vars, v) - writer.WriteString(stmt.DB.Dialector.BindVar(stmt, v)) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) } } } diff --git a/tests/dummy_dialecter.go b/tests/dummy_dialecter.go index f6e9d9f9..63af0c9c 100644 --- a/tests/dummy_dialecter.go +++ b/tests/dummy_dialecter.go @@ -18,8 +18,8 @@ func (DummyDialector) Migrator(*gorm.DB) gorm.Migrator { return nil } -func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string { - return "?" +func (DummyDialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { + writer.WriteByte('?') } func (DummyDialector) QuoteTo(writer clause.Writer, str string) { From 9e8a4db36ba0b6c8d5ddd6e23f3968126f06dae1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 9 Mar 2020 20:37:01 +0800 Subject: [PATCH 276/881] Use *gorm.DB to replace gorm.DB --- callbacks.go | 1 - chainable_api.go | 40 +++++++++--------- dialects/sqlite/sqlite_test.go | 6 +-- finisher_api.go | 76 +++++++++++++++++----------------- gorm.go | 35 +++++++++------- migrator/migrator.go | 4 +- statement.go | 10 +---- 7 files changed, 84 insertions(+), 88 deletions(-) diff --git a/callbacks.go b/callbacks.go index e2907178..e1b2b410 100644 --- a/callbacks.go +++ b/callbacks.go @@ -90,7 +90,6 @@ func (p *processor) Execute(db *DB) { } if stmt := db.Statement; stmt != nil { - db.Error = stmt.Error db.RowsAffected = stmt.RowsAffected db.Logger.Trace(curTime, func() (string, int64) { diff --git a/chainable_api.go b/chainable_api.go index c2a6247b..432caa4f 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -13,14 +13,14 @@ import ( // db.Model(&User{}).Update("name", "hello") // // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` // db.Model(&user).Update("name", "hello") -func (db DB) Model(value interface{}) (tx DB) { +func (db *DB) Model(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Model = value return } // Clauses Add clauses -func (db DB) Clauses(conds ...clause.Expression) (tx DB) { +func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { tx = db.getInstance() var whereConds []interface{} @@ -39,14 +39,14 @@ func (db DB) Clauses(conds ...clause.Expression) (tx DB) { } // Table specify the table you would like to run db operations -func (db DB) Table(name string) (tx DB) { +func (db *DB) Table(name string) (tx *DB) { tx = db.getInstance() tx.Statement.Table = name return } // Select specify fields that you want when querying, creating, updating -func (db DB) Select(query interface{}, args ...interface{}) (tx DB) { +func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() switch v := query.(type) { @@ -97,7 +97,7 @@ func (db DB) Select(query interface{}, args ...interface{}) (tx DB) { } // Omit specify fields that you want to ignore when creating, updating and querying -func (db DB) Omit(columns ...string) (tx DB) { +func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { @@ -108,21 +108,21 @@ func (db DB) Omit(columns ...string) (tx DB) { return } -func (db DB) Where(query interface{}, args ...interface{}) (tx DB) { +func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(query, args...)}) return } // Not add NOT condition -func (db DB) Not(query interface{}, args ...interface{}) (tx DB) { +func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}}) return } // Or add OR conditions -func (db DB) Or(query interface{}, args ...interface{}) (tx DB) { +func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(tx.Statement.BuildCondtion(query, args...)...)}}) return @@ -131,13 +131,13 @@ func (db DB) Or(query interface{}, args ...interface{}) (tx DB) { // Joins specify Joins conditions // db.Joins("Account").Find(&user) // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) -func (db DB) Joins(query string, args ...interface{}) (tx DB) { +func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() return } // Group specify the group method on the find -func (db DB) Group(name string) (tx DB) { +func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.GroupBy{ Columns: []clause.Column{{Name: name}}, @@ -146,7 +146,7 @@ func (db DB) Group(name string) (tx DB) { } // Having specify HAVING conditions for GROUP BY -func (db DB) Having(query interface{}, args ...interface{}) (tx DB) { +func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.GroupBy{ Having: tx.Statement.BuildCondtion(query, args...), @@ -157,7 +157,7 @@ func (db DB) Having(query interface{}, args ...interface{}) (tx DB) { // Order specify order when retrieve records from database // db.Order("name DESC") // db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression -func (db DB) Order(value interface{}) (tx DB) { +func (db *DB) Order(value interface{}) (tx *DB) { tx = db.getInstance() switch v := value.(type) { @@ -176,14 +176,14 @@ func (db DB) Order(value interface{}) (tx DB) { } // Limit specify the number of records to be retrieved -func (db DB) Limit(limit int) (tx DB) { +func (db *DB) Limit(limit int) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Limit{Limit: limit}) return } // Offset specify the number of records to skip before starting to return the records -func (db DB) Offset(offset int) (tx DB) { +func (db *DB) Offset(offset int) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Limit{Offset: offset}) return @@ -201,7 +201,7 @@ func (db DB) Offset(offset int) (tx DB) { // } // // db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) -func (db DB) Scopes(funcs ...func(DB) DB) DB { +func (db *DB) Scopes(funcs ...func(*DB) *DB) *DB { for _, f := range funcs { db = f(db) } @@ -210,27 +210,27 @@ func (db DB) Scopes(funcs ...func(DB) DB) DB { // Preload preload associations with given conditions // db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) -func (db DB) Preload(column string, conditions ...interface{}) (tx DB) { +func (db *DB) Preload(column string, conditions ...interface{}) (tx *DB) { tx = db.getInstance() return } -func (db DB) Assign(attrs ...interface{}) (tx DB) { +func (db *DB) Assign(attrs ...interface{}) (tx *DB) { tx = db.getInstance() return } -func (db DB) Attrs(attrs ...interface{}) (tx DB) { +func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { tx = db.getInstance() return } -func (db DB) Unscoped() (tx DB) { +func (db *DB) Unscoped() (tx *DB) { tx = db.getInstance() return } -func (db DB) Raw(sql string, values ...interface{}) (tx DB) { +func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.SQL = strings.Builder{} clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) diff --git a/dialects/sqlite/sqlite_test.go b/dialects/sqlite/sqlite_test.go index 7a07db01..a42bc8ee 100644 --- a/dialects/sqlite/sqlite_test.go +++ b/dialects/sqlite/sqlite_test.go @@ -12,7 +12,7 @@ import ( ) var ( - DB gorm.DB + DB *gorm.DB err error ) @@ -23,9 +23,9 @@ func init() { } func TestCURD(t *testing.T) { - tests.RunTestsSuit(t, &DB) + tests.RunTestsSuit(t, DB) } func TestMigrate(t *testing.T) { - tests.TestMigrate(t, &DB) + tests.TestMigrate(t, DB) } diff --git a/finisher_api.go b/finisher_api.go index 4b3829a2..62c1af30 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -9,15 +9,15 @@ import ( ) // Create insert the value into database -func (db DB) Create(value interface{}) (tx DB) { +func (db *DB) Create(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = value - tx.callbacks.Create().Execute(&tx) + tx.callbacks.Create().Execute(tx) return } // Save update value in database, if the value doesn't have primary key, will insert it -func (db DB) Save(value interface{}) (tx DB) { +func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = value @@ -26,7 +26,7 @@ func (db DB) Save(value interface{}) (tx DB) { reflectValue := reflect.ValueOf(value) for idx, pf := range tx.Statement.Schema.PrimaryFields { if pv, isZero := pf.ValueOf(reflectValue); isZero { - tx.callbacks.Create().Execute(&tx) + tx.callbacks.Create().Execute(tx) where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} return } @@ -38,12 +38,12 @@ func (db DB) Save(value interface{}) (tx DB) { if len(tx.Statement.Selects) == 0 { tx.Statement.Selects = []string{"*"} } - tx.callbacks.Update().Execute(&tx) + tx.callbacks.Update().Execute(tx) return } // First find first record that match given conditions, order by primary key -func (db DB) First(out interface{}, conds ...interface{}) (tx DB) { +func (db *DB) First(out interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) @@ -52,24 +52,24 @@ func (db DB) First(out interface{}, conds ...interface{}) (tx DB) { } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out - tx.callbacks.Query().Execute(&tx) + tx.callbacks.Query().Execute(tx) return } // Take return a record that match given conditions, the order will depend on the database implementation -func (db DB) Take(out interface{}, conds ...interface{}) (tx DB) { +func (db *DB) Take(out interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1) if len(conds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out - tx.callbacks.Query().Execute(&tx) + tx.callbacks.Query().Execute(tx) return } // Last find last record that match given conditions, order by primary key -func (db DB) Last(out interface{}, conds ...interface{}) (tx DB) { +func (db *DB) Last(out interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Desc: true, @@ -79,101 +79,101 @@ func (db DB) Last(out interface{}, conds ...interface{}) (tx DB) { } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = out - tx.callbacks.Query().Execute(&tx) + tx.callbacks.Query().Execute(tx) return } // Find find records that match given conditions -func (db DB) Find(out interface{}, conds ...interface{}) (tx DB) { +func (db *DB) Find(out interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) } tx.Statement.Dest = out - tx.callbacks.Query().Execute(&tx) + tx.callbacks.Query().Execute(tx) return } -func (db DB) FirstOrInit(out interface{}, where ...interface{}) (tx DB) { +func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() return } -func (db DB) FirstOrCreate(out interface{}, where ...interface{}) (tx DB) { +func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() return } // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (db DB) Update(column string, value interface{}) (tx DB) { +func (db *DB) Update(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} - tx.callbacks.Update().Execute(&tx) + tx.callbacks.Update().Execute(tx) return } // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (db DB) Updates(values interface{}) (tx DB) { +func (db *DB) Updates(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values - tx.callbacks.Update().Execute(&tx) + tx.callbacks.Update().Execute(tx) return } -func (db DB) UpdateColumn(column string, value interface{}) (tx DB) { +func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} - tx.callbacks.Update().Execute(&tx) + tx.callbacks.Update().Execute(tx) return } -func (db DB) UpdateColumns(values interface{}) (tx DB) { +func (db *DB) UpdateColumns(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values - tx.callbacks.Update().Execute(&tx) + tx.callbacks.Update().Execute(tx) return } // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition -func (db DB) Delete(value interface{}, conds ...interface{}) (tx DB) { +func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) } tx.Statement.Dest = value - tx.callbacks.Delete().Execute(&tx) + tx.callbacks.Delete().Execute(tx) return } -func (db DB) Count(value interface{}) (tx DB) { +func (db *DB) Count(value interface{}) (tx *DB) { tx = db.getInstance() return } -func (db DB) Row() *sql.Row { +func (db *DB) Row() *sql.Row { tx := db.getInstance() - tx.callbacks.Row().Execute(&tx) + tx.callbacks.Row().Execute(tx) return tx.Statement.Dest.(*sql.Row) } -func (db DB) Rows() (*sql.Rows, error) { +func (db *DB) Rows() (*sql.Rows, error) { tx := db.Set("rows", true) - tx.callbacks.Row().Execute(&tx) + tx.callbacks.Row().Execute(tx) return tx.Statement.Dest.(*sql.Rows), tx.Error } // Scan scan value to a struct -func (db DB) Scan(dest interface{}) (tx DB) { +func (db *DB) Scan(dest interface{}) (tx *DB) { tx = db.getInstance() return } -func (db DB) ScanRows(rows *sql.Rows, result interface{}) error { +func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error { return nil } // Transaction start a transaction as a block, return error will rollback, otherwise to commit. -func (db DB) Transaction(fc func(tx DB) error, opts ...*sql.TxOptions) (err error) { +func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { panicked := true tx := db.Begin(opts...) defer func() { @@ -194,7 +194,7 @@ func (db DB) Transaction(fc func(tx DB) error, opts ...*sql.TxOptions) (err erro } // Begin begins a transaction -func (db DB) Begin(opts ...*sql.TxOptions) (tx DB) { +func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { tx = db.getInstance() if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { var opt *sql.TxOptions @@ -213,7 +213,7 @@ func (db DB) Begin(opts ...*sql.TxOptions) (tx DB) { } // Commit commit a transaction -func (db DB) Commit() DB { +func (db *DB) Commit() *DB { if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { db.AddError(comminter.Commit()) } else { @@ -223,7 +223,7 @@ func (db DB) Commit() DB { } // Rollback rollback a transaction -func (db DB) Rollback() DB { +func (db *DB) Rollback() *DB { if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { db.AddError(comminter.Rollback()) } else { @@ -233,10 +233,10 @@ func (db DB) Rollback() DB { } // Exec execute raw sql -func (db DB) Exec(sql string, values ...interface{}) (tx DB) { +func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.SQL = strings.Builder{} clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) - tx.callbacks.Raw().Execute(&tx) + tx.callbacks.Raw().Execute(tx) return } diff --git a/gorm.go b/gorm.go index b7d3e929..2d78c8d9 100644 --- a/gorm.go +++ b/gorm.go @@ -2,6 +2,7 @@ package gorm import ( "context" + "fmt" "sync" "time" @@ -51,7 +52,7 @@ type Session struct { } // Open initialize db session based on dialector -func Open(dialector Dialector, config *Config) (db DB, err error) { +func Open(dialector Dialector, config *Config) (db *DB, err error) { if config == nil { config = &Config{} } @@ -87,21 +88,21 @@ func Open(dialector Dialector, config *Config) (db DB, err error) { }, } - db = DB{ + db = &DB{ Config: config, clone: true, } - db.callbacks = initializeCallbacks(&db) + db.callbacks = initializeCallbacks(db) if dialector != nil { - err = dialector.Initialize(&db) + err = dialector.Initialize(db) } return } // Session create new db session -func (db DB) Session(config *Session) DB { +func (db *DB) Session(config *Session) *DB { var ( tx = db.getInstance() txConfig = *tx.Config @@ -125,24 +126,24 @@ func (db DB) Session(config *Session) DB { } // WithContext change current instance db's context to ctx -func (db DB) WithContext(ctx context.Context) DB { +func (db *DB) WithContext(ctx context.Context) *DB { return db.Session(&Session{Context: ctx}) } // Debug start debug mode -func (db DB) Debug() (tx DB) { +func (db *DB) Debug() (tx *DB) { return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)}) } // Set store value with key into current db instance's context -func (db DB) Set(key string, value interface{}) DB { +func (db *DB) Set(key string, value interface{}) *DB { tx := db.getInstance() tx.Statement.Settings.Store(key, value) return tx } // Get get value with key from current db instance's context -func (db DB) Get(key string) (interface{}, bool) { +func (db *DB) Get(key string) (interface{}, bool) { if db.Statement != nil { return db.Statement.Settings.Load(key) } @@ -150,28 +151,32 @@ func (db DB) Get(key string) (interface{}, bool) { } // Callback returns callback manager -func (db DB) Callback() *callbacks { +func (db *DB) Callback() *callbacks { return db.callbacks } // AutoMigrate run auto migration for given models -func (db DB) AutoMigrate(dst ...interface{}) error { +func (db *DB) AutoMigrate(dst ...interface{}) error { return db.Migrator().AutoMigrate(dst...) } // AddError add error to db -func (db DB) AddError(err error) { - db.Statement.AddError(err) +func (db *DB) AddError(err error) { + if db.Error == nil { + db.Error = err + } else if err != nil { + db.Error = fmt.Errorf("%v; %w", db.Error, err) + } } -func (db DB) getInstance() DB { +func (db *DB) getInstance() *DB { if db.clone { stmt := db.Config.statementPool.Get().(*Statement) if db.Statement != nil { stmt.Context = db.Statement.Context } - return DB{Config: db.Config, Statement: stmt} + return &DB{Config: db.Config, Statement: stmt} } return db diff --git a/migrator/migrator.go b/migrator/migrator.go index b2458bfc..730e8cfe 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -27,7 +27,7 @@ type Config struct { func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { stmt := m.DB.Statement if stmt == nil { - stmt = &gorm.Statement{DB: *m.DB} + stmt = &gorm.Statement{DB: m.DB} } if err := stmt.Parse(value); err != nil { @@ -496,7 +496,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i parseDependence := func(value interface{}, addToList bool) { dep := Dependency{ - Statement: &gorm.Statement{DB: *m.DB, Dest: value}, + Statement: &gorm.Statement{DB: m.DB, Dest: value}, } dep.Parse(value) diff --git a/statement.go b/statement.go index 6bc8b384..fb3599ec 100644 --- a/statement.go +++ b/statement.go @@ -16,6 +16,7 @@ import ( // Statement statement type Statement struct { + *DB Table string Model interface{} Dest interface{} @@ -25,7 +26,6 @@ type Statement struct { Omits []string // omit columns Settings sync.Map ConnPool ConnPool - DB DB Schema *schema.Schema Context context.Context Error error @@ -219,14 +219,6 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con return conditions } -func (stmt *Statement) AddError(err error) { - if stmt.Error == nil { - stmt.Error = err - } else if err != nil { - stmt.Error = fmt.Errorf("%v; %w", stmt.Error, err) - } -} - // Build build sql with clauses names func (stmt *Statement) Build(clauses ...string) { var firstClauseWritten bool From af080e677317015c36070227e889c2943f92752a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 12 Mar 2020 08:39:42 +0800 Subject: [PATCH 277/881] Fix primary key tag --- callbacks.go | 2 -- chainable_api.go | 3 ++- clause/from.go | 41 -------------------------------- clause/joins.go | 44 +++++++++++++++++++++++++++++++---- dialects/mssql/mssql.go | 2 +- dialects/mysql/mysql.go | 6 ++++- dialects/mysql/mysql_test.go | 2 +- dialects/postgres/postgres.go | 2 +- logger/sql.go | 2 +- schema/field.go | 2 +- statement.go | 14 ++++------- tests/model.go | 2 +- tests/tests.go | 2 +- 13 files changed, 58 insertions(+), 66 deletions(-) diff --git a/callbacks.go b/callbacks.go index e1b2b410..78f1192e 100644 --- a/callbacks.go +++ b/callbacks.go @@ -90,8 +90,6 @@ func (p *processor) Execute(db *DB) { } if stmt := db.Statement; stmt != nil { - db.RowsAffected = stmt.RowsAffected - db.Logger.Trace(curTime, func() (string, int64) { return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected }, db.Error) diff --git a/chainable_api.go b/chainable_api.go index 432caa4f..7a6e8b7c 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -108,13 +108,14 @@ func (db *DB) Omit(columns ...string) (tx *DB) { return } +// Where add conditions func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(query, args...)}) return } -// Not add NOT condition +// Not add NOT conditions func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}}) diff --git a/clause/from.go b/clause/from.go index 5e8c5d25..59b0bfaf 100644 --- a/clause/from.go +++ b/clause/from.go @@ -6,23 +6,6 @@ type From struct { Joins []Join } -type JoinType string - -const ( - CrossJoin JoinType = "CROSS" - InnerJoin = "INNER" - LeftJoin = "LEFT" - RightJoin = "RIGHT" -) - -// Join join clause for from -type Join struct { - Type JoinType - Table Table - ON Where - Using []string -} - // Name from clause name func (from From) Name() string { return "FROM" @@ -48,30 +31,6 @@ func (from From) Build(builder Builder) { } } -func (join Join) Build(builder Builder) { - if join.Type != "" { - builder.WriteString(string(join.Type)) - builder.WriteByte(' ') - } - - builder.WriteString("JOIN ") - builder.WriteQuoted(join.Table) - - if len(join.ON.Exprs) > 0 { - builder.WriteString(" ON ") - join.ON.Build(builder) - } else if len(join.Using) > 0 { - builder.WriteString(" USING (") - for idx, c := range join.Using { - if idx > 0 { - builder.WriteByte(',') - } - builder.WriteQuoted(c) - } - builder.WriteByte(')') - } -} - // MergeClause merge from clause func (from From) MergeClause(clause *Clause) { if v, ok := clause.Expression.(From); ok { diff --git a/clause/joins.go b/clause/joins.go index 4983d6fd..a78bde39 100644 --- a/clause/joins.go +++ b/clause/joins.go @@ -1,8 +1,42 @@ package clause -// Joins joins clause -type Joins struct { - Name string - Query string - Vars []interface{} +type JoinType string + +const ( + CrossJoin JoinType = "CROSS" + InnerJoin = "INNER" + LeftJoin = "LEFT" + RightJoin = "RIGHT" +) + +// Join join clause for from +type Join struct { + Type JoinType + Table Table + ON Where + Using []string +} + +func (join Join) Build(builder Builder) { + if join.Type != "" { + builder.WriteString(string(join.Type)) + builder.WriteByte(' ') + } + + builder.WriteString("JOIN ") + builder.WriteQuoted(join.Table) + + if len(join.ON.Exprs) > 0 { + builder.WriteString(" ON ") + join.ON.Build(builder) + } else if len(join.Using) > 0 { + builder.WriteString(" USING (") + for idx, c := range join.Using { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(c) + } + builder.WriteByte(')') + } } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 8cf1e2e2..e5bc7dd2 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -70,7 +70,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { sqlType = "bigint" } - if field.AutoIncrement { + if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { return sqlType + " IDENTITY(1,1)" } return sqlType diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index 514dfc14..af796847 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -71,7 +71,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { sqlType += " unsigned" } - if field.AutoIncrement { + if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { sqlType += " AUTO_INCREMENT" } return sqlType @@ -94,6 +94,10 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { return fmt.Sprintf("varchar(%d)", size) case schema.Time: precision := "" + if field.Precision == 0 { + field.Precision = 3 + } + if field.Precision > 0 { precision = fmt.Sprintf("(%d)", field.Precision) } diff --git a/dialects/mysql/mysql_test.go b/dialects/mysql/mysql_test.go index 5bc1debd..cb3b240a 100644 --- a/dialects/mysql/mysql_test.go +++ b/dialects/mysql/mysql_test.go @@ -16,7 +16,7 @@ var ( ) func init() { - dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" + dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" if os.Getenv("GORM_DSN") != "" { dsn = os.Getenv("GORM_DSN") } diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index c2ddd82c..7589025d 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -60,7 +60,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { case schema.Bool: return "boolean" case schema.Int, schema.Uint: - if field.AutoIncrement { + if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { switch { case field.Size < 16: return "smallserial" diff --git a/logger/sql.go b/logger/sql.go index cb50ccf6..41c514fd 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -33,7 +33,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v if v.IsZero() { vars[idx] = escaper + "0000-00-00 00:00:00" + escaper } else { - vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper + vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper } case []byte: if isPrintable(v) { diff --git a/schema/field.go b/schema/field.go index c6de669d..ee1baf3c 100644 --- a/schema/field.go +++ b/schema/field.go @@ -219,7 +219,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if field.Size == 0 { - switch fieldValue.Kind() { + switch reflect.Indirect(fieldValue).Kind() { case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: field.Size = 64 case reflect.Int8, reflect.Uint8: diff --git a/statement.go b/statement.go index fb3599ec..298a4c56 100644 --- a/statement.go +++ b/statement.go @@ -28,17 +28,15 @@ type Statement struct { ConnPool ConnPool Schema *schema.Schema Context context.Context - Error error - RowsAffected int64 RaiseErrorOnNotFound bool SQL strings.Builder Vars []interface{} NamedVars []sql.NamedArg } -// StatementOptimizer statement optimizer interface -type StatementOptimizer interface { - OptimizeStatement(*Statement) +// StatementModifier statement modifier interface +type StatementModifier interface { + ModifyStatement(*Statement) } // Write write string @@ -144,8 +142,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { // AddClause add clause func (stmt *Statement) AddClause(v clause.Interface) { - if optimizer, ok := v.(StatementOptimizer); ok { - optimizer.OptimizeStatement(stmt) + if optimizer, ok := v.(StatementModifier); ok { + optimizer.ModifyStatement(stmt) } c, ok := stmt.Clauses[v.Name()] @@ -255,8 +253,6 @@ func (stmt *Statement) reinit() { stmt.ConnPool = stmt.DB.Config.ConnPool stmt.Schema = nil stmt.Context = context.Background() - stmt.Error = nil - stmt.RowsAffected = 0 stmt.RaiseErrorOnNotFound = false stmt.SQL.Reset() diff --git a/tests/model.go b/tests/model.go index b2d5efe1..4d686a57 100644 --- a/tests/model.go +++ b/tests/model.go @@ -21,7 +21,7 @@ type User struct { Toys []Toy `gorm:"polymorphic:Owner"` CompanyID *int Company Company - ManagerID uint + ManagerID *uint Manager *User Team []User `gorm:"foreignkey:ManagerID"` Languages []Language `gorm:"many2many:UserSpeak"` diff --git a/tests/tests.go b/tests/tests.go index 33013032..c26d743e 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -81,7 +81,7 @@ func TestFind(t *testing.T, db *gorm.DB) { }} if err := db.Create(&users).Error; err != nil { - t.Errorf("errors happened when create users: %v", err) + t.Fatal("errors happened when create users: %v", err) } t.Run("First", func(t *testing.T) { From f7f633590fefb3a503a4cbda894787d8a11b2540 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 12 Mar 2020 13:05:22 +0800 Subject: [PATCH 278/881] Fix tests with mysql, postgres --- callbacks/callbacks.go | 9 ++- callbacks/create.go | 110 +++++++++++++++++++++++++---- dialects/mssql/mssql.go | 2 +- dialects/mysql/mysql.go | 2 +- dialects/postgres/postgres.go | 4 +- dialects/postgres/postgres_test.go | 2 +- dialects/sqlite/sqlite.go | 4 +- schema/schema_test.go | 16 ++--- statement.go | 2 + tests/docker-compose.yml | 1 + tests/tests.go | 12 ++-- 11 files changed, 129 insertions(+), 35 deletions(-) diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index 0a48ada6..1985aec2 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -4,7 +4,12 @@ import ( "github.com/jinzhu/gorm" ) -func RegisterDefaultCallbacks(db *gorm.DB) { +type Config struct { + LastInsertIDReversed bool + WithReturning bool +} + +func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { enableTransaction := func(db *gorm.DB) bool { return !db.SkipDefaultTransaction } @@ -13,7 +18,7 @@ func RegisterDefaultCallbacks(db *gorm.DB) { createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) createCallback.Register("gorm:before_create", BeforeCreate) createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) - createCallback.Register("gorm:create", Create) + createCallback.Register("gorm:create", Create(config)) createCallback.Register("gorm:save_after_associations", SaveAfterAssociations) createCallback.Register("gorm:after_create", AfterCreate) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) diff --git a/callbacks/create.go b/callbacks/create.go index 42dcda27..3f6a81e4 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -5,6 +5,7 @@ import ( "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" ) func BeforeCreate(db *gorm.DB) { @@ -43,32 +44,113 @@ func BeforeCreate(db *gorm.DB) { func SaveBeforeAssociations(db *gorm.DB) { } -func Create(db *gorm.DB) { +func Create(config *Config) func(db *gorm.DB) { + if config.WithReturning { + return CreateWithReturning + } else { + return func(db *gorm.DB) { + db.Statement.AddClauseIfNotExists(clause.Insert{ + Table: clause.Table{Name: db.Statement.Table}, + }) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + + db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + if db.Statement.Schema != nil { + if insertID, err := result.LastInsertId(); err == nil { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID-- + } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID++ + } + } + case reflect.Struct: + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } + } else { + db.AddError(err) + } + } + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } + } + } +} + +func CreateWithReturning(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Insert{ Table: clause.Table{Name: db.Statement.Table}, }) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err == nil { - if db.Statement.Schema != nil { - if insertID, err := result.LastInsertId(); err == nil { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID-- + if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { + db.Statement.WriteString(" RETURNING ") + + var ( + idx int + fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) + values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) + ) + + for dbName, field := range sch.FieldsWithDefaultDBValue { + if idx != 0 { + db.Statement.WriteByte(',') + } + + fields[idx] = field + db.Statement.WriteQuoted(dbName) + idx++ + } + + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + defer rows.Close() + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for rows.Next() { + for idx, field := range fields { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() } - case reflect.Struct: - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + if err := rows.Scan(values...); err != nil { + db.AddError(err) + } + db.RowsAffected++ + } + case reflect.Struct: + for idx, field := range fields { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + } + + if rows.Next() { + err = rows.Scan(values...) } } } - db.RowsAffected, _ = result.RowsAffected() + + if err != nil { + db.AddError(err) + } } else { - db.AddError(err) + if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } } } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index e5bc7dd2..ad6782c7 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -25,7 +25,7 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks - callbacks.RegisterDefaultCallbacks(db) + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) db.ConnPool, err = sql.Open("sqlserver", dialector.DSN) return } diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index af796847..7b8f0491 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -24,7 +24,7 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks - callbacks.RegisterDefaultCallbacks(db) + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) db.ConnPool, err = sql.Open("mysql", dialector.DSN) return } diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 7589025d..73a19e9d 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -25,7 +25,9 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks - callbacks.RegisterDefaultCallbacks(db) + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ + WithReturning: true, + }) db.ConnPool, err = sql.Open("postgres", dialector.DSN) return } diff --git a/dialects/postgres/postgres_test.go b/dialects/postgres/postgres_test.go index a1252d92..2185c19c 100644 --- a/dialects/postgres/postgres_test.go +++ b/dialects/postgres/postgres_test.go @@ -16,7 +16,7 @@ var ( ) func init() { - dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" + dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" if os.Getenv("GORM_DSN") != "" { dsn = os.Getenv("GORM_DSN") } diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index c4837463..51829b17 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -22,7 +22,9 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks - callbacks.RegisterDefaultCallbacks(db) + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ + LastInsertIDReversed: true, + }) db.ConnPool, err = sql.Open("sqlite3", dialector.DSN) return } diff --git a/schema/schema_test.go b/schema/schema_test.go index ce225010..7d13e614 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -32,15 +32,15 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { // check fields fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, + {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64}, {Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time}, {Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time}, {Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time}, {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, - {Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint}, + {Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint, Size: 64}, {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, - {Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int}, - {Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint}, + {Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int, Size: 64}, + {Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint, Size: 64}, {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, } @@ -83,7 +83,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { JoinTable: JoinTable{Name: "UserSpeak", Table: "user_speaks", Fields: []schema.Field{ { Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, - Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64, }, { Name: "LanguageCode", DBName: "language_code", BindNames: []string{"LanguageCode"}, DataType: schema.String, @@ -97,11 +97,11 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { JoinTable: JoinTable{Name: "user_friends", Table: "user_friends", Fields: []schema.Field{ { Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, - Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64, }, { Name: "FriendID", DBName: "friend_id", BindNames: []string{"FriendID"}, DataType: schema.Uint, - Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64, }, }}, References: []Reference{{"ID", "User", "UserID", "user_friends", "", true}, {"ID", "User", "FriendID", "user_friends", "", false}}, @@ -124,7 +124,7 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) { // check fields fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true}, + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64}, {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, {Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time}, diff --git a/statement.go b/statement.go index 298a4c56..e45bd8bb 100644 --- a/statement.go +++ b/statement.go @@ -91,6 +91,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { writer.WriteString(" AS ") stmt.DB.Dialector.QuoteTo(writer, v.Alias) } + case string: + stmt.DB.Dialector.QuoteTo(writer, v) default: stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field)) } diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 6bf3fadf..05e0956e 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -15,6 +15,7 @@ services: ports: - 9920:5432 environment: + - TZ=Asia/Shanghai - POSTGRES_DB=gorm - POSTGRES_USER=gorm - POSTGRES_PASSWORD=gorm diff --git a/tests/tests.go b/tests/tests.go index c26d743e..aa48f699 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -37,7 +37,7 @@ func TestCreate(t *testing.T, db *gorm.DB) { } if err := db.Create(&user).Error; err != nil { - t.Errorf("errors happened when create: %v", err) + t.Fatalf("errors happened when create: %v", err) } if user.ID == 0 { @@ -81,7 +81,7 @@ func TestFind(t *testing.T, db *gorm.DB) { }} if err := db.Create(&users).Error; err != nil { - t.Fatal("errors happened when create users: %v", err) + t.Fatalf("errors happened when create users: %v", err) } t.Run("First", func(t *testing.T) { @@ -195,11 +195,11 @@ func TestUpdate(t *testing.T, db *gorm.DB) { } if err := db.Create(&users).Error; err != nil { - t.Errorf("errors happened when create: %v", err) + t.Fatalf("errors happened when create: %v", err) } else if user.ID == 0 { - t.Errorf("user's primary value should not zero, %v", user.ID) + t.Fatalf("user's primary value should not zero, %v", user.ID) } else if user.UpdatedAt.IsZero() { - t.Errorf("user's updated at should not zero, %v", user.UpdatedAt) + t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt) } lastUpdatedAt = user.UpdatedAt @@ -297,7 +297,7 @@ func TestDelete(t *testing.T, db *gorm.DB) { for _, user := range users { if user.ID == 0 { - t.Errorf("user's primary key should has value after create, got : %v", user.ID) + t.Fatalf("user's primary key should has value after create, got : %v", user.ID) } } From 477efab8cd9881ffe79a040d87ef1531d5ba0b7f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 14 Mar 2020 19:00:41 +0800 Subject: [PATCH 279/881] Refactor logger --- logger/logger.go | 89 ++++++++++++++++++++++++++---------------------- 1 file changed, 48 insertions(+), 41 deletions(-) diff --git a/logger/logger.go b/logger/logger.go index 80ae31b1..ee6c0da1 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -10,16 +10,17 @@ import ( // Colors const ( - Reset = "\033[0m" - Red = "\033[31m" - Green = "\033[32m" - Yellow = "\033[33m" - Blue = "\033[34m" - Magenta = "\033[35m" - Cyan = "\033[36m" - White = "\033[37m" - Redbold = "\033[31;1m" - YellowBold = "\033[33;1m" + Reset = "\033[0m" + Red = "\033[31m" + Green = "\033[32m" + Yellow = "\033[33m" + Blue = "\033[34m" + Magenta = "\033[35m" + Cyan = "\033[36m" + White = "\033[37m" + MagentaBold = "\033[35;1m" + RedBold = "\033[31;1m" + YellowBold = "\033[33;1m" ) // LogLevel @@ -59,37 +60,40 @@ var Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ func New(writer Writer, config Config) Interface { var ( - infoPrefix = "%s\n[info] " - warnPrefix = "%s\n[warn] " - errPrefix = "%s\n[error] " - tracePrefix = "%s\n[%v] [rows:%d] %s" - traceErrPrefix = "%s\n[%v] [rows:%d] %s" + infoStr = "%s\n[info] " + warnStr = "%s\n[warn] " + errStr = "%s\n[error] " + traceStr = "%s\n[%v] [rows:%d] %s" + traceWarnStr = "%s\n[%v] [rows:%d] %s" + traceErrStr = "%s %s\n[%v] [rows:%d] %s" ) if config.Colorful { - infoPrefix = Green + "%s\n" + Reset + Green + "[info] " + Reset - warnPrefix = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset - errPrefix = Magenta + "%s\n" + Reset + Red + "[error] " + Reset - tracePrefix = Green + "%s\n" + Reset + YellowBold + "[%.3fms] " + Green + "[rows:%d]" + Reset + " %s" - traceErrPrefix = Magenta + "%s\n" + Reset + Redbold + "[%.3fms] " + Yellow + "[rows:%d]" + Reset + " %s" + infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset + warnStr = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset + errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset + traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" + traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset + traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" } return logger{ - Writer: writer, - Config: config, - infoPrefix: infoPrefix, - warnPrefix: warnPrefix, - errPrefix: errPrefix, - tracePrefix: tracePrefix, - traceErrPrefix: traceErrPrefix, + Writer: writer, + Config: config, + infoStr: infoStr, + warnStr: warnStr, + errStr: errStr, + traceStr: traceStr, + traceWarnStr: traceWarnStr, + traceErrStr: traceErrStr, } } type logger struct { Writer Config - infoPrefix, warnPrefix, errPrefix string - tracePrefix, traceErrPrefix string + infoStr, warnStr, errStr string + traceStr, traceErrStr, traceWarnStr string } // LogMode log mode @@ -101,35 +105,38 @@ func (l logger) LogMode(level LogLevel) Interface { // Info print info func (l logger) Info(msg string, data ...interface{}) { if l.LogLevel >= Info { - l.Printf(l.infoPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) + l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Warn print warn messages func (l logger) Warn(msg string, data ...interface{}) { if l.LogLevel >= Warn { - l.Printf(l.warnPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) + l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Error print error messages func (l logger) Error(msg string, data ...interface{}) { if l.LogLevel >= Error { - l.Printf(l.errPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) + l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Trace print sql message func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) { - if elapsed := time.Now().Sub(begin); elapsed > l.SlowThreshold && l.SlowThreshold != 0 { - sql, rows := fc() - fileline := utils.FileWithLineNum() - if err != nil { - fileline += " " + err.Error() + if l.LogLevel > 0 { + elapsed := time.Now().Sub(begin) + switch { + case err != nil: + sql, rows := fc() + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) + case elapsed > l.SlowThreshold && l.SlowThreshold != 0: + sql, rows := fc() + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + case l.LogLevel >= Info: + sql, rows := fc() + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) } - l.Printf(l.traceErrPrefix, fileline, float64(elapsed.Nanoseconds())/1e6, rows, sql) - } else if l.LogLevel >= Info { - sql, rows := fc() - l.Printf(l.tracePrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) } } From 3a126233bff544590896a14b36b092c9a5941189 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 23 Mar 2020 22:40:12 +0800 Subject: [PATCH 280/881] Fix select with * --- callbacks/helper.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callbacks/helper.go b/callbacks/helper.go index 433ab346..0dd6ff43 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -17,7 +17,7 @@ func SelectAndOmitColumns(stmt *gorm.Statement) (map[string]bool, bool) { for _, dbName := range stmt.Schema.DBNames { results[dbName] = true } - return results, true + break } if field := stmt.Schema.LookUpField(column); field != nil { From be537f29ec080d0bcef2f1db1587b051e69958d7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 30 Mar 2020 09:31:02 +0800 Subject: [PATCH 281/881] [migrator] Use full data type when add column --- migrator/migrator.go | 42 +++++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 730e8cfe..5e246c3f 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -45,6 +45,27 @@ func (m Migrator) DataTypeOf(field *schema.Field) string { return m.Dialector.DataTypeOf(field) } +func (m Migrator) FullDataTypeOf(field *schema.Field) string { + dataType := m.DataTypeOf(field) + + if field.AutoIncrement { + dataType += " AUTO_INCREMENT" + } + + if field.NotNull { + dataType += " NOT NULL" + } + + if field.Unique { + dataType += " UNIQUE" + } + + if field.HasDefaultValue { + dataType += " DEFAULT " + field.DefaultValue + } + return dataType +} + // AutoMigrate func (m Migrator) AutoMigrate(values ...interface{}) error { // TODO smart migrate data type @@ -113,24 +134,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { field := stmt.Schema.FieldsByDBName[dbName] createTableSQL += fmt.Sprintf("? ?") hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY") - values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: m.DataTypeOf(field)}) - - if field.AutoIncrement { - createTableSQL += " AUTO_INCREMENT" - } - - if field.NotNull { - createTableSQL += " NOT NULL" - } - - if field.Unique { - createTableSQL += " UNIQUE" - } - - if field.DefaultValue != "" { - createTableSQL += " DEFAULT ?" - values = append(values, clause.Expr{SQL: field.DefaultValue}) - } + values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: m.FullDataTypeOf(field)}) createTableSQL += "," } @@ -220,7 +224,7 @@ func (m Migrator) AddColumn(value interface{}, field string) error { if field := stmt.Schema.LookUpField(field); field != nil { return m.DB.Exec( "ALTER TABLE ? ADD ? ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)}, + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.FullDataTypeOf(field)}, ).Error } return fmt.Errorf("failed to look up field with name: %s", field) From 511bd664900a6818e883b0b6eb2e3e4243efefac Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 3 Apr 2020 07:15:30 +0800 Subject: [PATCH 282/881] Fix print code lines --- utils/utils.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/utils.go b/utils/utils.go index 25cd585a..8521d09b 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -7,8 +7,8 @@ import ( "unicode" ) -var goSrcRegexp = regexp.MustCompile(`/gorm/.*.go`) -var goTestRegexp = regexp.MustCompile(`/gorm/.*test.*.go`) +var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`) +var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`) func FileWithLineNum() string { for i := 2; i < 15; i++ { From d39bdc35132dc7e6181ac1d2b5524df75c157a08 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 3 Apr 2020 07:57:52 +0800 Subject: [PATCH 283/881] Fix create index --- migrator/migrator.go | 2 +- schema/index.go | 6 +++++- schema/index_test.go | 22 ++++++++++++++-------- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 5e246c3f..763b4ec3 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -150,7 +150,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, idx := range stmt.Schema.ParseIndexes() { if m.CreateIndexAfterCreateTable { - tx.Migrator().CreateIndex(value, idx.Name) + defer tx.Migrator().CreateIndex(value, idx.Name) } else { createTableSQL += "INDEX ? ?," values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) diff --git a/schema/index.go b/schema/index.go index 26c7a558..c5c96aa4 100644 --- a/schema/index.go +++ b/schema/index.go @@ -26,7 +26,7 @@ type IndexOption struct { func (schema *Schema) ParseIndexes() map[string]Index { var indexes = map[string]Index{} - for _, field := range schema.FieldsByDBName { + for _, field := range schema.Fields { if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE_INDEX"] != "" { for _, index := range parseFieldIndexes(field) { idx := indexes[index.Name] @@ -66,6 +66,10 @@ func parseFieldIndexes(field *Field) (indexes []Index) { length, _ = strconv.Atoi(settings["LENGTH"]) ) + if idx == -1 { + idx = len(tag) + } + if idx != -1 { name = tag[0:idx] } diff --git a/schema/index_test.go b/schema/index_test.go index d0e8dfe0..398ddbb7 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -9,13 +9,15 @@ import ( ) type UserIndex struct { - Name string `gorm:"index"` - Name2 string `gorm:"index:idx_name,unique"` - Name3 string `gorm:"index:,sort:desc,collate:utf8,type:btree,length:10,where:name3 != 'jinzhu'"` - Name4 string `gorm:"unique_index"` - Name5 int64 `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"` - Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"` - Age int64 `gorm:"index:profile,expression:ABS(age)"` + Name string `gorm:"index"` + Name2 string `gorm:"index:idx_name,unique"` + Name3 string `gorm:"index:,sort:desc,collate:utf8,type:btree,length:10,where:name3 != 'jinzhu'"` + Name4 string `gorm:"unique_index"` + Name5 int64 `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"` + Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"` + Age int64 `gorm:"index:profile,expression:ABS(age)"` + OID int64 `gorm:"index:idx_id"` + MemberNumber string `gorm:"index:idx_id"` } func TestParseIndex(t *testing.T) { @@ -64,6 +66,10 @@ func TestParseIndex(t *testing.T) { Expression: "ABS(age)", }}, }, + "idx_id": { + Name: "idx_id", + Fields: []schema.IndexOption{{}, {}}, + }, } indices := user.ParseIndexes() @@ -71,7 +77,7 @@ func TestParseIndex(t *testing.T) { for k, result := range results { v, ok := indices[k] if !ok { - t.Errorf("Failed to found index %v from parsed indices %+v", k, indices) + t.Fatalf("Failed to found index %v from parsed indices %+v", k, indices) } for _, name := range []string{"Name", "Class", "Type", "Where", "Comment"} { From 29cd35219fc13c3019b0c7515562e28434ad0056 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 8 Apr 2020 08:15:00 +0800 Subject: [PATCH 284/881] Add creatable, updatable, readable permission --- callbacks/create.go | 6 +++--- callbacks/helper.go | 24 +++++++++++++++++------- callbacks/scan.go | 8 ++++++-- callbacks/update.go | 2 +- schema/field.go | 18 ++++++++++++++++++ 5 files changed, 45 insertions(+), 13 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 3f6a81e4..97a2832c 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -194,13 +194,13 @@ func AfterCreate(db *gorm.DB) { func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { switch value := stmt.Dest.(type) { case map[string]interface{}: - return ConvertMapToValues(stmt, value) + return ConvertMapToValuesForCreate(stmt, value) case []map[string]interface{}: - return ConvertSliceOfMapToValues(stmt, value) + return ConvertSliceOfMapToValuesForCreate(stmt, value) default: var ( values = clause.Values{} - selectColumns, restricted = SelectAndOmitColumns(stmt) + selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) curTime = stmt.DB.NowFunc() isZero = false ) diff --git a/callbacks/helper.go b/callbacks/helper.go index 0dd6ff43..8a69fbd1 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -8,7 +8,7 @@ import ( ) // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false -func SelectAndOmitColumns(stmt *gorm.Statement) (map[string]bool, bool) { +func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate bool) (map[string]bool, bool) { results := map[string]bool{} // select columns @@ -36,13 +36,23 @@ func SelectAndOmitColumns(stmt *gorm.Statement) (map[string]bool, bool) { } } + if stmt.Schema != nil { + for _, field := range stmt.Schema.FieldsByDBName { + if requireCreate && !field.Creatable { + results[field.DBName] = false + } else if requireUpdate && !field.Updatable { + results[field.DBName] = false + } + } + } + return results, len(stmt.Selects) > 0 } -// ConvertMapToValues convert map to values -func ConvertMapToValues(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { +// ConvertMapToValuesForCreate convert map to values +func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { columns := make([]string, 0, len(mapValue)) - selectColumns, restricted := SelectAndOmitColumns(stmt) + selectColumns, restricted := SelectAndOmitColumns(stmt, true, false) var keys []string for k, _ := range mapValue { @@ -64,12 +74,12 @@ func ConvertMapToValues(stmt *gorm.Statement, mapValue map[string]interface{}) ( return } -// ConvertSliceOfMapToValues convert slice of map to values -func ConvertSliceOfMapToValues(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { +// ConvertSliceOfMapToValuesForCreate convert slice of map to values +func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { var ( columns = []string{} result = map[string][]interface{}{} - selectColumns, restricted = SelectAndOmitColumns(stmt) + selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) ) for idx, mapValue := range mapValues { diff --git a/callbacks/scan.go b/callbacks/scan.go index f8f1ef54..2bd0143c 100644 --- a/callbacks/scan.go +++ b/callbacks/scan.go @@ -56,7 +56,11 @@ func Scan(rows *sql.Rows, db *gorm.DB) { fields := make([]*schema.Field, len(columns)) for idx, column := range columns { - fields[idx] = db.Statement.Schema.LookUpField(column) + if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { + fields[idx] = field + } else { + values[idx] = sql.RawBytes{} + } } for rows.Next() { @@ -80,7 +84,7 @@ func Scan(rows *sql.Rows, db *gorm.DB) { } case reflect.Struct: for idx, column := range columns { - if field := db.Statement.Schema.LookUpField(column); field != nil { + if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() } else { values[idx] = sql.RawBytes{} diff --git a/callbacks/update.go b/callbacks/update.go index eab9f929..53c646e9 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -91,7 +91,7 @@ func AfterUpdate(db *gorm.DB) { // ConvertToAssignments convert to update assignments func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { - selectColumns, restricted := SelectAndOmitColumns(stmt) + selectColumns, restricted := SelectAndOmitColumns(stmt, false, true) reflectModelValue := reflect.ValueOf(stmt.Model) switch value := stmt.Dest.(type) { diff --git a/schema/field.go b/schema/field.go index ee1baf3c..a8e55acd 100644 --- a/schema/field.go +++ b/schema/field.go @@ -42,6 +42,7 @@ type Field struct { AutoIncrement bool Creatable bool Updatable bool + Readable bool HasDefaultValue bool AutoCreateTime TimeType AutoUpdateTime TimeType @@ -73,6 +74,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { StructField: fieldStruct, Creatable: true, Updatable: true, + Readable: true, Tag: fieldStruct.Tag, TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"), Schema: schema, @@ -117,6 +119,21 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { if _, ok := field.TagSettings["-"]; ok { field.Creatable = false field.Updatable = false + field.Readable = false + } + + if v, ok := field.TagSettings["<-"]; ok { + if !strings.Contains(v, "create") { + field.Creatable = false + } + + if !strings.Contains(v, "update") { + field.Updatable = false + } + } + + if _, ok := field.TagSettings["->"]; ok { + field.Readable = false } if dbName, ok := field.TagSettings["COLUMN"]; ok { @@ -235,6 +252,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { var err error field.Creatable = false field.Updatable = false + field.Readable = false if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { schema.err = err } From a46d48ccb3f243f2ae06515eff0026852d088131 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 8 Apr 2020 08:32:28 +0800 Subject: [PATCH 285/881] Add tests for controlling field permission with tag --- schema/field.go | 18 ++++++++++++------ schema/field_test.go | 31 +++++++++++++++++++++++++++++++ schema/schema_helper_test.go | 2 +- schema/schema_test.go | 10 ++++++---- 4 files changed, 50 insertions(+), 11 deletions(-) diff --git a/schema/field.go b/schema/field.go index a8e55acd..a5c3b41f 100644 --- a/schema/field.go +++ b/schema/field.go @@ -123,17 +123,23 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if v, ok := field.TagSettings["<-"]; ok { - if !strings.Contains(v, "create") { - field.Creatable = false + if v != "<-" { + if !strings.Contains(v, "create") { + field.Creatable = false + } + + if !strings.Contains(v, "update") { + field.Updatable = false + } } - if !strings.Contains(v, "update") { - field.Updatable = false - } + field.Readable = false } if _, ok := field.TagSettings["->"]; ok { - field.Readable = false + field.Creatable = false + field.Updatable = false + field.Readable = true } if dbName, ok := field.TagSettings["COLUMN"]; ok { diff --git a/schema/field_test.go b/schema/field_test.go index 15dfa41d..c04149ff 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -216,3 +216,34 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { } checkField(t, userSchema, reflectValue, newValues2) } + +type UserWithPermissionControl struct { + ID uint + Name string `gorm:"-"` + Name2 string `gorm:"->"` + Name3 string `gorm:"<-"` + Name4 string `gorm:"<-:create"` + Name5 string `gorm:"<-:update"` + Name6 string `gorm:"<-:create,update"` +} + +func TestParseFieldWithPermission(t *testing.T) { + user, err := schema.Parse(&UserWithPermissionControl{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse user with permission, got error %v", err) + } + + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true}, + {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String, Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false}, + {Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true}, + {Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: false}, + {Name: "Name4", DBName: "name4", BindNames: []string{"Name4"}, DataType: schema.String, Tag: `gorm:"<-:create"`, Creatable: true, Updatable: false, Readable: false}, + {Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: false}, + {Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: false}, + } + + for _, f := range fields { + checkSchemaField(t, user, &f, func(f *schema.Field) {}) + } +} diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 146ba13a..24920515 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -52,7 +52,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* if parsedField, ok := s.FieldsByName[f.Name]; !ok { t.Errorf("schema %v failed to look up field with name %v", s, f.Name) } else { - tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") + tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) diff --git a/schema/schema_test.go b/schema/schema_test.go index 7d13e614..958e035f 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -48,6 +48,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { checkSchemaField(t, user, &f, func(f *schema.Field) { f.Creatable = true f.Updatable = true + f.Readable = true }) } @@ -83,11 +84,11 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { JoinTable: JoinTable{Name: "UserSpeak", Table: "user_speaks", Fields: []schema.Field{ { Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, - Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, Readable: true, PrimaryKey: true, Size: 64, }, { Name: "LanguageCode", DBName: "language_code", BindNames: []string{"LanguageCode"}, DataType: schema.String, - Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, Readable: true, PrimaryKey: true, }, }}, References: []Reference{{"ID", "User", "UserID", "UserSpeak", "", true}, {"Code", "Language", "LanguageCode", "UserSpeak", "", false}}, @@ -97,11 +98,11 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { JoinTable: JoinTable{Name: "user_friends", Table: "user_friends", Fields: []schema.Field{ { Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, - Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, Readable: true, PrimaryKey: true, Size: 64, }, { Name: "FriendID", DBName: "friend_id", BindNames: []string{"FriendID"}, DataType: schema.Uint, - Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, Readable: true, PrimaryKey: true, Size: 64, }, }}, References: []Reference{{"ID", "User", "UserID", "user_friends", "", true}, {"ID", "User", "FriendID", "user_friends", "", false}}, @@ -137,6 +138,7 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) { checkSchemaField(t, user, &f, func(f *schema.Field) { f.Creatable = true f.Updatable = true + f.Readable = true }) } } From e1bcca6b332c5a9d59d794806e65ed060789b40c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 12 Apr 2020 13:16:15 +0800 Subject: [PATCH 286/881] Compatible with tag PRIMARY_KEY --- schema/field.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/schema/field.go b/schema/field.go index a5c3b41f..ec419383 100644 --- a/schema/field.go +++ b/schema/field.go @@ -148,6 +148,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { if val, ok := field.TagSettings["PRIMARYKEY"]; ok && checkTruth(val) { field.PrimaryKey = true + } else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && checkTruth(val) { + field.PrimaryKey = true } if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && checkTruth(val) { From a992c1ea38c4a05a934fb30928a560c3a54190d2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 12 Apr 2020 13:22:52 +0800 Subject: [PATCH 287/881] Fix check has column, index for sqlite --- dialects/sqlite/migrator.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dialects/sqlite/migrator.go b/dialects/sqlite/migrator.go index 4ddcbb5d..601de126 100644 --- a/dialects/sqlite/migrator.go +++ b/dialects/sqlite/migrator.go @@ -30,8 +30,8 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { } return m.DB.Raw( - "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)", - stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", + "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)", + "table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", ).Row().Scan(&count) }) return count > 0 @@ -41,8 +41,8 @@ func (m Migrator) HasIndex(value interface{}, name string) bool { var count int m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Raw( - "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE ?", - stmt.Table, "%INDEX "+name+" ON%", + "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND sql LIKE ?", + "index", stmt.Table, "%INDEX "+name+" ON%", ).Row().Scan(&count) }) return count > 0 From 50aa9be4f10d8a1562fef223efcf9fee6a02d256 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 15 Apr 2020 09:14:24 +0800 Subject: [PATCH 288/881] Add joins support --- callbacks/query.go | 73 ++++++++++++++++++++++++++++++++++++++++++++-- chainable_api.go | 10 ++++++- clause/joins.go | 51 +++++++++++++++++--------------- statement.go | 10 +++++++ 4 files changed, 118 insertions(+), 26 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 00820bfd..ae22f4d0 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -1,6 +1,7 @@ package callbacks import ( + "fmt" "reflect" "github.com/jinzhu/gorm" @@ -9,8 +10,76 @@ import ( func Query(db *gorm.DB) { if db.Statement.SQL.String() == "" { - db.Statement.AddClauseIfNotExists(clause.Select{}) - db.Statement.AddClauseIfNotExists(clause.From{}) + clauseSelect := clause.Select{} + + if len(db.Statement.Selects) > 0 { + for _, name := range db.Statement.Selects { + if f := db.Statement.Schema.LookUpField(name); f != nil { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Name: f.DBName, + }) + } + } + } + + if len(db.Statement.Joins) != 0 { + joins := []clause.Join{} + + if len(db.Statement.Selects) == 0 { + for _, dbName := range db.Statement.Schema.DBNames { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Name: dbName, + }) + } + } + + for name, conds := range db.Statement.Joins { + if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { + for _, s := range relation.FieldSchema.DBNames { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: relation.FieldSchema.Table, + Name: s, + }) + } + + var exprs []clause.Expression + for _, ref := range relation.References { + if ref.OwnPrimaryKey { + exprs = append(exprs, clause.Expr{ + SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.PrimaryKey.DBName, relation.FieldSchema.Table, ref.ForeignKey.DBName), + }) + } else { + if ref.PrimaryValue == "" { + exprs = append(exprs, clause.Expr{ + SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.ForeignKey.DBName, relation.FieldSchema.Table, ref.PrimaryKey.DBName), + }) + } else { + exprs = append(exprs, clause.Expr{ + SQL: fmt.Sprintf("%s.%s = ?", relation.FieldSchema.Table, ref.PrimaryKey.DBName), + Vars: []interface{}{ref.PrimaryValue}, + }) + } + } + } + + joins = append(joins, clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: relation.FieldSchema.Table}, + ON: clause.Where{Exprs: exprs}, + }) + } else { + joins = append(joins, clause.Join{ + Expression: clause.Expr{SQL: name, Vars: conds}, + }) + } + } + + db.Statement.AddClause(clause.From{Joins: joins}) + } else { + db.Statement.AddClauseIfNotExists(clause.From{}) + } + + db.Statement.AddClauseIfNotExists(clauseSelect) db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } diff --git a/chainable_api.go b/chainable_api.go index 7a6e8b7c..6b91c9ad 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -134,6 +134,10 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() + if tx.Statement.Joins == nil { + tx.Statement.Joins = map[string][]interface{}{} + } + tx.Statement.Joins[query] = args return } @@ -211,8 +215,12 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) *DB { // Preload preload associations with given conditions // db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) -func (db *DB) Preload(column string, conditions ...interface{}) (tx *DB) { +func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() + if tx.Statement.Preloads == nil { + tx.Statement.Preloads = map[string][]interface{}{} + } + tx.Statement.Preloads[query] = args return } diff --git a/clause/joins.go b/clause/joins.go index a78bde39..8d9055cd 100644 --- a/clause/joins.go +++ b/clause/joins.go @@ -11,32 +11,37 @@ const ( // Join join clause for from type Join struct { - Type JoinType - Table Table - ON Where - Using []string + Type JoinType + Table Table + ON Where + Using []string + Expression Expression } func (join Join) Build(builder Builder) { - if join.Type != "" { - builder.WriteString(string(join.Type)) - builder.WriteByte(' ') - } - - builder.WriteString("JOIN ") - builder.WriteQuoted(join.Table) - - if len(join.ON.Exprs) > 0 { - builder.WriteString(" ON ") - join.ON.Build(builder) - } else if len(join.Using) > 0 { - builder.WriteString(" USING (") - for idx, c := range join.Using { - if idx > 0 { - builder.WriteByte(',') - } - builder.WriteQuoted(c) + if join.Expression != nil { + join.Expression.Build(builder) + } else { + if join.Type != "" { + builder.WriteString(string(join.Type)) + builder.WriteByte(' ') + } + + builder.WriteString("JOIN ") + builder.WriteQuoted(join.Table) + + if len(join.ON.Exprs) > 0 { + builder.WriteString(" ON ") + join.ON.Build(builder) + } else if len(join.Using) > 0 { + builder.WriteString(" USING (") + for idx, c := range join.Using { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(c) + } + builder.WriteByte(')') } - builder.WriteByte(')') } } diff --git a/statement.go b/statement.go index e45bd8bb..3f2ceca3 100644 --- a/statement.go +++ b/statement.go @@ -24,6 +24,8 @@ type Statement struct { Clauses map[string]clause.Clause Selects []string // selected columns Omits []string // omit columns + Joins map[string][]interface{} + Preloads map[string][]interface{} Settings sync.Map ConnPool ConnPool Schema *schema.Schema @@ -265,6 +267,14 @@ func (stmt *Statement) reinit() { delete(stmt.Clauses, k) } + for k := range stmt.Joins { + delete(stmt.Joins, k) + } + + for k := range stmt.Preloads { + delete(stmt.Preloads, k) + } + stmt.Settings.Range(func(k, _ interface{}) bool { stmt.Settings.Delete(k) return true From b4b249ddcb451327109020bac6a372102c1bcb1e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 15 Apr 2020 19:13:36 +0800 Subject: [PATCH 289/881] Refactor test files --- tests/create.go | 43 +++++++ tests/delete.go | 64 ++++++++++ tests/joins.go | 5 + tests/query.go | 95 +++++++++++++++ tests/tests.go | 306 ------------------------------------------------ tests/update.go | 133 +++++++++++++++++++++ 6 files changed, 340 insertions(+), 306 deletions(-) create mode 100644 tests/create.go create mode 100644 tests/delete.go create mode 100644 tests/query.go create mode 100644 tests/update.go diff --git a/tests/create.go b/tests/create.go new file mode 100644 index 00000000..dfd73bd3 --- /dev/null +++ b/tests/create.go @@ -0,0 +1,43 @@ +package tests + +import ( + "testing" + + "github.com/jinzhu/gorm" +) + +func TestCreate(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&User{}) + db.AutoMigrate(&User{}) + + t.Run("Create", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + } + + if err := db.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + if user.ID == 0 { + t.Errorf("user's primary key should has value after create, got : %v", user.ID) + } + + if user.CreatedAt.IsZero() { + t.Errorf("user's created at should be not zero") + } + + if user.UpdatedAt.IsZero() { + t.Errorf("user's updated at should be not zero") + } + + var newUser User + if err := db.Where("id = ?", user.ID).First(&newUser).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday") + } + }) +} diff --git a/tests/delete.go b/tests/delete.go new file mode 100644 index 00000000..45701ff0 --- /dev/null +++ b/tests/delete.go @@ -0,0 +1,64 @@ +package tests + +import ( + "errors" + "testing" + + "github.com/jinzhu/gorm" +) + +func TestDelete(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&User{}) + db.AutoMigrate(&User{}) + + t.Run("Delete", func(t *testing.T) { + var users = []User{{ + Name: "find", + Age: 1, + Birthday: Now(), + }, { + Name: "find", + Age: 2, + Birthday: Now(), + }, { + Name: "find", + Age: 3, + Birthday: Now(), + }} + + if err := db.Create(&users).Error; err != nil { + t.Errorf("errors happened when create: %v", err) + } + + for _, user := range users { + if user.ID == 0 { + t.Fatalf("user's primary key should has value after create, got : %v", user.ID) + } + } + + if err := db.Delete(&users[1]).Error; err != nil { + t.Errorf("errors happened when delete: %v", err) + } + + var result User + if err := db.Where("id = ?", users[1].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should returns record not found error, but got %v", err) + } + + for _, user := range []User{users[0], users[2]} { + if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("no error should returns when query %v, but got %v", user.ID, err) + } + } + + if err := db.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { + t.Errorf("should returns missing WHERE clause while deleting error") + } + + for _, user := range []User{users[0], users[2]} { + if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("no error should returns when query %v, but got %v", user.ID, err) + } + } + }) +} diff --git a/tests/joins.go b/tests/joins.go index 3c4bfbb5..2a8cdc8b 100644 --- a/tests/joins.go +++ b/tests/joins.go @@ -7,4 +7,9 @@ import ( ) func TestJoins(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&User{}) + db.AutoMigrate(&User{}) + + t.Run("Joins", func(t *testing.T) { + }) } diff --git a/tests/query.go b/tests/query.go new file mode 100644 index 00000000..5eabfb48 --- /dev/null +++ b/tests/query.go @@ -0,0 +1,95 @@ +package tests + +import ( + "reflect" + "strconv" + "testing" + + "github.com/jinzhu/gorm" +) + +func TestFind(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&User{}) + db.AutoMigrate(&User{}) + + t.Run("Find", func(t *testing.T) { + var users = []User{{ + Name: "find", + Age: 1, + Birthday: Now(), + }, { + Name: "find", + Age: 2, + Birthday: Now(), + }, { + Name: "find", + Age: 3, + Birthday: Now(), + }} + + if err := db.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create users: %v", err) + } + + t.Run("First", func(t *testing.T) { + var first User + if err := db.Where("name = ?", "find").First(&first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + AssertObjEqual(t, first, users[0], "Name", "Age", "Birthday") + } + }) + + t.Run("Last", func(t *testing.T) { + var last User + if err := db.Where("name = ?", "find").Last(&last).Error; err != nil { + t.Errorf("errors happened when query last: %v", err) + } else { + AssertObjEqual(t, last, users[2], "Name", "Age", "Birthday") + } + }) + + var all []User + if err := db.Where("name = ?", "find").Find(&all).Error; err != nil || len(all) != 3 { + t.Errorf("errors happened when query find: %v, length: %v", err, len(all)) + } else { + for idx, user := range users { + t.Run("FindAll#"+strconv.Itoa(idx+1), func(t *testing.T) { + AssertObjEqual(t, all[idx], user, "Name", "Age", "Birthday") + }) + } + } + + t.Run("FirstMap", func(t *testing.T) { + var first = map[string]interface{}{} + if err := db.Model(&User{}).Where("name = ?", "find").First(first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := db.NamingStrategy.ColumnName("", name) + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) + AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) + }) + } + } + }) + + var allMap = []map[string]interface{}{} + if err := db.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for idx, user := range users { + t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := db.NamingStrategy.ColumnName("", name) + reflectValue := reflect.Indirect(reflect.ValueOf(user)) + AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) + }) + } + }) + } + } + }) +} diff --git a/tests/tests.go b/tests/tests.go index aa48f699..cc9c1a78 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -1,9 +1,6 @@ package tests import ( - "errors" - "reflect" - "strconv" "testing" "time" @@ -24,306 +21,3 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) { TestGroupBy(t, db) TestJoins(t, db) } - -func TestCreate(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&User{}) - db.AutoMigrate(&User{}) - - t.Run("Create", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - if user.ID == 0 { - t.Errorf("user's primary key should has value after create, got : %v", user.ID) - } - - if user.CreatedAt.IsZero() { - t.Errorf("user's created at should be not zero") - } - - if user.UpdatedAt.IsZero() { - t.Errorf("user's updated at should be not zero") - } - - var newUser User - if err := db.Where("id = ?", user.ID).First(&newUser).Error; err != nil { - t.Errorf("errors happened when query: %v", err) - } else { - AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday") - } - }) -} - -func TestFind(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&User{}) - db.AutoMigrate(&User{}) - - t.Run("Find", func(t *testing.T) { - var users = []User{{ - Name: "find", - Age: 1, - Birthday: Now(), - }, { - Name: "find", - Age: 2, - Birthday: Now(), - }, { - Name: "find", - Age: 3, - Birthday: Now(), - }} - - if err := db.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create users: %v", err) - } - - t.Run("First", func(t *testing.T) { - var first User - if err := db.Where("name = ?", "find").First(&first).Error; err != nil { - t.Errorf("errors happened when query first: %v", err) - } else { - AssertObjEqual(t, first, users[0], "Name", "Age", "Birthday") - } - }) - - t.Run("Last", func(t *testing.T) { - var last User - if err := db.Where("name = ?", "find").Last(&last).Error; err != nil { - t.Errorf("errors happened when query last: %v", err) - } else { - AssertObjEqual(t, last, users[2], "Name", "Age", "Birthday") - } - }) - - var all []User - if err := db.Where("name = ?", "find").Find(&all).Error; err != nil || len(all) != 3 { - t.Errorf("errors happened when query find: %v, length: %v", err, len(all)) - } else { - for idx, user := range users { - t.Run("FindAll#"+strconv.Itoa(idx+1), func(t *testing.T) { - AssertObjEqual(t, all[idx], user, "Name", "Age", "Birthday") - }) - } - } - - t.Run("FirstMap", func(t *testing.T) { - var first = map[string]interface{}{} - if err := db.Model(&User{}).Where("name = ?", "find").First(first).Error; err != nil { - t.Errorf("errors happened when query first: %v", err) - } else { - for _, name := range []string{"Name", "Age", "Birthday"} { - t.Run(name, func(t *testing.T) { - dbName := db.NamingStrategy.ColumnName("", name) - reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) - AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) - }) - } - } - }) - - var allMap = []map[string]interface{}{} - if err := db.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { - t.Errorf("errors happened when query first: %v", err) - } else { - for idx, user := range users { - t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { - for _, name := range []string{"Name", "Age", "Birthday"} { - t.Run(name, func(t *testing.T) { - dbName := db.NamingStrategy.ColumnName("", name) - reflectValue := reflect.Indirect(reflect.ValueOf(user)) - AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) - }) - } - }) - } - } - }) -} - -func TestUpdate(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&User{}) - db.AutoMigrate(&User{}) - - t.Run("Update", func(t *testing.T) { - var ( - users = []*User{{ - Name: "update-before", - Age: 1, - Birthday: Now(), - }, { - Name: "update", - Age: 18, - Birthday: Now(), - }, { - Name: "update-after", - Age: 1, - Birthday: Now(), - }} - user = users[1] - lastUpdatedAt time.Time - ) - - checkUpdatedTime := func(name string, n time.Time) { - if n.UnixNano() == lastUpdatedAt.UnixNano() { - t.Errorf("%v: user's updated at should be changed, but got %v, was %v", name, n, lastUpdatedAt) - } - lastUpdatedAt = n - } - - checkOtherData := func(name string) { - var beforeUser, afterUser User - if err := db.Where("id = ?", users[0].ID).First(&beforeUser).Error; err != nil { - t.Errorf("errors happened when query before user: %v", err) - } - t.Run(name, func(t *testing.T) { - AssertObjEqual(t, beforeUser, users[0], "Name", "Age", "Birthday") - }) - - if err := db.Where("id = ?", users[2].ID).First(&afterUser).Error; err != nil { - t.Errorf("errors happened when query after user: %v", err) - } - t.Run(name, func(t *testing.T) { - AssertObjEqual(t, afterUser, users[2], "Name", "Age", "Birthday") - }) - } - - if err := db.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } else if user.ID == 0 { - t.Fatalf("user's primary value should not zero, %v", user.ID) - } else if user.UpdatedAt.IsZero() { - t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt) - } - lastUpdatedAt = user.UpdatedAt - - if err := db.Model(user).Update("Age", 10).Error; err != nil { - t.Errorf("errors happened when update: %v", err) - } else if user.Age != 10 { - t.Errorf("Age should equals to 10, but got %v", user.Age) - } - checkUpdatedTime("Update", user.UpdatedAt) - checkOtherData("Update") - - var result User - if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { - t.Errorf("errors happened when query: %v", err) - } else { - AssertObjEqual(t, result, user, "Name", "Age", "Birthday") - } - - values := map[string]interface{}{"Active": true, "age": 5} - if err := db.Model(user).Updates(values).Error; err != nil { - t.Errorf("errors happened when update: %v", err) - } else if user.Age != 5 { - t.Errorf("Age should equals to 5, but got %v", user.Age) - } else if user.Active != true { - t.Errorf("Active should be true, but got %v", user.Active) - } - checkUpdatedTime("Updates with map", user.UpdatedAt) - checkOtherData("Updates with map") - - var result2 User - if err := db.Where("id = ?", user.ID).First(&result2).Error; err != nil { - t.Errorf("errors happened when query: %v", err) - } else { - AssertObjEqual(t, result2, user, "Name", "Age", "Birthday") - } - - if err := db.Model(user).Updates(User{Age: 2}).Error; err != nil { - t.Errorf("errors happened when update: %v", err) - } else if user.Age != 2 { - t.Errorf("Age should equals to 2, but got %v", user.Age) - } - checkUpdatedTime("Updates with struct", user.UpdatedAt) - checkOtherData("Updates with struct") - - var result3 User - if err := db.Where("id = ?", user.ID).First(&result3).Error; err != nil { - t.Errorf("errors happened when query: %v", err) - } else { - AssertObjEqual(t, result3, user, "Name", "Age", "Birthday") - } - - user.Active = false - user.Age = 1 - if err := db.Save(user).Error; err != nil { - t.Errorf("errors happened when update: %v", err) - } else if user.Age != 1 { - t.Errorf("Age should equals to 1, but got %v", user.Age) - } else if user.Active != false { - t.Errorf("Active should equals to false, but got %v", user.Active) - } - checkUpdatedTime("Save", user.UpdatedAt) - checkOtherData("Save") - - var result4 User - if err := db.Where("id = ?", user.ID).First(&result4).Error; err != nil { - t.Errorf("errors happened when query: %v", err) - } else { - AssertObjEqual(t, result4, user, "Name", "Age", "Birthday") - } - }) -} - -func TestDelete(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&User{}) - db.AutoMigrate(&User{}) - - t.Run("Delete", func(t *testing.T) { - var users = []User{{ - Name: "find", - Age: 1, - Birthday: Now(), - }, { - Name: "find", - Age: 2, - Birthday: Now(), - }, { - Name: "find", - Age: 3, - Birthday: Now(), - }} - - if err := db.Create(&users).Error; err != nil { - t.Errorf("errors happened when create: %v", err) - } - - for _, user := range users { - if user.ID == 0 { - t.Fatalf("user's primary key should has value after create, got : %v", user.ID) - } - } - - if err := db.Delete(&users[1]).Error; err != nil { - t.Errorf("errors happened when delete: %v", err) - } - - var result User - if err := db.Where("id = ?", users[1].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { - t.Errorf("should returns record not found error, but got %v", err) - } - - for _, user := range []User{users[0], users[2]} { - if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { - t.Errorf("no error should returns when query %v, but got %v", user.ID, err) - } - } - - if err := db.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { - t.Errorf("should returns missing WHERE clause while deleting error") - } - - for _, user := range []User{users[0], users[2]} { - if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { - t.Errorf("no error should returns when query %v, but got %v", user.ID, err) - } - } - }) -} diff --git a/tests/update.go b/tests/update.go new file mode 100644 index 00000000..3a94313e --- /dev/null +++ b/tests/update.go @@ -0,0 +1,133 @@ +package tests + +import ( + "testing" + "time" + + "github.com/jinzhu/gorm" +) + +func TestUpdate(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&User{}) + db.AutoMigrate(&User{}) + + t.Run("Update", func(t *testing.T) { + var ( + users = []*User{{ + Name: "update-before", + Age: 1, + Birthday: Now(), + }, { + Name: "update", + Age: 18, + Birthday: Now(), + }, { + Name: "update-after", + Age: 1, + Birthday: Now(), + }} + user = users[1] + lastUpdatedAt time.Time + ) + + checkUpdatedTime := func(name string, n time.Time) { + if n.UnixNano() == lastUpdatedAt.UnixNano() { + t.Errorf("%v: user's updated at should be changed, but got %v, was %v", name, n, lastUpdatedAt) + } + lastUpdatedAt = n + } + + checkOtherData := func(name string) { + var beforeUser, afterUser User + if err := db.Where("id = ?", users[0].ID).First(&beforeUser).Error; err != nil { + t.Errorf("errors happened when query before user: %v", err) + } + t.Run(name, func(t *testing.T) { + AssertObjEqual(t, beforeUser, users[0], "Name", "Age", "Birthday") + }) + + if err := db.Where("id = ?", users[2].ID).First(&afterUser).Error; err != nil { + t.Errorf("errors happened when query after user: %v", err) + } + t.Run(name, func(t *testing.T) { + AssertObjEqual(t, afterUser, users[2], "Name", "Age", "Birthday") + }) + } + + if err := db.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } else if user.ID == 0 { + t.Fatalf("user's primary value should not zero, %v", user.ID) + } else if user.UpdatedAt.IsZero() { + t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt) + } + lastUpdatedAt = user.UpdatedAt + + if err := db.Model(user).Update("Age", 10).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 10 { + t.Errorf("Age should equals to 10, but got %v", user.Age) + } + checkUpdatedTime("Update", user.UpdatedAt) + checkOtherData("Update") + + var result User + if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, result, user, "Name", "Age", "Birthday") + } + + values := map[string]interface{}{"Active": true, "age": 5} + if err := db.Model(user).Updates(values).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 5 { + t.Errorf("Age should equals to 5, but got %v", user.Age) + } else if user.Active != true { + t.Errorf("Active should be true, but got %v", user.Active) + } + checkUpdatedTime("Updates with map", user.UpdatedAt) + checkOtherData("Updates with map") + + var result2 User + if err := db.Where("id = ?", user.ID).First(&result2).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, result2, user, "Name", "Age", "Birthday") + } + + if err := db.Model(user).Updates(User{Age: 2}).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 2 { + t.Errorf("Age should equals to 2, but got %v", user.Age) + } + checkUpdatedTime("Updates with struct", user.UpdatedAt) + checkOtherData("Updates with struct") + + var result3 User + if err := db.Where("id = ?", user.ID).First(&result3).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, result3, user, "Name", "Age", "Birthday") + } + + user.Active = false + user.Age = 1 + if err := db.Save(user).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 1 { + t.Errorf("Age should equals to 1, but got %v", user.Age) + } else if user.Active != false { + t.Errorf("Active should equals to false, but got %v", user.Active) + } + checkUpdatedTime("Save", user.UpdatedAt) + checkOtherData("Save") + + var result4 User + if err := db.Where("id = ?", user.ID).First(&result4).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, result4, user, "Name", "Age", "Birthday") + } + }) +} From 345ff7577c985b8c1f7e7f759391e989a9041034 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 15 Apr 2020 23:58:26 +0800 Subject: [PATCH 290/881] Save before associations --- callbacks/create.go | 23 +++++++++++++++++++++++ logger/sql.go | 22 ++++++++++++---------- tests/create.go | 41 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 10 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 97a2832c..e21e04c2 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -42,6 +42,29 @@ func BeforeCreate(db *gorm.DB) { } func SaveBeforeAssociations(db *gorm.DB) { + if db.Statement.Schema != nil { + for _, rel := range db.Statement.Schema.Relationships.BelongsTo { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice: + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + if f.Kind() == reflect.Ptr { + db.Session(&gorm.Session{}).Create(f.Interface()) + } else { + db.Session(&gorm.Session{}).Create(f.Addr().Interface()) + } + + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(f) + ref.ForeignKey.Set(db.Statement.ReflectValue, fv) + } + } + } + } + } + } } func Create(config *Config) func(db *gorm.DB) { diff --git a/logger/sql.go b/logger/sql.go index 41c514fd..9c0f54d7 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -51,20 +51,22 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v if v == nil { vars[idx] = "NULL" } else { - rv := reflect.Indirect(reflect.ValueOf(v)) + rv := reflect.ValueOf(v) + if !rv.IsValid() { vars[idx] = "NULL" - return - } - - for _, t := range convertableTypes { - if rv.Type().ConvertibleTo(t) { - convertParams(rv.Convert(t).Interface(), idx) - return + } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { + convertParams(reflect.Indirect(rv).Interface(), idx) + } else { + for _, t := range convertableTypes { + if rv.Type().ConvertibleTo(t) { + convertParams(rv.Convert(t).Interface(), idx) + return + } } - } - vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper + vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper + } } } } diff --git a/tests/create.go b/tests/create.go index dfd73bd3..74a010dc 100644 --- a/tests/create.go +++ b/tests/create.go @@ -40,4 +40,45 @@ func TestCreate(t *testing.T, db *gorm.DB) { AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday") } }) + + TestCreateAssociations(t, db) +} + +func TestCreateAssociations(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&Company{}) + db.Migrator().AutoMigrate(&Company{}) + + t.Run("Create-BelongsToAssociation", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association"}, + Manager: &User{Name: "manager-belongs-to-association"}, + } + + if err := db.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + if user.CompanyID == nil { + t.Errorf("Failed to create belongs to association - Company") + } else { + var company Company + db.First(&company, "id = ?", *user.CompanyID) + if company.Name != "company-belongs-to-association" { + t.Errorf("Failed to query saved belongs to association - Company") + } + } + + if user.ManagerID == nil { + t.Errorf("Failed to create belongs to association - Manager") + } else { + var manager User + db.First(&manager, "id = ?", *user.ManagerID) + if manager.Name != "manager-belongs-to-association" { + t.Errorf("Failed to query saved belongs to association - Manager") + } + } + }) } From 56ca9a87e06eea0ec63101d5e81c1359e7f45537 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 16 Apr 2020 10:29:18 +0800 Subject: [PATCH 291/881] Add permission check when create associations --- callbacks/associations.go | 72 +++++++++++++++++++++++++++++++++++++++ callbacks/create.go | 26 -------------- finisher_api.go | 5 +-- schema/field.go | 11 +++--- schema/utils.go | 7 ---- utils/utils.go | 15 ++++++++ 6 files changed, 94 insertions(+), 42 deletions(-) create mode 100644 callbacks/associations.go diff --git a/callbacks/associations.go b/callbacks/associations.go new file mode 100644 index 00000000..1df0103a --- /dev/null +++ b/callbacks/associations.go @@ -0,0 +1,72 @@ +package callbacks + +import ( + "reflect" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/gorm/utils" +) + +func SaveBeforeAssociations(db *gorm.DB) { + if db.Statement.Schema != nil { + for _, rel := range db.Statement.Schema.Relationships.BelongsTo { + creatable, updatable, saveRef := saveAssociationCheck(db, rel.Field) + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice: + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + + _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f) + + if isZero && creatable { + if f.Kind() == reflect.Ptr { + db.Session(&gorm.Session{}).Create(f.Interface()) + } else { + db.Session(&gorm.Session{}).Create(f.Addr().Interface()) + } + } else if !isZero && updatable { + if f.Kind() == reflect.Ptr { + db.Session(&gorm.Session{}).Save(f.Interface()) + } else { + db.Session(&gorm.Session{}).Save(f.Addr().Interface()) + } + } else { + continue + } + + if saveRef { + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(f) + ref.ForeignKey.Set(db.Statement.ReflectValue, fv) + } + } + } + } + } + } + } +} + +func saveAssociationCheck(db *gorm.DB, field *schema.Field) (bool, bool, bool) { + creatable := field.Creatable + updatable := field.Updatable + saveRef := true + + if value, ok := db.Get("gorm:association_autocreate"); creatable && ok { + creatable = utils.CheckTruth(value) + } + + if value, ok := db.Get("gorm:association_autoupdate"); updatable && ok { + updatable = utils.CheckTruth(value) + } + + if value, ok := db.Get("gorm:association_save_reference"); ok { + saveRef = utils.CheckTruth(value) + } + + return creatable, updatable, saveRef +} diff --git a/callbacks/create.go b/callbacks/create.go index e21e04c2..829c9c4c 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -41,32 +41,6 @@ func BeforeCreate(db *gorm.DB) { } } -func SaveBeforeAssociations(db *gorm.DB) { - if db.Statement.Schema != nil { - for _, rel := range db.Statement.Schema.Relationships.BelongsTo { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice: - case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) - if f.Kind() == reflect.Ptr { - db.Session(&gorm.Session{}).Create(f.Interface()) - } else { - db.Session(&gorm.Session{}).Create(f.Addr().Interface()) - } - - for _, ref := range rel.References { - if !ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(f) - ref.ForeignKey.Set(db.Statement.ReflectValue, fv) - } - } - } - } - } - } -} - func Create(config *Config) func(db *gorm.DB) { if config.WithReturning { return CreateWithReturning diff --git a/finisher_api.go b/finisher_api.go index 62c1af30..9e29e327 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -21,7 +21,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = value - if err := tx.Statement.Parse(value); err != nil && tx.Statement.Schema != nil { + if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} reflectValue := reflect.ValueOf(value) for idx, pf := range tx.Statement.Schema.PrimaryFields { @@ -35,9 +35,6 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.Statement.AddClause(where) } - if len(tx.Statement.Selects) == 0 { - tx.Statement.Selects = []string{"*"} - } tx.callbacks.Update().Execute(tx) return } diff --git a/schema/field.go b/schema/field.go index ec419383..7b37733b 100644 --- a/schema/field.go +++ b/schema/field.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/jinzhu/gorm/utils" "github.com/jinzhu/now" ) @@ -146,13 +147,13 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DBName = dbName } - if val, ok := field.TagSettings["PRIMARYKEY"]; ok && checkTruth(val) { + if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { field.PrimaryKey = true - } else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && checkTruth(val) { + } else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { field.PrimaryKey = true } - if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && checkTruth(val) { + if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { field.AutoIncrement = true field.HasDefaultValue = true } @@ -173,11 +174,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Precision, _ = strconv.Atoi(p) } - if val, ok := field.TagSettings["NOT NULL"]; ok && checkTruth(val) { + if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { field.NotNull = true } - if val, ok := field.TagSettings["UNIQUE"]; ok && checkTruth(val) { + if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) { field.Unique = true } diff --git a/schema/utils.go b/schema/utils.go index d7572d3d..7be78bc5 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -37,13 +37,6 @@ func ParseTagSetting(str string, sep string) map[string]string { return settings } -func checkTruth(val string) bool { - if strings.ToLower(val) == "false" { - return false - } - return true -} - func toColumns(val string) (results []string) { if val != "" { for _, v := range strings.Split(val, ",") { diff --git a/utils/utils.go b/utils/utils.go index 8521d09b..8dd500a5 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -2,8 +2,10 @@ package utils import ( "fmt" + "reflect" "regexp" "runtime" + "strings" "unicode" ) @@ -23,3 +25,16 @@ func FileWithLineNum() string { func IsChar(c rune) bool { return !unicode.IsLetter(c) && !unicode.IsNumber(c) } + +func CheckTruth(val interface{}) bool { + if v, ok := val.(bool); ok { + return v + } + + if v, ok := val.(string); ok { + v = strings.ToLower(v) + return v != "false" + } + + return !reflect.ValueOf(val).IsZero() +} From fb44625c33b4790c74b052c4628721ce17794741 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 17 Apr 2020 08:23:47 +0800 Subject: [PATCH 292/881] Save HasOne association --- callbacks/associations.go | 50 ++++++++++++++++++++++++++++++++++++++- callbacks/create.go | 3 --- tests/create.go | 25 ++++++++++++++++++++ 3 files changed, 74 insertions(+), 4 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 1df0103a..283a2666 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -10,15 +10,18 @@ import ( func SaveBeforeAssociations(db *gorm.DB) { if db.Statement.Schema != nil { + // Save Belongs To associations for _, rel := range db.Statement.Schema.Relationships.BelongsTo { creatable, updatable, saveRef := saveAssociationCheck(db, rel.Field) + if !(creatable || updatable) { + continue + } switch db.Statement.ReflectValue.Kind() { case reflect.Slice: case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) - _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f) if isZero && creatable { @@ -51,6 +54,51 @@ func SaveBeforeAssociations(db *gorm.DB) { } } +func SaveAfterAssociations(db *gorm.DB) { + // Save Has One associations + for _, rel := range db.Statement.Schema.Relationships.HasOne { + creatable, updatable, saveRef := saveAssociationCheck(db, rel.Field) + if !(creatable || updatable) { + continue + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice: + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + + if saveRef { + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) + ref.ForeignKey.Set(f, fv) + } + } + } + + _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f) + + if isZero && creatable { + if f.Kind() == reflect.Ptr { + db.Session(&gorm.Session{}).Create(f.Interface()) + } else { + db.Session(&gorm.Session{}).Create(f.Addr().Interface()) + } + } else if !isZero && updatable { + if f.Kind() == reflect.Ptr { + db.Session(&gorm.Session{}).Save(f.Interface()) + } else { + db.Session(&gorm.Session{}).Save(f.Addr().Interface()) + } + } else { + continue + } + } + } + } +} + func saveAssociationCheck(db *gorm.DB, field *schema.Field) (bool, bool, bool) { creatable := field.Creatable updatable := field.Updatable diff --git a/callbacks/create.go b/callbacks/create.go index 829c9c4c..9dc8dc67 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -151,9 +151,6 @@ func CreateWithReturning(db *gorm.DB) { } } -func SaveAfterAssociations(db *gorm.DB) { -} - func AfterCreate(db *gorm.DB) { if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { callMethod := func(value interface{}) bool { diff --git a/tests/create.go b/tests/create.go index 74a010dc..b8e9245b 100644 --- a/tests/create.go +++ b/tests/create.go @@ -81,4 +81,29 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) { } } }) + + t.Run("Create-HasOneAssociation", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + Account: Account{Number: "account-has-one-association"}, + } + + if err := db.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + if user.Account.ID == 0 { + t.Errorf("Failed to create has one association - Account") + } else if user.Account.UserID.Int64 != int64(user.ID) { + t.Errorf("Failed to create has one association - Account") + } else { + var account Account + db.First(&account, "id = ?", user.Account.ID) + if user.Account.Number != "account-has-one-association" { + t.Errorf("Failed to query saved has one association - Account") + } + } + }) } From 952df527db254990e1ea250ab9670894a9aa92ab Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 17 Apr 2020 08:40:07 +0800 Subject: [PATCH 293/881] Test create polymorphic has one --- callbacks/associations.go | 2 ++ tests/create.go | 22 ++++++++++++++++++++++ tests/model.go | 1 + 3 files changed, 25 insertions(+) diff --git a/callbacks/associations.go b/callbacks/associations.go index 283a2666..bbfbbc3d 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -73,6 +73,8 @@ func SaveAfterAssociations(db *gorm.DB) { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) ref.ForeignKey.Set(f, fv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(f, ref.PrimaryValue) } } } diff --git a/tests/create.go b/tests/create.go index b8e9245b..10b6b699 100644 --- a/tests/create.go +++ b/tests/create.go @@ -1,6 +1,7 @@ package tests import ( + "fmt" "testing" "github.com/jinzhu/gorm" @@ -106,4 +107,25 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) { } } }) + + t.Run("Create-HasOneAssociation-Polymorphic", func(t *testing.T) { + var pet = Pet{ + Name: "create", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic"}, + } + + if err := db.Create(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + if pet.Toy.OwnerID != fmt.Sprint(pet.ID) || pet.Toy.OwnerType != "pets" { + t.Errorf("Failed to create polymorphic has one association - toy owner id %v, owner type %v", pet.Toy.OwnerID, pet.Toy.OwnerType) + } else { + var toy Toy + db.First(&toy, "owner_id = ? and owner_type = ?", pet.Toy.OwnerID, pet.Toy.OwnerType) + if toy.Name != "Create-HasOneAssociation-Polymorphic" { + t.Errorf("Failed to query saved polymorphic has one association") + } + } + }) } diff --git a/tests/model.go b/tests/model.go index 4d686a57..1ae7c160 100644 --- a/tests/model.go +++ b/tests/model.go @@ -44,6 +44,7 @@ type Pet struct { type Toy struct { gorm.Model + Name string OwnerID string OwnerType string } From 158bacefbef5d3995a81da02b238a5b3b8a3b024 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 19 Apr 2020 14:29:31 +0800 Subject: [PATCH 294/881] Add save has many relations --- callbacks/associations.go | 52 +++++++++++++++++++++++++++++++++++++ tests/create.go | 54 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+) diff --git a/callbacks/associations.go b/callbacks/associations.go index bbfbbc3d..6d976eac 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -99,6 +99,58 @@ func SaveAfterAssociations(db *gorm.DB) { } } } + + // Save Has Many associations + for _, rel := range db.Statement.Schema.Relationships.HasMany { + creatable, updatable, _ := saveAssociationCheck(db, rel.Field) + if !(creatable || updatable) { + continue + } + + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := true + if fieldType.Kind() != reflect.Ptr { + isPtr = false + fieldType = reflect.PtrTo(fieldType) + } + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + db.Statement.ReflectValue.Index(i) + } + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.ReflectValue)) + + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) + _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) + ref.ForeignKey.Set(elem, fv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(elem, ref.PrimaryValue) + } + } + + if isZero && creatable { + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) + } + } + } + } + } + + if elems.Len() > 0 { + db.Session(&gorm.Session{}).Create(elems.Interface()) + } + } } func saveAssociationCheck(db *gorm.DB, field *schema.Field) (bool, bool, bool) { diff --git a/tests/create.go b/tests/create.go index 10b6b699..218e1e59 100644 --- a/tests/create.go +++ b/tests/create.go @@ -128,4 +128,58 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) { } } }) + + t.Run("Create-HasManyAssociation", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + Pets: []*Pet{{Name: "pet1"}, {Name: "pet2"}}, + } + + if err := db.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for idx, pet := range user.Pets { + if pet.ID == 0 { + t.Fatalf("Failed to create pet #%v", idx) + } + + var result Pet + db.First(&result, "id = ?", pet.ID) + if result.Name != pet.Name { + t.Errorf("Failed to query pet") + } else if result.UserID != user.ID { + t.Errorf("Failed to save relation") + } + } + }) + + t.Run("Create-HasManyAssociation-Polymorphic", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + Toys: []Toy{{Name: "toy1"}, {Name: "toy2"}}, + } + + if err := db.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for idx, toy := range user.Toys { + if toy.ID == 0 { + t.Fatalf("Failed to create toy #%v", idx) + } + + var result Toy + db.First(&result, "id = ?", toy.ID) + if result.Name != toy.Name { + t.Errorf("Failed to query saved toy") + } else if result.OwnerID != fmt.Sprint(user.ID) || result.OwnerType != "users" { + t.Errorf("Failed to save relation") + } + } + }) } From 7bcd95d4b882544c613fa3609a5fc91a0c0e2714 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 19 Apr 2020 23:11:56 +0800 Subject: [PATCH 295/881] Add save associations for bulk create --- callbacks/associations.go | 330 +++++++++++++++++++++++++------------- callbacks/helper.go | 11 +- gorm.go | 3 +- 3 files changed, 229 insertions(+), 115 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 6d976eac..8cc96029 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -10,41 +10,75 @@ import ( func SaveBeforeAssociations(db *gorm.DB) { if db.Statement.Schema != nil { + selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) + // Save Belongs To associations for _, rel := range db.Statement.Schema.Relationships.BelongsTo { - creatable, updatable, saveRef := saveAssociationCheck(db, rel.Field) - if !(creatable || updatable) { + if !saveAssociationCheck(db, rel, selectColumns, restricted) { continue } switch db.Statement.ReflectValue.Kind() { case reflect.Slice: + var ( + objs []reflect.Value + fieldType = rel.Field.FieldType + isPtr = fieldType.Kind() == reflect.Ptr + ) + + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(obj) // relation reflect value + if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { + objs = append(objs, obj) + if isPtr { + elems = reflect.Append(elems, rv) + } else { + elems = reflect.Append(elems, rv.Addr()) + } + } else { + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(rv) + ref.ForeignKey.Set(objs[i], pv) + } + } + } + } + } + + if elems.Len() > 0 { + if db.AddError(db.Session(&gorm.Session{}).Create(elems.Interface()).Error) == nil { + for i := 0; i < elems.Len(); i++ { + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(elems.Index(i)) + ref.ForeignKey.Set(objs[i], pv) + } + } + } + } + } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) - _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f) - - if isZero && creatable { - if f.Kind() == reflect.Ptr { - db.Session(&gorm.Session{}).Create(f.Interface()) + rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value + if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { + if rv.Kind() == reflect.Ptr { + db.Session(&gorm.Session{}).Create(rv.Interface()) } else { - db.Session(&gorm.Session{}).Create(f.Addr().Interface()) + db.Session(&gorm.Session{}).Create(rv.Addr().Interface()) } - } else if !isZero && updatable { - if f.Kind() == reflect.Ptr { - db.Session(&gorm.Session{}).Save(f.Interface()) - } else { - db.Session(&gorm.Session{}).Save(f.Addr().Interface()) - } - } else { - continue - } - if saveRef { for _, ref := range rel.References { if !ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(f) - ref.ForeignKey.Set(db.Statement.ReflectValue, fv) + pv, _ := ref.PrimaryKey.ValueOf(rv) + ref.ForeignKey.Set(db.Statement.ReflectValue, pv) } } } @@ -55,20 +89,58 @@ func SaveBeforeAssociations(db *gorm.DB) { } func SaveAfterAssociations(db *gorm.DB) { - // Save Has One associations - for _, rel := range db.Statement.Schema.Relationships.HasOne { - creatable, updatable, saveRef := saveAssociationCheck(db, rel.Field) - if !(creatable || updatable) { - continue - } + if db.Statement.Schema != nil { + selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice: - case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + // Save Has One associations + for _, rel := range db.Statement.Schema.Relationships.HasOne { + if !saveAssociationCheck(db, rel, selectColumns, restricted) { + continue + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice: + var ( + fieldType = rel.Field.FieldType + isPtr = fieldType.Kind() == reflect.Ptr + ) + + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) + + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + if rv, zero := rel.Field.ValueOf(obj); !zero { + rv := reflect.ValueOf(rv) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(obj) + ref.ForeignKey.Set(rv, fv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(rv, ref.PrimaryValue) + } + } + + if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { + if isPtr { + elems = reflect.Append(elems, rv) + } else { + elems = reflect.Append(elems, rv.Addr()) + } + } + } + } + + if elems.Len() > 0 { + db.Session(&gorm.Session{}).Create(elems.Interface()) + } + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) - if saveRef { for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) @@ -77,98 +149,134 @@ func SaveAfterAssociations(db *gorm.DB) { ref.ForeignKey.Set(f, ref.PrimaryValue) } } - } - _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f) - - if isZero && creatable { - if f.Kind() == reflect.Ptr { - db.Session(&gorm.Session{}).Create(f.Interface()) - } else { - db.Session(&gorm.Session{}).Create(f.Addr().Interface()) - } - } else if !isZero && updatable { - if f.Kind() == reflect.Ptr { - db.Session(&gorm.Session{}).Save(f.Interface()) - } else { - db.Session(&gorm.Session{}).Save(f.Addr().Interface()) - } - } else { - continue - } - } - } - } - - // Save Has Many associations - for _, rel := range db.Statement.Schema.Relationships.HasMany { - creatable, updatable, _ := saveAssociationCheck(db, rel.Field) - if !(creatable || updatable) { - continue - } - - fieldType := rel.Field.IndirectFieldType.Elem() - isPtr := true - if fieldType.Kind() != reflect.Ptr { - isPtr = false - fieldType = reflect.PtrTo(fieldType) - } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) - - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - db.Statement.ReflectValue.Index(i) - } - case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.ReflectValue)) - - for i := 0; i < f.Len(); i++ { - elem := f.Index(i) - _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem) - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) - ref.ForeignKey.Set(elem, fv) - } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(elem, ref.PrimaryValue) - } - } - - if isZero && creatable { - if isPtr { - elems = reflect.Append(elems, elem) + if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f); isZero { + if f.Kind() == reflect.Ptr { + db.Session(&gorm.Session{}).Create(f.Interface()) } else { - elems = reflect.Append(elems, elem.Addr()) + db.Session(&gorm.Session{}).Create(f.Addr().Interface()) } } } } } - if elems.Len() > 0 { - db.Session(&gorm.Session{}).Create(elems.Interface()) + // Save Has Many associations + for _, rel := range db.Statement.Schema.Relationships.HasMany { + if !saveAssociationCheck(db, rel, selectColumns, restricted) { + continue + } + + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := true + if fieldType.Kind() != reflect.Ptr { + isPtr = false + fieldType = reflect.PtrTo(fieldType) + } + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) + appendToElems := func(v reflect.Value) { + if _, zero := rel.Field.ValueOf(v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(v)) + + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(v) + ref.ForeignKey.Set(elem, pv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(elem, ref.PrimaryValue) + } + } + + if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero { + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) + } + } + } + } + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + appendToElems(db.Statement.ReflectValue.Index(i)) + } + case reflect.Struct: + appendToElems(db.Statement.ReflectValue) + } + + if elems.Len() > 0 { + db.Session(&gorm.Session{}).Create(elems.Interface()) + } + } + + // Save Many2Many associations + for _, rel := range db.Statement.Schema.Relationships.Many2Many { + if !saveAssociationCheck(db, rel, selectColumns, restricted) { + continue + } + + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := true + if fieldType.Kind() != reflect.Ptr { + isPtr = false + fieldType = reflect.PtrTo(fieldType) + } + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + db.Statement.ReflectValue.Index(i) + } + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.ReflectValue)) + + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) + ref.ForeignKey.Set(elem, fv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(elem, ref.PrimaryValue) + } + } + + if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero { + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) + } + } + } + } + } + + if elems.Len() > 0 { + db.Session(&gorm.Session{}).Create(elems.Interface()) + } } } } -func saveAssociationCheck(db *gorm.DB, field *schema.Field) (bool, bool, bool) { - creatable := field.Creatable - updatable := field.Updatable - saveRef := true - - if value, ok := db.Get("gorm:association_autocreate"); creatable && ok { - creatable = utils.CheckTruth(value) +func saveAssociationCheck(db *gorm.DB, rel *schema.Relationship, selectColumns map[string]bool, restricted bool) bool { + savable := true + if value, ok := db.Get("gorm:save_association"); ok { + savable = utils.CheckTruth(value) } - if value, ok := db.Get("gorm:association_autoupdate"); updatable && ok { - updatable = utils.CheckTruth(value) + if savable { + if v, ok := selectColumns[rel.Name]; (ok && v) || (!ok && !restricted) { + return true + } } - if value, ok := db.Get("gorm:association_save_reference"); ok { - saveRef = utils.CheckTruth(value) - } - - return creatable, updatable, saveRef + return false } diff --git a/callbacks/helper.go b/callbacks/helper.go index 8a69fbd1..092c9c37 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -37,11 +37,16 @@ func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate boo } if stmt.Schema != nil { - for _, field := range stmt.Schema.FieldsByDBName { + for _, field := range stmt.Schema.Fields { + name := field.DBName + if name == "" { + name = field.Name + } + if requireCreate && !field.Creatable { - results[field.DBName] = false + results[name] = false } else if requireUpdate && !field.Updatable { - results[field.DBName] = false + results[name] = false } } } diff --git a/gorm.go b/gorm.go index 2d78c8d9..f8c944af 100644 --- a/gorm.go +++ b/gorm.go @@ -161,12 +161,13 @@ func (db *DB) AutoMigrate(dst ...interface{}) error { } // AddError add error to db -func (db *DB) AddError(err error) { +func (db *DB) AddError(err error) error { if db.Error == nil { db.Error = err } else if err != nil { db.Error = fmt.Errorf("%v; %w", db.Error, err) } + return db.Error } func (db *DB) getInstance() *DB { From 43a814ae708a08f56ab84904435201e5c57afebe Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 20 Apr 2020 11:47:29 +0800 Subject: [PATCH 296/881] Add bulk create associations tests --- callbacks/associations.go | 131 +++++++------- tests/create.go | 347 +++++++++++++++++++++++++++++++++----- 2 files changed, 380 insertions(+), 98 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 8cc96029..98e0d254 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -18,6 +18,15 @@ func SaveBeforeAssociations(db *gorm.DB) { continue } + setupReferences := func(obj reflect.Value, elem reflect.Value) { + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(elem) + ref.ForeignKey.Set(obj, pv) + } + } + } + switch db.Statement.ReflectValue.Kind() { case reflect.Slice: var ( @@ -43,12 +52,7 @@ func SaveBeforeAssociations(db *gorm.DB) { elems = reflect.Append(elems, rv.Addr()) } } else { - for _, ref := range rel.References { - if !ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(rv) - ref.ForeignKey.Set(objs[i], pv) - } - } + setupReferences(obj, rv) } } } @@ -56,31 +60,20 @@ func SaveBeforeAssociations(db *gorm.DB) { if elems.Len() > 0 { if db.AddError(db.Session(&gorm.Session{}).Create(elems.Interface()).Error) == nil { for i := 0; i < elems.Len(); i++ { - for _, ref := range rel.References { - if !ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(elems.Index(i)) - ref.ForeignKey.Set(objs[i], pv) - } - } + setupReferences(objs[i], elems.Index(i)) } } } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value - if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { - if rv.Kind() == reflect.Ptr { - db.Session(&gorm.Session{}).Create(rv.Interface()) - } else { - db.Session(&gorm.Session{}).Create(rv.Addr().Interface()) - } + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() + } - for _, ref := range rel.References { - if !ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(rv) - ref.ForeignKey.Set(db.Statement.ReflectValue, pv) - } - } + if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { + db.Session(&gorm.Session{}).Create(rv.Interface()) + setupReferences(db.Statement.ReflectValue, rv) } } } @@ -113,8 +106,13 @@ func SaveAfterAssociations(db *gorm.DB) { for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) - if rv, zero := rel.Field.ValueOf(obj); !zero { - rv := reflect.ValueOf(rv) + + if _, zero := rel.Field.ValueOf(obj); !zero { + rv := rel.Field.ReflectValueOf(obj) + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() + } + for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(obj) @@ -125,11 +123,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { - if isPtr { - elems = reflect.Append(elems, rv) - } else { - elems = reflect.Append(elems, rv.Addr()) - } + elems = reflect.Append(elems, rv) } } } @@ -140,6 +134,9 @@ func SaveAfterAssociations(db *gorm.DB) { case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + if f.Kind() != reflect.Ptr { + f = f.Addr() + } for _, ref := range rel.References { if ref.OwnPrimaryKey { @@ -151,11 +148,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f); isZero { - if f.Kind() == reflect.Ptr { - db.Session(&gorm.Session{}).Create(f.Interface()) - } else { - db.Session(&gorm.Session{}).Create(f.Addr().Interface()) - } + db.Session(&gorm.Session{}).Create(f.Interface()) } } } @@ -168,9 +161,8 @@ func SaveAfterAssociations(db *gorm.DB) { } fieldType := rel.Field.IndirectFieldType.Elem() - isPtr := true - if fieldType.Kind() != reflect.Ptr { - isPtr = false + isPtr := fieldType.Kind() == reflect.Ptr + if !isPtr { fieldType = reflect.PtrTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) @@ -221,46 +213,71 @@ func SaveAfterAssociations(db *gorm.DB) { } fieldType := rel.Field.IndirectFieldType.Elem() - isPtr := true - if fieldType.Kind() != reflect.Ptr { - isPtr = false + isPtr := fieldType.Kind() == reflect.Ptr + if !isPtr { fieldType = reflect.PtrTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) + joins := reflect.MakeSlice(reflect.SliceOf(rel.JoinTable.ModelType), 0, 0) + objs := []reflect.Value{} - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - db.Statement.ReflectValue.Index(i) + appendToJoins := func(obj reflect.Value, elem reflect.Value) { + joinValue := reflect.New(rel.JoinTable.ModelType) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(obj) + ref.ForeignKey.Set(joinValue, fv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(joinValue, ref.PrimaryValue) + } else { + fv, _ := ref.PrimaryKey.ValueOf(elem) + ref.ForeignKey.Set(joinValue, fv) + } } - case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.ReflectValue)) + + joins = reflect.Append(joins, joinValue) + } + + appendToElems := func(v reflect.Value) { + if _, zero := rel.Field.ValueOf(v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(v)) for i := 0; i < f.Len(); i++ { elem := f.Index(i) - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) - ref.ForeignKey.Set(elem, fv) - } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(elem, ref.PrimaryValue) - } - } if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero { + objs = append(objs, v) if isPtr { elems = reflect.Append(elems, elem) } else { elems = reflect.Append(elems, elem.Addr()) } + } else { + appendToJoins(v, elem) } } } } + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + appendToElems(db.Statement.ReflectValue.Index(i)) + } + case reflect.Struct: + appendToElems(db.Statement.ReflectValue) + } + if elems.Len() > 0 { db.Session(&gorm.Session{}).Create(elems.Interface()) + + for i := 0; i < elems.Len(); i++ { + appendToJoins(objs[i], elems.Index(i)) + } + } + + if joins.Len() > 0 { + db.Session(&gorm.Session{}).Create(joins.Interface()) } } } diff --git a/tests/create.go b/tests/create.go index 218e1e59..b4bdd47e 100644 --- a/tests/create.go +++ b/tests/create.go @@ -40,16 +40,53 @@ func TestCreate(t *testing.T, db *gorm.DB) { } else { AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday") } - }) - TestCreateAssociations(t, db) + TestCreateAssociations(t, db) + }) } func TestCreateAssociations(t *testing.T, db *gorm.DB) { + TestCreateBelongsToAssociations(t, db) + TestCreateHasOneAssociations(t, db) + TestCreateHasManyAssociations(t, db) + TestCreateMany2ManyAssociations(t, db) +} + +func TestCreateBelongsToAssociations(t *testing.T, db *gorm.DB) { db.Migrator().DropTable(&Company{}) db.Migrator().AutoMigrate(&Company{}) - t.Run("Create-BelongsToAssociation", func(t *testing.T) { + check := func(t *testing.T, user User) { + if user.Company.Name != "" { + if user.CompanyID == nil { + t.Errorf("Company's foreign key should be saved") + } else { + var company Company + db.First(&company, "id = ?", *user.CompanyID) + if company.Name != user.Company.Name { + t.Errorf("Company's name should be same") + } + } + } else if user.CompanyID != nil { + t.Errorf("Company should not be created for zero value, got: %+v", user.CompanyID) + } + + if user.Manager != nil { + if user.ManagerID == nil { + t.Errorf("Manager's foreign key should be saved") + } else { + var manager User + db.First(&manager, "id = ?", *user.ManagerID) + if manager.Name != user.Manager.Name { + t.Errorf("Manager's name should be same") + } + } + } else if user.ManagerID != nil { + t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) + } + } + + t.Run("BelongsTo", func(t *testing.T) { var user = User{ Name: "create", Age: 18, @@ -62,28 +99,113 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) { t.Fatalf("errors happened when create: %v", err) } - if user.CompanyID == nil { - t.Errorf("Failed to create belongs to association - Company") - } else { - var company Company - db.First(&company, "id = ?", *user.CompanyID) - if company.Name != "company-belongs-to-association" { - t.Errorf("Failed to query saved belongs to association - Company") - } + check(t, user) + }) + + t.Run("BelongsToForBulkInsert", func(t *testing.T) { + var users = []User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-1"}, + Manager: &User{Name: "manager-belongs-to-association-1"}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-2"}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-3"}, + Manager: &User{Name: "manager-belongs-to-association-3"}, + }} + + if err := db.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) } - if user.ManagerID == nil { - t.Errorf("Failed to create belongs to association - Manager") - } else { - var manager User - db.First(&manager, "id = ?", *user.ManagerID) - if manager.Name != "manager-belongs-to-association" { - t.Errorf("Failed to query saved belongs to association - Manager") - } + for _, user := range users { + check(t, user) } }) - t.Run("Create-HasOneAssociation", func(t *testing.T) { + t.Run("BelongsToForBulkInsertPtrData", func(t *testing.T) { + var users = []*User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-1"}, + Manager: &User{Name: "manager-belongs-to-association-1"}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-2"}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-3"}, + Manager: &User{Name: "manager-belongs-to-association-3"}, + }} + + if err := db.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, user := range users { + check(t, *user) + } + }) + + t.Run("BelongsToForBulkInsertWithoutPtr", func(t *testing.T) { + var users = []*User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-1"}, + Manager: &User{Name: "manager-belongs-to-association-1"}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-2"}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-3"}, + Manager: &User{Name: "manager-belongs-to-association-3"}, + }} + + if err := db.Create(users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, user := range users { + check(t, *user) + } + }) +} + +func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) { + check := func(t *testing.T, user User) { + if user.Account.ID == 0 { + t.Errorf("Account should be saved") + } else if user.Account.UserID.Int64 != int64(user.ID) { + t.Errorf("Account's foreign key should be saved") + } else { + var account Account + db.First(&account, "id = ?", user.Account.ID) + if account.Number != user.Account.Number { + t.Errorf("Account's number should be sme") + } + } + } + + t.Run("HasOne", func(t *testing.T) { var user = User{ Name: "create", Age: 18, @@ -95,20 +217,103 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) { t.Fatalf("errors happened when create: %v", err) } - if user.Account.ID == 0 { - t.Errorf("Failed to create has one association - Account") - } else if user.Account.UserID.Int64 != int64(user.ID) { - t.Errorf("Failed to create has one association - Account") - } else { - var account Account - db.First(&account, "id = ?", user.Account.ID) - if user.Account.Number != "account-has-one-association" { - t.Errorf("Failed to query saved has one association - Account") - } + check(t, user) + }) + + t.Run("HasOneForBulkInsert", func(t *testing.T) { + var users = []User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Account: Account{Number: "account-has-one-association-1"}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Account: Account{Number: "account-has-one-association-2"}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Account: Account{Number: "account-has-one-association-3"}, + }} + + if err := db.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, user := range users { + check(t, user) } }) - t.Run("Create-HasOneAssociation-Polymorphic", func(t *testing.T) { + t.Run("HasOneForBulkInsertPtrData", func(t *testing.T) { + var users = []*User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Account: Account{Number: "account-has-one-association-1"}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Account: Account{Number: "account-has-one-association-2"}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Account: Account{Number: "account-has-one-association-3"}, + }} + + if err := db.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, user := range users { + check(t, *user) + } + }) + + t.Run("HasOneForBulkInsertWithoutPtr", func(t *testing.T) { + var users = []User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Account: Account{Number: "account-has-one-association-1"}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Account: Account{Number: "account-has-one-association-2"}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Account: Account{Number: "account-has-one-association-3"}, + }} + + if err := db.Create(users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, user := range users { + check(t, user) + } + }) + + checkPet := func(t *testing.T, pet Pet) { + if pet.Toy.OwnerID != fmt.Sprint(pet.ID) || pet.Toy.OwnerType != "pets" { + t.Errorf("Failed to create polymorphic has one association - toy owner id %v, owner type %v", pet.Toy.OwnerID, pet.Toy.OwnerType) + } else { + var toy Toy + db.First(&toy, "owner_id = ? and owner_type = ?", pet.Toy.OwnerID, pet.Toy.OwnerType) + if toy.Name != pet.Toy.Name { + t.Errorf("Failed to query saved polymorphic has one association") + } + } + } + + t.Run("PolymorphicHasOne", func(t *testing.T) { var pet = Pet{ Name: "create", Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic"}, @@ -118,18 +323,75 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) { t.Fatalf("errors happened when create: %v", err) } - if pet.Toy.OwnerID != fmt.Sprint(pet.ID) || pet.Toy.OwnerType != "pets" { - t.Errorf("Failed to create polymorphic has one association - toy owner id %v, owner type %v", pet.Toy.OwnerID, pet.Toy.OwnerType) - } else { - var toy Toy - db.First(&toy, "owner_id = ? and owner_type = ?", pet.Toy.OwnerID, pet.Toy.OwnerType) - if toy.Name != "Create-HasOneAssociation-Polymorphic" { - t.Errorf("Failed to query saved polymorphic has one association") - } + checkPet(t, pet) + }) + + t.Run("PolymorphicHasOneForBulkInsert", func(t *testing.T) { + var pets = []Pet{{ + Name: "create-1", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-1"}, + }, { + Name: "create-2", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-2"}, + }, { + Name: "create-3", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-3"}, + }} + + if err := db.Create(&pets).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, pet := range pets { + checkPet(t, pet) } }) - t.Run("Create-HasManyAssociation", func(t *testing.T) { + t.Run("PolymorphicHasOneForBulkInsertPtrData", func(t *testing.T) { + var pets = []*Pet{{ + Name: "create-1", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-1"}, + }, { + Name: "create-2", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-2"}, + }, { + Name: "create-3", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-3"}, + }} + + if err := db.Create(&pets).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, pet := range pets { + checkPet(t, *pet) + } + }) + + t.Run("PolymorphicHasOneForBulkInsertWithoutPtr", func(t *testing.T) { + var pets = []*Pet{{ + Name: "create-1", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-1"}, + }, { + Name: "create-2", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-2"}, + }, { + Name: "create-3", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-3"}, + }} + + if err := db.Create(pets).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, pet := range pets { + checkPet(t, *pet) + } + }) +} + +func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { + t.Run("HasMany", func(t *testing.T) { var user = User{ Name: "create", Age: 18, @@ -156,7 +418,7 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) { } }) - t.Run("Create-HasManyAssociation-Polymorphic", func(t *testing.T) { + t.Run("PolymorphicHasMany", func(t *testing.T) { var user = User{ Name: "create", Age: 18, @@ -183,3 +445,6 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) { } }) } + +func TestCreateMany2ManyAssociations(t *testing.T, db *gorm.DB) { +} From 85f317446795b35c84e92d77cdf5d9583504e52d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 20 Apr 2020 23:35:18 +0800 Subject: [PATCH 297/881] Test has many associations --- callbacks/associations.go | 2 +- tests/create.go | 252 ++++++++++++++++++++++++++++++++++---- 2 files changed, 231 insertions(+), 23 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 98e0d254..df19d5f5 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -218,7 +218,7 @@ func SaveAfterAssociations(db *gorm.DB) { fieldType = reflect.PtrTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) - joins := reflect.MakeSlice(reflect.SliceOf(rel.JoinTable.ModelType), 0, 0) + joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 0) objs := []reflect.Value{} appendToJoins := func(obj reflect.Value, elem reflect.Value) { diff --git a/tests/create.go b/tests/create.go index b4bdd47e..27ad7a49 100644 --- a/tests/create.go +++ b/tests/create.go @@ -46,6 +46,9 @@ func TestCreate(t *testing.T, db *gorm.DB) { } func TestCreateAssociations(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{}) + db.Migrator().AutoMigrate(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{}) + TestCreateBelongsToAssociations(t, db) TestCreateHasOneAssociations(t, db) TestCreateHasManyAssociations(t, db) @@ -53,9 +56,6 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) { } func TestCreateBelongsToAssociations(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&Company{}) - db.Migrator().AutoMigrate(&Company{}) - check := func(t *testing.T, user User) { if user.Company.Name != "" { if user.CompanyID == nil { @@ -391,6 +391,22 @@ func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) { } func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { + check := func(t *testing.T, user User) { + for _, pet := range user.Pets { + if pet.ID == 0 { + t.Errorf("Pet's foreign key should be saved") + } + + var result Pet + db.First(&result, "id = ?", pet.ID) + if result.Name != pet.Name { + t.Errorf("Pet's name should be same") + } else if result.UserID != user.ID { + t.Errorf("Pet's foreign key should be saved") + } + } + } + t.Run("HasMany", func(t *testing.T) { var user = User{ Name: "create", @@ -403,33 +419,91 @@ func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { t.Fatalf("errors happened when create: %v", err) } - for idx, pet := range user.Pets { - if pet.ID == 0 { - t.Fatalf("Failed to create pet #%v", idx) - } - - var result Pet - db.First(&result, "id = ?", pet.ID) - if result.Name != pet.Name { - t.Errorf("Failed to query pet") - } else if result.UserID != user.ID { - t.Errorf("Failed to save relation") - } - } + check(t, user) }) - t.Run("PolymorphicHasMany", func(t *testing.T) { - var user = User{ - Name: "create", + t.Run("HasManyForBulkInsert", func(t *testing.T) { + var users = []User{{ + Name: "create-1", Age: 18, Birthday: Now(), - Toys: []Toy{{Name: "toy1"}, {Name: "toy2"}}, - } + Pets: []*Pet{{Name: "pet-1-1"}, {Name: "pet-1-2"}}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Pets: []*Pet{{Name: "pet-2-1"}}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Pets: []*Pet{{Name: "pet-3-1"}, {Name: "pet-3-2"}}, + }} - if err := db.Create(&user).Error; err != nil { + if err := db.Create(&users).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } + for _, user := range users { + check(t, user) + } + }) + + t.Run("HasManyForBulkInsertPtrData", func(t *testing.T) { + var users = []*User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Pets: []*Pet{{Name: "pet-1-1"}, {Name: "pet-1-2"}}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Pets: []*Pet{{Name: "pet-2-1"}}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Pets: []*Pet{{Name: "pet-3-1"}, {Name: "pet-3-2"}}, + }} + + if err := db.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, user := range users { + check(t, *user) + } + }) + + t.Run("HasManyForBulkInsertWithoutPtr", func(t *testing.T) { + var users = []*User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Pets: []*Pet{{Name: "pet-1-1"}, {Name: "pet-1-2"}}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Pets: []*Pet{{Name: "pet-2-1"}}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Pets: []*Pet{{Name: "pet-3-1"}, {Name: "pet-3-2"}}, + }} + + if err := db.Create(users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, user := range users { + check(t, *user) + } + }) + + checkToy := func(t *testing.T, user User) { for idx, toy := range user.Toys { if toy.ID == 0 { t.Fatalf("Failed to create toy #%v", idx) @@ -443,8 +517,142 @@ func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { t.Errorf("Failed to save relation") } } + } + + t.Run("PolymorphicHasMany", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + Toys: []Toy{{Name: "toy1"}, {Name: "toy2"}}, + } + + if err := db.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + checkToy(t, user) + }) + + t.Run("PolymorphicHasManyForBulkInsert", func(t *testing.T) { + var users = []User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Toys: []Toy{{Name: "toy-1-1"}, {Name: "toy-1-2"}}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Toys: []Toy{{Name: "toy-2-1"}, {Name: "toy-2-2"}}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Toys: []Toy{{Name: "toy-3-1"}, {Name: "toy-3-2"}}, + }} + + if err := db.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, user := range users { + checkToy(t, user) + } + }) + + t.Run("PolymorphicHasManyForBulkInsertPtrData", func(t *testing.T) { + var users = []*User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Toys: []Toy{{Name: "toy-1-1"}, {Name: "toy-1-2"}}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Toys: []Toy{{Name: "toy-2-1"}, {Name: "toy-2-2"}}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Toys: []Toy{{Name: "toy-3-1"}, {Name: "toy-3-2"}}, + }} + + if err := db.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, user := range users { + checkToy(t, *user) + } + }) + + t.Run("PolymorphicHasManyForBulkInsertWithoutPtr", func(t *testing.T) { + var users = []User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Toys: []Toy{{Name: "toy-1-1"}, {Name: "toy-1-2"}}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Toys: []Toy{{Name: "toy-2-1"}, {Name: "toy-2-2"}}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Toys: []Toy{{Name: "toy-3-1"}, {Name: "toy-3-2"}}, + }} + + if err := db.Create(users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, user := range users { + checkToy(t, user) + } }) } func TestCreateMany2ManyAssociations(t *testing.T, db *gorm.DB) { + check := func(t *testing.T, user User) { + for _, language := range user.Languages { + var result Language + db.First(&result, "code = ?", language.Code) + // TODO + // if result.Name != language.Name { + // t.Errorf("Language's name should be same") + // } + } + + for _, f := range user.Friends { + if f.ID == 0 { + t.Errorf("Friend's foreign key should be saved") + } + + var result User + db.First(&result, "id = ?", f.ID) + if result.Name != f.Name { + t.Errorf("Friend's name should be same") + } + } + } + + t.Run("Many2Many", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + Languages: []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}}, + Friends: []*User{{Name: "friend-1"}, {Name: "friend-2"}}, + } + + if err := db.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + check(t, user) + }) } From 70d60ef72fccd8822c9dc54e0be492294e78c58d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 28 Apr 2020 08:05:22 +0800 Subject: [PATCH 298/881] Fix create join table --- migrator/migrator.go | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 763b4ec3..f581f714 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -25,12 +25,14 @@ type Config struct { } func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { - stmt := m.DB.Statement - if stmt == nil { - stmt = &gorm.Statement{DB: m.DB} + stmt := &gorm.Statement{DB: m.DB} + if m.DB.Statement != nil { + stmt.Table = m.DB.Statement.Table } - if err := stmt.Parse(value); err != nil { + if table, ok := value.(string); ok { + stmt.Table = table + } else if err := stmt.Parse(value); err != nil { return err } @@ -105,8 +107,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { // create join table if rel.JoinTable != nil { joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !tx.Migrator().HasTable(joinValue) { - defer tx.Migrator().CreateTable(joinValue) + if !tx.Migrator().HasTable(rel.JoinTable.Table) { + defer tx.Table(rel.JoinTable.Table).Migrator().CreateTable(joinValue) + } else { + defer tx.Table(rel.JoinTable.Table).Migrator().AutoMigrate(joinValue) } } } @@ -167,8 +171,8 @@ func (m Migrator) CreateTable(values ...interface{}) error { // create join table if rel.JoinTable != nil { joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !tx.Migrator().HasTable(joinValue) { - defer tx.Migrator().CreateTable(joinValue) + if !tx.Migrator().HasTable(rel.JoinTable.Table) { + defer tx.Table(rel.JoinTable.Table).Migrator().CreateTable(joinValue) } } } @@ -207,6 +211,7 @@ func (m Migrator) DropTable(values ...interface{}) error { func (m Migrator) HasTable(value interface{}) bool { var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count) From 85246682c81e6c6039249bd0d209d021c190566b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 29 Apr 2020 22:15:05 +0800 Subject: [PATCH 299/881] Test update associations --- tests/update.go | 249 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 249 insertions(+) diff --git a/tests/update.go b/tests/update.go index 3a94313e..82a2dc8b 100644 --- a/tests/update.go +++ b/tests/update.go @@ -1,6 +1,7 @@ package tests import ( + "fmt" "testing" "time" @@ -129,5 +130,253 @@ func TestUpdate(t *testing.T, db *gorm.DB) { } else { AssertObjEqual(t, result4, user, "Name", "Age", "Birthday") } + + TestUpdateAssociations(t, db) + }) +} + +func TestUpdateAssociations(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{}) + db.Migrator().AutoMigrate(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{}) + + TestUpdateBelongsToAssociations(t, db) + TestUpdateHasOneAssociations(t, db) + TestUpdateHasManyAssociations(t, db) + TestUpdateMany2ManyAssociations(t, db) +} + +func TestUpdateBelongsToAssociations(t *testing.T, db *gorm.DB) { + check := func(t *testing.T, user User) { + if user.Company.Name != "" { + if user.CompanyID == nil { + t.Errorf("Company's foreign key should be saved") + } else { + var company Company + db.First(&company, "id = ?", *user.CompanyID) + if company.Name != user.Company.Name { + t.Errorf("Company's name should be same") + } + } + } else if user.CompanyID != nil { + t.Errorf("Company should not be created for zero value, got: %+v", user.CompanyID) + } + + if user.Manager != nil { + if user.ManagerID == nil { + t.Errorf("Manager's foreign key should be saved") + } else { + var manager User + db.First(&manager, "id = ?", *user.ManagerID) + if manager.Name != user.Manager.Name { + t.Errorf("Manager's name should be same") + } + } + } else if user.ManagerID != nil { + t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) + } + } + + t.Run("BelongsTo", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + } + + if err := db.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Company = Company{Name: "company-belongs-to-association"} + user.Manager = &User{Name: "manager-belongs-to-association"} + if err := db.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + check(t, user) + }) +} + +func TestUpdateHasOneAssociations(t *testing.T, db *gorm.DB) { + check := func(t *testing.T, user User) { + if user.Account.ID == 0 { + t.Errorf("Account should be saved") + } else if user.Account.UserID.Int64 != int64(user.ID) { + t.Errorf("Account's foreign key should be saved") + } else { + var account Account + db.First(&account, "id = ?", user.Account.ID) + if account.Number != user.Account.Number { + t.Errorf("Account's number should be sme") + } + } + } + + t.Run("HasOne", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + } + + if err := db.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Account = Account{Number: "account-has-one-association"} + + if err := db.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + check(t, user) + }) + + checkPet := func(t *testing.T, pet Pet) { + if pet.Toy.OwnerID != fmt.Sprint(pet.ID) || pet.Toy.OwnerType != "pets" { + t.Errorf("Failed to create polymorphic has one association - toy owner id %v, owner type %v", pet.Toy.OwnerID, pet.Toy.OwnerType) + } else { + var toy Toy + db.First(&toy, "owner_id = ? and owner_type = ?", pet.Toy.OwnerID, pet.Toy.OwnerType) + if toy.Name != pet.Toy.Name { + t.Errorf("Failed to query saved polymorphic has one association") + } + } + } + + t.Run("PolymorphicHasOne", func(t *testing.T) { + var pet = Pet{ + Name: "create", + } + + if err := db.Create(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + pet.Toy = Toy{Name: "Update-HasOneAssociation-Polymorphic"} + + if err := db.Save(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + checkPet(t, pet) + }) +} + +func TestUpdateHasManyAssociations(t *testing.T, db *gorm.DB) { + check := func(t *testing.T, user User) { + for _, pet := range user.Pets { + if pet.ID == 0 { + t.Errorf("Pet's foreign key should be saved") + } + + var result Pet + db.First(&result, "id = ?", pet.ID) + if result.Name != pet.Name { + t.Errorf("Pet's name should be same") + } else if result.UserID != user.ID { + t.Errorf("Pet's foreign key should be saved") + } + } + } + + t.Run("HasMany", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + } + + if err := db.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Pets = []*Pet{{Name: "pet1"}, {Name: "pet2"}} + if err := db.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + check(t, user) + }) + + checkToy := func(t *testing.T, user User) { + for idx, toy := range user.Toys { + if toy.ID == 0 { + t.Fatalf("Failed to create toy #%v", idx) + } + + var result Toy + db.First(&result, "id = ?", toy.ID) + if result.Name != toy.Name { + t.Errorf("Failed to query saved toy") + } else if result.OwnerID != fmt.Sprint(user.ID) || result.OwnerType != "users" { + t.Errorf("Failed to save relation") + } + } + } + + t.Run("PolymorphicHasMany", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + } + + if err := db.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Toys = []Toy{{Name: "toy1"}, {Name: "toy2"}} + if err := db.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + checkToy(t, user) + }) +} + +func TestUpdateMany2ManyAssociations(t *testing.T, db *gorm.DB) { + check := func(t *testing.T, user User) { + for _, language := range user.Languages { + var result Language + db.First(&result, "code = ?", language.Code) + // TODO + // if result.Name != language.Name { + // t.Errorf("Language's name should be same") + // } + } + + for _, f := range user.Friends { + if f.ID == 0 { + t.Errorf("Friend's foreign key should be saved") + } + + var result User + db.First(&result, "id = ?", f.ID) + if result.Name != f.Name { + t.Errorf("Friend's name should be same") + } + } + } + + t.Run("Many2Many", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + } + + if err := db.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Languages = []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}} + user.Friends = []*User{{Name: "friend-1"}, {Name: "friend-2"}} + + if err := db.Save(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + check(t, user) }) } From 9dfed613db7e2cb92a6e463bf063bb8fc1f9fd83 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 29 Apr 2020 23:47:18 +0800 Subject: [PATCH 300/881] Test inner joins --- callbacks/query.go | 14 ++++++---- callbacks/scan.go | 26 +++++++++++++++-- tests/joins.go | 70 ++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 100 insertions(+), 10 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index ae22f4d0..a3b59b48 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -28,7 +28,8 @@ func Query(db *gorm.DB) { if len(db.Statement.Selects) == 0 { for _, dbName := range db.Statement.Schema.DBNames { clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Name: dbName, + Table: db.Statement.Table, + Name: dbName, }) } } @@ -37,8 +38,9 @@ func Query(db *gorm.DB) { if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { for _, s := range relation.FieldSchema.DBNames { clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: relation.FieldSchema.Table, + Table: relation.Name, Name: s, + Alias: relation.Name + "__" + s, }) } @@ -46,16 +48,16 @@ func Query(db *gorm.DB) { for _, ref := range relation.References { if ref.OwnPrimaryKey { exprs = append(exprs, clause.Expr{ - SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.PrimaryKey.DBName, relation.FieldSchema.Table, ref.ForeignKey.DBName), + SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.PrimaryKey.DBName, relation.Name, ref.ForeignKey.DBName), }) } else { if ref.PrimaryValue == "" { exprs = append(exprs, clause.Expr{ - SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.ForeignKey.DBName, relation.FieldSchema.Table, ref.PrimaryKey.DBName), + SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.ForeignKey.DBName, relation.Name, ref.PrimaryKey.DBName), }) } else { exprs = append(exprs, clause.Expr{ - SQL: fmt.Sprintf("%s.%s = ?", relation.FieldSchema.Table, ref.PrimaryKey.DBName), + SQL: fmt.Sprintf("%s.%s = ?", relation.Name, ref.PrimaryKey.DBName), Vars: []interface{}{ref.PrimaryValue}, }) } @@ -64,7 +66,7 @@ func Query(db *gorm.DB) { joins = append(joins, clause.Join{ Type: clause.LeftJoin, - Table: clause.Table{Name: relation.FieldSchema.Table}, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: relation.Name}, ON: clause.Where{Exprs: exprs}, }) } else { diff --git a/callbacks/scan.go b/callbacks/scan.go index 2bd0143c..6ea8bf23 100644 --- a/callbacks/scan.go +++ b/callbacks/scan.go @@ -3,6 +3,7 @@ package callbacks import ( "database/sql" "reflect" + "strings" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/schema" @@ -54,12 +55,21 @@ func Scan(rows *sql.Rows, db *gorm.DB) { isPtr := db.Statement.ReflectValue.Type().Elem().Kind() == reflect.Ptr db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) fields := make([]*schema.Field, len(columns)) + joinFields := make([][2]*schema.Field, len(columns)) for idx, column := range columns { if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { fields[idx] = field + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + joinFields[idx] = [2]*schema.Field{rel.Field, field} + continue + } + } + values[idx] = &sql.RawBytes{} } else { - values[idx] = sql.RawBytes{} + values[idx] = &sql.RawBytes{} } } @@ -68,6 +78,9 @@ func Scan(rows *sql.Rows, db *gorm.DB) { for idx, field := range fields { if field != nil { values[idx] = field.ReflectValueOf(elem).Addr().Interface() + } else if joinFields[idx][0] != nil { + relValue := joinFields[idx][0].ReflectValueOf(elem) + values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface() } } @@ -86,8 +99,17 @@ func Scan(rows *sql.Rows, db *gorm.DB) { for idx, column := range columns { if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { + relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + values[idx] = field.ReflectValueOf(relValue).Addr().Interface() + continue + } + } + values[idx] = &sql.RawBytes{} } else { - values[idx] = sql.RawBytes{} + values[idx] = &sql.RawBytes{} } } diff --git a/tests/joins.go b/tests/joins.go index 2a8cdc8b..86f9f104 100644 --- a/tests/joins.go +++ b/tests/joins.go @@ -7,9 +7,75 @@ import ( ) func TestJoins(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&User{}) - db.AutoMigrate(&User{}) + db.Migrator().DropTable(&User{}, &Account{}, &Company{}) + db.AutoMigrate(&User{}, &Account{}, &Company{}) + + check := func(t *testing.T, oldUser, newUser User) { + if newUser.Company.ID != oldUser.Company.ID { + t.Errorf("Company is not equal when load with joins, loaded company id: %v", newUser.Company.ID) + } + + if newUser.Manager == nil || newUser.Manager.ID != oldUser.Manager.ID { + t.Errorf("Manager is not equal when load with joins: loaded manager: %+v", newUser.Manager) + } + + if newUser.Account.ID != oldUser.Account.ID { + t.Errorf("Account is not equal when load with joins, loaded account id: %v, expect: %v", newUser.Account.ID, oldUser.Account.ID) + } + } t.Run("Joins", func(t *testing.T) { + user := User{ + Name: "joins-1", + Company: Company{Name: "company"}, + Manager: &User{Name: "manager"}, + Account: Account{Number: "account-has-one-association"}, + } + + db.Create(&user) + + var user2 User + if err := db.Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil { + t.Fatalf("Failed to load with joins, got error: %v", err) + } + + check(t, user, user2) + }) + + t.Run("JoinsForSlice", func(t *testing.T) { + users := []User{{ + Name: "slice-joins-1", + Company: Company{Name: "company"}, + Manager: &User{Name: "manager"}, + Account: Account{Number: "account-has-one-association"}, + }, { + Name: "slice-joins-2", + Company: Company{Name: "company2"}, + Manager: &User{Name: "manager2"}, + Account: Account{Number: "account-has-one-association2"}, + }, { + Name: "slice-joins-3", + Company: Company{Name: "company3"}, + Manager: &User{Name: "manager3"}, + Account: Account{Number: "account-has-one-association3"}, + }} + + db.Create(&users) + + var users2 []User + if err := db.Joins("Company").Joins("Manager").Joins("Account").Find(&users2, "users.name LIKE ?", "slice-joins%").Error; err != nil { + t.Fatalf("Failed to load with joins, got error: %v", err) + } else if len(users2) != len(users) { + t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) + } + + for _, u2 := range users2 { + for _, u := range users { + if u.Name == u2.Name { + check(t, u, u2) + continue + } + } + } }) } From 8def7be5836026f5874332c0a5992b0f43d35817 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 5 May 2020 21:28:38 +0800 Subject: [PATCH 301/881] Add context to logger --- callbacks.go | 11 ++++++----- logger/logger.go | 19 ++++++++++--------- schema/schema.go | 5 +++-- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/callbacks.go b/callbacks.go index 78f1192e..6c70b392 100644 --- a/callbacks.go +++ b/callbacks.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "errors" "fmt" "reflect" @@ -90,7 +91,7 @@ func (p *processor) Execute(db *DB) { } if stmt := db.Statement; stmt != nil { - db.Logger.Trace(curTime, func() (string, int64) { + db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected }, db.Error) @@ -141,7 +142,7 @@ func (p *processor) compile() (err error) { } if p.fns, err = sortCallbacks(p.callbacks); err != nil { - logger.Default.Error("Got error when compile callbacks, got %v", err) + logger.Default.Error(context.Background(), "Got error when compile callbacks, got %v", err) } return } @@ -164,7 +165,7 @@ func (c *callback) Register(name string, fn func(*DB)) error { } func (c *callback) Remove(name string) error { - logger.Default.Warn("removing callback `%v` from %v\n", name, utils.FileWithLineNum()) + logger.Default.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum()) c.name = name c.remove = true c.processor.callbacks = append(c.processor.callbacks, c) @@ -172,7 +173,7 @@ func (c *callback) Remove(name string) error { } func (c *callback) Replace(name string, fn func(*DB)) error { - logger.Default.Info("replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) + logger.Default.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) c.name = name c.handler = fn c.replace = true @@ -199,7 +200,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { for _, c := range cs { // show warning message the callback name already exists if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { - logger.Default.Warn("duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) + logger.Default.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) } names = append(names, c.name) } diff --git a/logger/logger.go b/logger/logger.go index ee6c0da1..24cee821 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -1,6 +1,7 @@ package logger import ( + "context" "log" "os" "time" @@ -46,10 +47,10 @@ type Config struct { // Interface logger interface type Interface interface { LogMode(LogLevel) Interface - Info(string, ...interface{}) - Warn(string, ...interface{}) - Error(string, ...interface{}) - Trace(begin time.Time, fc func() (string, int64), err error) + Info(context.Context, string, ...interface{}) + Warn(context.Context, string, ...interface{}) + Error(context.Context, string, ...interface{}) + Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) } var Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ @@ -103,35 +104,35 @@ func (l logger) LogMode(level LogLevel) Interface { } // Info print info -func (l logger) Info(msg string, data ...interface{}) { +func (l logger) Info(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Info { l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Warn print warn messages -func (l logger) Warn(msg string, data ...interface{}) { +func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Warn { l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Error print error messages -func (l logger) Error(msg string, data ...interface{}) { +func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Error { l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Trace print sql message -func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) { +func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { if l.LogLevel > 0 { elapsed := time.Now().Sub(begin) switch { case err != nil: sql, rows := fc() l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) - case elapsed > l.SlowThreshold && l.SlowThreshold != 0: + case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: sql, rows := fc() l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) case l.LogLevel >= Info: diff --git a/schema/schema.go b/schema/schema.go index 2ac6d312..3abac2ba 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -1,6 +1,7 @@ package schema import ( + "context" "errors" "fmt" "go/ast" @@ -83,7 +84,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) defer func() { if schema.err != nil { - logger.Default.Error(schema.err.Error()) + logger.Default.Error(context.Background(), schema.err.Error()) cacheStore.Delete(modelType) } }() @@ -174,7 +175,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) case "func(*gorm.DB)": // TODO hack reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) default: - logger.Default.Warn("Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name) + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name) } } } From 41697d58d3b02b26c2f9af782052e3d39578b205 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 7 May 2020 10:03:48 +0800 Subject: [PATCH 302/881] Handle preload --- callbacks/preload.go | 9 +++++++++ callbacks/query.go | 45 ++++++++++++++++++++++++++++++++++++++++++++ errors.go | 2 ++ 3 files changed, 56 insertions(+) create mode 100644 callbacks/preload.go diff --git a/callbacks/preload.go b/callbacks/preload.go new file mode 100644 index 00000000..c8dcd05e --- /dev/null +++ b/callbacks/preload.go @@ -0,0 +1,9 @@ +package callbacks + +import ( + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/schema" +) + +func preload(db *gorm.DB, preloadFields []string, rel *schema.Relationship) { +} diff --git a/callbacks/query.go b/callbacks/query.go index a3b59b48..ca9e84a9 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -3,9 +3,12 @@ package callbacks import ( "fmt" "reflect" + "sort" + "strings" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" ) func Query(db *gorm.DB) { @@ -96,6 +99,48 @@ func Query(db *gorm.DB) { } func Preload(db *gorm.DB) { + if len(db.Statement.Preloads) > 0 { + preloadMap := map[string][]string{} + + for name := range db.Statement.Preloads { + preloadFields := strings.Split(name, ".") + for idx := range preloadFields { + preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] + } + } + + preloadNames := make([]string, len(preloadMap)) + idx := 0 + for key := range preloadMap { + preloadNames[idx] = key + idx++ + } + sort.Strings(preloadNames) + + for _, name := range preloadNames { + curSchema := db.Statement.Schema + preloadFields := preloadMap[name] + + for idx, preloadField := range preloadFields { + if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { + if idx == len(preloadFields)-1 { + conds := db.Statement.Preloads[strings.Join(preloadFields[:idx+1], ".")] + + switch rel.Type { + case schema.HasOne: + case schema.HasMany: + case schema.BelongsTo: + case schema.Many2Many: + } + } else { + curSchema = rel.FieldSchema + } + } else { + db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) + } + } + } + } } func AfterQuery(db *gorm.DB) { diff --git a/errors.go b/errors.go index 32f55e01..a990cc4a 100644 --- a/errors.go +++ b/errors.go @@ -17,4 +17,6 @@ var ( ErrNotImplemented = errors.New("not implemented") // ErrMissingWhereClause missing where clause ErrMissingWhereClause = errors.New("missing WHERE clause while deleting") + // ErrUnsupportedRelation unsupported relations + ErrUnsupportedRelation = errors.New("unsupported relations") ) From b549f9bb9a877ee2dc7b20fb768c9a278fbdc5e1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 14 May 2020 12:19:12 +0800 Subject: [PATCH 303/881] Implement preload support --- callbacks/preload.go | 189 ++++++++++++++++++++++++++++++++++++++++++- callbacks/query.go | 25 +++--- statement.go | 9 +++ utils/utils.go | 22 +++++ 4 files changed, 229 insertions(+), 16 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index c8dcd05e..112f67f7 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -1,9 +1,196 @@ package callbacks import ( + "reflect" + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/gorm/utils" ) -func preload(db *gorm.DB, preloadFields []string, rel *schema.Relationship) { +// getRelationsValue get relations's value from a reflect value +func getRelationsValue(reflectValue reflect.Value, rels []*schema.Relationship) (reflectResults reflect.Value) { + for _, rel := range rels { + reflectResults = reflect.MakeSlice(reflect.SliceOf(rel.FieldSchema.ModelType), 0, 0) + + appendToResults := func(value reflect.Value) { + if _, isZero := rel.Field.ValueOf(value); !isZero { + result := reflect.Indirect(rel.Field.ReflectValueOf(value)) + switch result.Kind() { + case reflect.Struct: + reflectResults = reflect.Append(reflectResults, result) + case reflect.Slice, reflect.Array: + for i := 0; i < value.Len(); i++ { + reflectResults = reflect.Append(reflectResults, reflect.Indirect(value.Index(i))) + } + } + } + } + + switch reflectValue.Kind() { + case reflect.Struct: + appendToResults(reflectValue) + case reflect.Slice: + for i := 0; i < reflectValue.Len(); i++ { + appendToResults(reflectValue.Index(i)) + } + } + + reflectValue = reflectResults + } + + return +} + +func getIdentityFieldValuesMap(reflectValue reflect.Value, fields []*schema.Field) (map[string][]reflect.Value, [][]interface{}) { + var ( + fieldValues = make([]reflect.Value, len(fields)) + results = [][]interface{}{} + dataResults = map[string][]reflect.Value{} + ) + + switch reflectValue.Kind() { + case reflect.Struct: + results = [][]interface{}{make([]interface{}, len(fields))} + + for idx, field := range fields { + fieldValues[idx] = field.ReflectValueOf(reflectValue) + results[0][idx] = fieldValues[idx].Interface() + } + + dataResults[utils.ToStringKey(fieldValues...)] = []reflect.Value{reflectValue} + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + for idx, field := range fields { + fieldValues[idx] = field.ReflectValueOf(reflectValue.Index(i)) + } + + dataKey := utils.ToStringKey(fieldValues...) + if _, ok := dataResults[dataKey]; !ok { + result := make([]interface{}, len(fieldValues)) + for idx, fieldValue := range fieldValues { + result[idx] = fieldValue.Interface() + } + results = append(results, result) + + dataResults[dataKey] = []reflect.Value{reflectValue.Index(i)} + } else { + dataResults[dataKey] = append(dataResults[dataKey], reflectValue.Index(i)) + } + } + } + + return dataResults, results +} + +func preloadData(tx *gorm.DB, resultSchema *schema.Schema, foreignKeys []string, foreignValues [][]interface{}) reflect.Value { + results := reflect.MakeSlice(reflect.SliceOf(resultSchema.ModelType), 0, 0) + queryValues := make([]interface{}, len(foreignValues)) + if len(foreignKeys) == 1 { + for idx, r := range foreignValues { + queryValues[idx] = r[0] + } + tx.Where(clause.IN{Column: foreignKeys[0], Values: queryValues}).Find(results.Addr().Interface()) + } else { + for idx, r := range foreignValues { + queryValues[idx] = r + } + tx.Where(clause.IN{Column: foreignKeys, Values: queryValues}).Find(results.Addr().Interface()) + } + + return results +} + +func preload(tx *gorm.DB, rels []*schema.Relationship, conds []interface{}) { + var ( + reflectValue = tx.Statement.ReflectValue + rel = rels[len(rels)-1] + relForeignKeys []string + relForeignFields []*schema.Field + foreignFields []*schema.Field + foreignValues [][]interface{} + identityMap = map[string][]reflect.Value{} + ) + + if len(rels) > 1 { + reflectValue = getRelationsValue(reflectValue, rels[:len(rels)]) + } + + if rel.JoinTable != nil { + var joinForeignFields, joinRelForeignFields []*schema.Field + var joinForeignKeys []string + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName) + joinForeignFields = append(joinForeignFields, ref.ForeignKey) + foreignFields = append(foreignFields, ref.PrimaryKey) + } else if ref.PrimaryValue != "" { + tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } else { + joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey) + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + relForeignFields = append(relForeignFields, ref.PrimaryKey) + } + } + + joinIdentityMap, joinForeignValues := getIdentityFieldValuesMap(reflectValue, joinForeignFields) + joinResults := preloadData(tx, rel.JoinTable, joinForeignKeys, joinForeignValues) + + // convert join identity map to relation identity map + fieldValues := make([]reflect.Value, len(foreignFields)) + joinFieldValues := make([]reflect.Value, len(joinForeignFields)) + for i := 0; i < joinResults.Len(); i++ { + for idx, field := range foreignFields { + fieldValues[idx] = field.ReflectValueOf(joinResults.Index(i)) + } + + for idx, field := range joinForeignFields { + joinFieldValues[idx] = field.ReflectValueOf(joinResults.Index(i)) + } + + if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { + identityMap[utils.ToStringKey(joinFieldValues...)] = results + } + } + + _, foreignValues = getIdentityFieldValuesMap(joinResults, joinRelForeignFields) + } else { + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + relForeignFields = append(relForeignFields, ref.ForeignKey) + foreignFields = append(foreignFields, ref.PrimaryKey) + } else if ref.PrimaryValue != "" { + tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } else { + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + relForeignFields = append(relForeignFields, ref.PrimaryKey) + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + + identityMap, foreignValues = getIdentityFieldValuesMap(reflectValue, foreignFields) + } + + reflectResults := preloadData(tx, rel.FieldSchema, relForeignKeys, foreignValues) + + fieldValues := make([]reflect.Value, len(foreignFields)) + for i := 0; i < reflectResults.Len(); i++ { + for idx, field := range foreignFields { + fieldValues[idx] = field.ReflectValueOf(reflectResults.Index(i)) + } + + for _, data := range identityMap[utils.ToStringKey(fieldValues...)] { + reflectFieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) + switch reflectFieldValue.Kind() { + case reflect.Struct: + elem := reflectResults.Index(i).Convert(reflectFieldValue.Type().Elem()) + rel.Field.Set(data, elem.Interface()) + case reflect.Slice, reflect.Array: + elem := reflectResults.Index(i).Convert(reflectFieldValue.Type().Elem()) + rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) + } + } + } } diff --git a/callbacks/query.go b/callbacks/query.go index ca9e84a9..2c187868 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -25,6 +25,7 @@ func Query(db *gorm.DB) { } } + // inline joins if len(db.Statement.Joins) != 0 { joins := []clause.Join{} @@ -101,7 +102,6 @@ func Query(db *gorm.DB) { func Preload(db *gorm.DB) { if len(db.Statement.Preloads) > 0 { preloadMap := map[string][]string{} - for name := range db.Statement.Preloads { preloadFields := strings.Split(name, ".") for idx := range preloadFields { @@ -118,27 +118,22 @@ func Preload(db *gorm.DB) { sort.Strings(preloadNames) for _, name := range preloadNames { - curSchema := db.Statement.Schema - preloadFields := preloadMap[name] + var ( + curSchema = db.Statement.Schema + preloadFields = preloadMap[name] + rels = make([]*schema.Relationship, len(preloadFields)) + ) for idx, preloadField := range preloadFields { if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { - if idx == len(preloadFields)-1 { - conds := db.Statement.Preloads[strings.Join(preloadFields[:idx+1], ".")] - - switch rel.Type { - case schema.HasOne: - case schema.HasMany: - case schema.BelongsTo: - case schema.Many2Many: - } - } else { - curSchema = rel.FieldSchema - } + rels[idx] = rel + curSchema = rel.FieldSchema } else { db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) } } + + preload(db.Session(&gorm.Session{}), rels, db.Statement.Preloads[name]) } } } diff --git a/statement.go b/statement.go index 3f2ceca3..f3090eb7 100644 --- a/statement.go +++ b/statement.go @@ -95,6 +95,15 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { } case string: stmt.DB.Dialector.QuoteTo(writer, v) + case []string: + writer.WriteByte('(') + for idx, d := range v { + if idx != 0 { + writer.WriteString(",") + } + stmt.DB.Dialector.QuoteTo(writer, d) + } + writer.WriteByte(')') default: stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field)) } diff --git a/utils/utils.go b/utils/utils.go index 8dd500a5..f3dedec2 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -5,6 +5,7 @@ import ( "reflect" "regexp" "runtime" + "strconv" "strings" "unicode" ) @@ -38,3 +39,24 @@ func CheckTruth(val interface{}) bool { return !reflect.ValueOf(val).IsZero() } + +func ToStringKey(values ...reflect.Value) string { + results := make([]string, len(values)) + + for idx, value := range values { + rv := reflect.Indirect(value).Interface() + + switch v := rv.(type) { + case string: + results[idx] = v + case []byte: + results[idx] = string(v) + case uint: + results[idx] = strconv.FormatUint(uint64(v), 10) + default: + results[idx] = fmt.Sprint(v) + } + } + + return strings.Join(results, "_") +} From 42aae572401c223274a13a0b4f3775c2d8f35e9f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 14 May 2020 13:48:51 +0800 Subject: [PATCH 304/881] Test Preload for BelongsTo/HasOne/HasMany --- callbacks/preload.go | 28 +++++---- callbacks/query.go | 2 +- tests/create.go | 138 +++++++++++++++++++++++++++++++++---------- utils/utils.go | 4 ++ 4 files changed, 130 insertions(+), 42 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index 112f67f7..8ab014f6 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -85,27 +85,31 @@ func getIdentityFieldValuesMap(reflectValue reflect.Value, fields []*schema.Fiel } func preloadData(tx *gorm.DB, resultSchema *schema.Schema, foreignKeys []string, foreignValues [][]interface{}) reflect.Value { - results := reflect.MakeSlice(reflect.SliceOf(resultSchema.ModelType), 0, 0) + slice := reflect.MakeSlice(reflect.SliceOf(resultSchema.ModelType), 0, 0) + results := reflect.New(slice.Type()) + results.Elem().Set(slice) + queryValues := make([]interface{}, len(foreignValues)) if len(foreignKeys) == 1 { for idx, r := range foreignValues { queryValues[idx] = r[0] } - tx.Where(clause.IN{Column: foreignKeys[0], Values: queryValues}).Find(results.Addr().Interface()) + tx.Where(clause.IN{Column: foreignKeys[0], Values: queryValues}).Find(results.Interface()) } else { for idx, r := range foreignValues { queryValues[idx] = r } - tx.Where(clause.IN{Column: foreignKeys, Values: queryValues}).Find(results.Addr().Interface()) + tx.Where(clause.IN{Column: foreignKeys, Values: queryValues}).Find(results.Interface()) } - return results + return results.Elem() } -func preload(tx *gorm.DB, rels []*schema.Relationship, conds []interface{}) { +func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { var ( - reflectValue = tx.Statement.ReflectValue + reflectValue = db.Statement.ReflectValue rel = rels[len(rels)-1] + tx = db.Session(&gorm.Session{}) relForeignKeys []string relForeignFields []*schema.Field foreignFields []*schema.Field @@ -177,7 +181,7 @@ func preload(tx *gorm.DB, rels []*schema.Relationship, conds []interface{}) { fieldValues := make([]reflect.Value, len(foreignFields)) for i := 0; i < reflectResults.Len(); i++ { - for idx, field := range foreignFields { + for idx, field := range relForeignFields { fieldValues[idx] = field.ReflectValueOf(reflectResults.Index(i)) } @@ -185,11 +189,13 @@ func preload(tx *gorm.DB, rels []*schema.Relationship, conds []interface{}) { reflectFieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) switch reflectFieldValue.Kind() { case reflect.Struct: - elem := reflectResults.Index(i).Convert(reflectFieldValue.Type().Elem()) - rel.Field.Set(data, elem.Interface()) + rel.Field.Set(data, reflectResults.Index(i).Interface()) case reflect.Slice, reflect.Array: - elem := reflectResults.Index(i).Convert(reflectFieldValue.Type().Elem()) - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) + if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { + rel.Field.Set(data, reflect.Append(reflectFieldValue, reflectResults.Index(i).Addr()).Interface()) + } else { + rel.Field.Set(data, reflect.Append(reflectFieldValue, reflectResults.Index(i)).Interface()) + } } } } diff --git a/callbacks/query.go b/callbacks/query.go index 2c187868..4a89c575 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -133,7 +133,7 @@ func Preload(db *gorm.DB) { } } - preload(db.Session(&gorm.Session{}), rels, db.Statement.Preloads[name]) + preload(db, rels, db.Statement.Preloads[name]) } } } diff --git a/tests/create.go b/tests/create.go index 27ad7a49..45cd9794 100644 --- a/tests/create.go +++ b/tests/create.go @@ -56,14 +56,16 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) { } func TestCreateBelongsToAssociations(t *testing.T, db *gorm.DB) { - check := func(t *testing.T, user User) { - if user.Company.Name != "" { + check := func(t *testing.T, user User, old User) { + if old.Company.Name != "" { if user.CompanyID == nil { t.Errorf("Company's foreign key should be saved") } else { var company Company db.First(&company, "id = ?", *user.CompanyID) - if company.Name != user.Company.Name { + if company.Name != old.Company.Name { + t.Errorf("Company's name should be same") + } else if user.Company.Name != old.Company.Name { t.Errorf("Company's name should be same") } } @@ -71,7 +73,7 @@ func TestCreateBelongsToAssociations(t *testing.T, db *gorm.DB) { t.Errorf("Company should not be created for zero value, got: %+v", user.CompanyID) } - if user.Manager != nil { + if old.Manager != nil { if user.ManagerID == nil { t.Errorf("Manager's foreign key should be saved") } else { @@ -79,6 +81,8 @@ func TestCreateBelongsToAssociations(t *testing.T, db *gorm.DB) { db.First(&manager, "id = ?", *user.ManagerID) if manager.Name != user.Manager.Name { t.Errorf("Manager's name should be same") + } else if user.Manager.Name != old.Manager.Name { + t.Errorf("Manager's name should be same") } } } else if user.ManagerID != nil { @@ -99,7 +103,11 @@ func TestCreateBelongsToAssociations(t *testing.T, db *gorm.DB) { t.Fatalf("errors happened when create: %v", err) } - check(t, user) + check(t, user, user) + + var user2 User + db.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID) + check(t, user2, user) }) t.Run("BelongsToForBulkInsert", func(t *testing.T) { @@ -126,8 +134,22 @@ func TestCreateBelongsToAssociations(t *testing.T, db *gorm.DB) { t.Fatalf("errors happened when create: %v", err) } + var userIDs []uint for _, user := range users { - check(t, user) + userIDs = append(userIDs, user.ID) + check(t, user, user) + } + + var users2 []User + db.Preload("Company").Preload("Manager").Find(&users2, "id IN (?)", userIDs) + for idx, user := range users2 { + check(t, user, users[idx]) + } + + var users3 []User + db.Preload("Company").Preload("Manager").Find(users3, "id IN (?)", userIDs) + for idx, user := range users3 { + check(t, user, users[idx]) } }) @@ -156,7 +178,7 @@ func TestCreateBelongsToAssociations(t *testing.T, db *gorm.DB) { } for _, user := range users { - check(t, *user) + check(t, *user, *user) } }) @@ -185,13 +207,13 @@ func TestCreateBelongsToAssociations(t *testing.T, db *gorm.DB) { } for _, user := range users { - check(t, *user) + check(t, *user, *user) } }) } func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) { - check := func(t *testing.T, user User) { + check := func(t *testing.T, user User, old User) { if user.Account.ID == 0 { t.Errorf("Account should be saved") } else if user.Account.UserID.Int64 != int64(user.ID) { @@ -200,7 +222,9 @@ func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) { var account Account db.First(&account, "id = ?", user.Account.ID) if account.Number != user.Account.Number { - t.Errorf("Account's number should be sme") + t.Errorf("Account's number should be same") + } else if user.Account.Number != old.Account.Number { + t.Errorf("Account's number should be same") } } } @@ -217,7 +241,11 @@ func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) { t.Fatalf("errors happened when create: %v", err) } - check(t, user) + check(t, user, user) + + var user2 User + db.Preload("Account").Find(&user2, "id = ?", user.ID) + check(t, user2, user) }) t.Run("HasOneForBulkInsert", func(t *testing.T) { @@ -242,8 +270,16 @@ func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) { t.Fatalf("errors happened when create: %v", err) } + var userIDs []uint for _, user := range users { - check(t, user) + userIDs = append(userIDs, user.ID) + check(t, user, user) + } + + var users2 []User + db.Preload("Account").Find(&users2, "id IN (?)", userIDs) + for idx, user := range users2 { + check(t, user, users[idx]) } }) @@ -270,7 +306,7 @@ func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) { } for _, user := range users { - check(t, *user) + check(t, *user, *user) } }) @@ -297,11 +333,11 @@ func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) { } for _, user := range users { - check(t, user) + check(t, user, user) } }) - checkPet := func(t *testing.T, pet Pet) { + checkPet := func(t *testing.T, pet Pet, old Pet) { if pet.Toy.OwnerID != fmt.Sprint(pet.ID) || pet.Toy.OwnerType != "pets" { t.Errorf("Failed to create polymorphic has one association - toy owner id %v, owner type %v", pet.Toy.OwnerID, pet.Toy.OwnerType) } else { @@ -309,6 +345,8 @@ func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) { db.First(&toy, "owner_id = ? and owner_type = ?", pet.Toy.OwnerID, pet.Toy.OwnerType) if toy.Name != pet.Toy.Name { t.Errorf("Failed to query saved polymorphic has one association") + } else if old.Toy.Name != pet.Toy.Name { + t.Errorf("Failed to query saved polymorphic has one association") } } } @@ -323,7 +361,11 @@ func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) { t.Fatalf("errors happened when create: %v", err) } - checkPet(t, pet) + checkPet(t, pet, pet) + + var pet2 Pet + db.Preload("Toy").Find(&pet2, "id = ?", pet.ID) + checkPet(t, pet2, pet) }) t.Run("PolymorphicHasOneForBulkInsert", func(t *testing.T) { @@ -342,8 +384,16 @@ func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) { t.Fatalf("errors happened when create: %v", err) } + var petIDs []uint for _, pet := range pets { - checkPet(t, pet) + petIDs = append(petIDs, pet.ID) + checkPet(t, pet, pet) + } + + var pets2 []Pet + db.Preload("Toy").Find(&pets2, "id IN (?)", petIDs) + for idx, pet := range pets2 { + checkPet(t, pet, pets[idx]) } }) @@ -364,7 +414,7 @@ func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) { } for _, pet := range pets { - checkPet(t, *pet) + checkPet(t, *pet, *pet) } }) @@ -385,14 +435,14 @@ func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) { } for _, pet := range pets { - checkPet(t, *pet) + checkPet(t, *pet, *pet) } }) } func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { - check := func(t *testing.T, user User) { - for _, pet := range user.Pets { + check := func(t *testing.T, user User, old User) { + for idx, pet := range user.Pets { if pet.ID == 0 { t.Errorf("Pet's foreign key should be saved") } @@ -403,6 +453,8 @@ func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { t.Errorf("Pet's name should be same") } else if result.UserID != user.ID { t.Errorf("Pet's foreign key should be saved") + } else if result.Name != old.Pets[idx].Name { + t.Errorf("Pet's name should be same") } } } @@ -419,7 +471,11 @@ func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { t.Fatalf("errors happened when create: %v", err) } - check(t, user) + check(t, user, user) + + var user2 User + db.Preload("Pets").Find(&user2, "id = ?", user.ID) + check(t, user2, user) }) t.Run("HasManyForBulkInsert", func(t *testing.T) { @@ -444,8 +500,16 @@ func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { t.Fatalf("errors happened when create: %v", err) } + var userIDs []uint for _, user := range users { - check(t, user) + userIDs = append(userIDs, user.ID) + check(t, user, user) + } + + var users2 []User + db.Preload("Pets").Find(&users2, "id IN (?)", userIDs) + for idx, user := range users2 { + check(t, user, users[idx]) } }) @@ -472,7 +536,7 @@ func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { } for _, user := range users { - check(t, *user) + check(t, *user, *user) } }) @@ -499,11 +563,11 @@ func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { } for _, user := range users { - check(t, *user) + check(t, *user, *user) } }) - checkToy := func(t *testing.T, user User) { + checkToy := func(t *testing.T, user User, old User) { for idx, toy := range user.Toys { if toy.ID == 0 { t.Fatalf("Failed to create toy #%v", idx) @@ -513,6 +577,8 @@ func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { db.First(&result, "id = ?", toy.ID) if result.Name != toy.Name { t.Errorf("Failed to query saved toy") + } else if result.Name != old.Toys[idx].Name { + t.Errorf("Failed to query saved toy") } else if result.OwnerID != fmt.Sprint(user.ID) || result.OwnerType != "users" { t.Errorf("Failed to save relation") } @@ -531,7 +597,11 @@ func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { t.Fatalf("errors happened when create: %v", err) } - checkToy(t, user) + checkToy(t, user, user) + + var user2 User + db.Preload("Toys").Find(&user2, "id = ?", user.ID) + check(t, user2, user) }) t.Run("PolymorphicHasManyForBulkInsert", func(t *testing.T) { @@ -556,8 +626,16 @@ func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { t.Fatalf("errors happened when create: %v", err) } + var userIDs []uint for _, user := range users { - checkToy(t, user) + userIDs = append(userIDs, user.ID) + checkToy(t, user, user) + } + + var users2 []User + db.Preload("Toys").Find(&users2, "id IN (?)", userIDs) + for idx, user := range users2 { + check(t, user, users[idx]) } }) @@ -584,7 +662,7 @@ func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { } for _, user := range users { - checkToy(t, *user) + checkToy(t, *user, *user) } }) @@ -611,7 +689,7 @@ func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { } for _, user := range users { - checkToy(t, user) + checkToy(t, user, user) } }) } diff --git a/utils/utils.go b/utils/utils.go index f3dedec2..5d6c9da2 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -1,6 +1,7 @@ package utils import ( + "database/sql/driver" "fmt" "reflect" "regexp" @@ -45,6 +46,9 @@ func ToStringKey(values ...reflect.Value) string { for idx, value := range values { rv := reflect.Indirect(value).Interface() + if valuer, ok := rv.(driver.Valuer); ok { + rv, _ = valuer.Value() + } switch v := rv.(type) { case string: From 92b812408c034faa5b03c503512036f0d529e848 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 14 May 2020 15:05:04 +0800 Subject: [PATCH 305/881] Test many2many associations --- tests/create.go | 67 +++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 59 insertions(+), 8 deletions(-) diff --git a/tests/create.go b/tests/create.go index 45cd9794..428f876c 100644 --- a/tests/create.go +++ b/tests/create.go @@ -695,17 +695,18 @@ func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { } func TestCreateMany2ManyAssociations(t *testing.T, db *gorm.DB) { - check := func(t *testing.T, user User) { - for _, language := range user.Languages { + check := func(t *testing.T, user User, old User) { + for idx, language := range user.Languages { var result Language db.First(&result, "code = ?", language.Code) - // TODO - // if result.Name != language.Name { - // t.Errorf("Language's name should be same") - // } + if result.Name != language.Name { + t.Errorf("Language's name should be same") + } else if result.Name != old.Languages[idx].Name { + t.Errorf("Language's name should be same") + } } - for _, f := range user.Friends { + for idx, f := range user.Friends { if f.ID == 0 { t.Errorf("Friend's foreign key should be saved") } @@ -714,10 +715,14 @@ func TestCreateMany2ManyAssociations(t *testing.T, db *gorm.DB) { db.First(&result, "id = ?", f.ID) if result.Name != f.Name { t.Errorf("Friend's name should be same") + } else if result.Name != old.Friends[idx].Name { + t.Errorf("Language's name should be same") } } } + db.Create(&[]Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}}) + t.Run("Many2Many", func(t *testing.T) { var user = User{ Name: "create", @@ -731,6 +736,52 @@ func TestCreateMany2ManyAssociations(t *testing.T, db *gorm.DB) { t.Fatalf("errors happened when create: %v", err) } - check(t, user) + check(t, user, user) + + var user2 User + db.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) + check(t, user2, user) + }) + + t.Run("Many2ManyForBulkInsert", func(t *testing.T) { + var users = []User{ + { + Name: "create-1", + Age: 18, + Birthday: Now(), + Languages: []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}}, + Friends: []*User{{Name: "friend-1-1"}, {Name: "friend-1-2"}}, + }, + { + Name: "create-2", + Age: 18, + Birthday: Now(), + Languages: []Language{{Code: "zh-CN", Name: "Chinese"}}, + Friends: []*User{{Name: "friend-2-1"}}, + }, + { + Name: "create-3", + Age: 18, + Birthday: Now(), + Languages: []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}}, + Friends: []*User{{Name: "friend-3-1"}, {Name: "friend-3-2"}, {Name: "friend-3-3"}}, + }, + } + + if err := db.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + check(t, user, user) + } + + var users2 []User + db.Preload("Languages").Preload("Friends").Find(&users2, "id IN (?)", userIDs) + for idx, user := range users2 { + check(t, user, users[idx]) + } }) } From f999240e106552c62eef70d29d1da93d95f76a5b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 14 May 2020 20:54:50 +0800 Subject: [PATCH 306/881] Define association API, add conds to when preloading --- association.go | 54 +++++++++++++++++++++++++++++++++++++++++++- callbacks/preload.go | 10 ++++---- 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/association.go b/association.go index 14bc54b6..a9345255 100644 --- a/association.go +++ b/association.go @@ -1,9 +1,61 @@ package gorm +import ( + "fmt" + + "github.com/jinzhu/gorm/schema" +) + // Association Mode contains some helper methods to handle relationship things easily. type Association struct { + DB *DB + Relationship *schema.Relationship + Error error } func (db *DB) Association(column string) *Association { - return nil + association := &Association{DB: db} + + if err := db.Statement.Parse(db.Statement.Model); err == nil { + association.Relationship = db.Statement.Schema.Relationships.Relations[column] + + if association.Relationship == nil { + association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column) + } + } else { + association.Error = err + } + + return association +} + +func (association *Association) Find(out interface{}, conds ...interface{}) error { + if association.Error == nil { + for _, ref := range association.Relationship.References { + if ref.OwnPrimaryKey { + } + } + } + + return association.Error +} + +func (association *Association) Append(values ...interface{}) error { + return association.Error +} + +func (association *Association) Replace(values ...interface{}) error { + return association.Error +} + +func (association *Association) Delete(values ...interface{}) error { + return association.Error +} + +func (association *Association) Clear() error { + return association.Error +} + +func (association *Association) Count() int { + return 0 } diff --git a/callbacks/preload.go b/callbacks/preload.go index 8ab014f6..aaac31b5 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -84,7 +84,7 @@ func getIdentityFieldValuesMap(reflectValue reflect.Value, fields []*schema.Fiel return dataResults, results } -func preloadData(tx *gorm.DB, resultSchema *schema.Schema, foreignKeys []string, foreignValues [][]interface{}) reflect.Value { +func preloadData(tx *gorm.DB, resultSchema *schema.Schema, foreignKeys []string, foreignValues [][]interface{}, conds []interface{}) reflect.Value { slice := reflect.MakeSlice(reflect.SliceOf(resultSchema.ModelType), 0, 0) results := reflect.New(slice.Type()) results.Elem().Set(slice) @@ -94,12 +94,12 @@ func preloadData(tx *gorm.DB, resultSchema *schema.Schema, foreignKeys []string, for idx, r := range foreignValues { queryValues[idx] = r[0] } - tx.Where(clause.IN{Column: foreignKeys[0], Values: queryValues}).Find(results.Interface()) + tx.Where(clause.IN{Column: foreignKeys[0], Values: queryValues}).Find(results.Interface(), conds...) } else { for idx, r := range foreignValues { queryValues[idx] = r } - tx.Where(clause.IN{Column: foreignKeys, Values: queryValues}).Find(results.Interface()) + tx.Where(clause.IN{Column: foreignKeys, Values: queryValues}).Find(results.Interface(), conds...) } return results.Elem() @@ -139,7 +139,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } joinIdentityMap, joinForeignValues := getIdentityFieldValuesMap(reflectValue, joinForeignFields) - joinResults := preloadData(tx, rel.JoinTable, joinForeignKeys, joinForeignValues) + joinResults := preloadData(tx, rel.JoinTable, joinForeignKeys, joinForeignValues, nil) // convert join identity map to relation identity map fieldValues := make([]reflect.Value, len(foreignFields)) @@ -177,7 +177,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { identityMap, foreignValues = getIdentityFieldValuesMap(reflectValue, foreignFields) } - reflectResults := preloadData(tx, rel.FieldSchema, relForeignKeys, foreignValues) + reflectResults := preloadData(tx, rel.FieldSchema, relForeignKeys, foreignValues, conds) fieldValues := make([]reflect.Value, len(foreignFields)) for i := 0; i < reflectResults.Len(); i++ { From 59365b776b061ea8dce6f29014b35cb1789d85f8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 18 May 2020 13:07:11 +0800 Subject: [PATCH 307/881] Refacotr Preload --- callbacks/preload.go | 113 +++++-------------------------------------- schema/schema.go | 7 +++ schema/utils.go | 95 ++++++++++++++++++++++++++++++++++++ 3 files changed, 113 insertions(+), 102 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index aaac31b5..9f23a2ca 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -9,102 +9,6 @@ import ( "github.com/jinzhu/gorm/utils" ) -// getRelationsValue get relations's value from a reflect value -func getRelationsValue(reflectValue reflect.Value, rels []*schema.Relationship) (reflectResults reflect.Value) { - for _, rel := range rels { - reflectResults = reflect.MakeSlice(reflect.SliceOf(rel.FieldSchema.ModelType), 0, 0) - - appendToResults := func(value reflect.Value) { - if _, isZero := rel.Field.ValueOf(value); !isZero { - result := reflect.Indirect(rel.Field.ReflectValueOf(value)) - switch result.Kind() { - case reflect.Struct: - reflectResults = reflect.Append(reflectResults, result) - case reflect.Slice, reflect.Array: - for i := 0; i < value.Len(); i++ { - reflectResults = reflect.Append(reflectResults, reflect.Indirect(value.Index(i))) - } - } - } - } - - switch reflectValue.Kind() { - case reflect.Struct: - appendToResults(reflectValue) - case reflect.Slice: - for i := 0; i < reflectValue.Len(); i++ { - appendToResults(reflectValue.Index(i)) - } - } - - reflectValue = reflectResults - } - - return -} - -func getIdentityFieldValuesMap(reflectValue reflect.Value, fields []*schema.Field) (map[string][]reflect.Value, [][]interface{}) { - var ( - fieldValues = make([]reflect.Value, len(fields)) - results = [][]interface{}{} - dataResults = map[string][]reflect.Value{} - ) - - switch reflectValue.Kind() { - case reflect.Struct: - results = [][]interface{}{make([]interface{}, len(fields))} - - for idx, field := range fields { - fieldValues[idx] = field.ReflectValueOf(reflectValue) - results[0][idx] = fieldValues[idx].Interface() - } - - dataResults[utils.ToStringKey(fieldValues...)] = []reflect.Value{reflectValue} - case reflect.Slice, reflect.Array: - for i := 0; i < reflectValue.Len(); i++ { - for idx, field := range fields { - fieldValues[idx] = field.ReflectValueOf(reflectValue.Index(i)) - } - - dataKey := utils.ToStringKey(fieldValues...) - if _, ok := dataResults[dataKey]; !ok { - result := make([]interface{}, len(fieldValues)) - for idx, fieldValue := range fieldValues { - result[idx] = fieldValue.Interface() - } - results = append(results, result) - - dataResults[dataKey] = []reflect.Value{reflectValue.Index(i)} - } else { - dataResults[dataKey] = append(dataResults[dataKey], reflectValue.Index(i)) - } - } - } - - return dataResults, results -} - -func preloadData(tx *gorm.DB, resultSchema *schema.Schema, foreignKeys []string, foreignValues [][]interface{}, conds []interface{}) reflect.Value { - slice := reflect.MakeSlice(reflect.SliceOf(resultSchema.ModelType), 0, 0) - results := reflect.New(slice.Type()) - results.Elem().Set(slice) - - queryValues := make([]interface{}, len(foreignValues)) - if len(foreignKeys) == 1 { - for idx, r := range foreignValues { - queryValues[idx] = r[0] - } - tx.Where(clause.IN{Column: foreignKeys[0], Values: queryValues}).Find(results.Interface(), conds...) - } else { - for idx, r := range foreignValues { - queryValues[idx] = r - } - tx.Where(clause.IN{Column: foreignKeys, Values: queryValues}).Find(results.Interface(), conds...) - } - - return results.Elem() -} - func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { var ( reflectValue = db.Statement.ReflectValue @@ -118,7 +22,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { ) if len(rels) > 1 { - reflectValue = getRelationsValue(reflectValue, rels[:len(rels)]) + reflectValue = schema.GetRelationsValues(reflectValue, rels[:len(rels)]) } if rel.JoinTable != nil { @@ -138,8 +42,11 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } } - joinIdentityMap, joinForeignValues := getIdentityFieldValuesMap(reflectValue, joinForeignFields) - joinResults := preloadData(tx, rel.JoinTable, joinForeignKeys, joinForeignValues, nil) + joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, joinForeignFields) + + joinResults := rel.JoinTable.MakeSlice().Elem() + column, values := schema.ToQueryValues(joinForeignKeys, joinForeignValues) + tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()) // convert join identity map to relation identity map fieldValues := make([]reflect.Value, len(foreignFields)) @@ -158,7 +65,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } } - _, foreignValues = getIdentityFieldValuesMap(joinResults, joinRelForeignFields) + _, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields) } else { for _, ref := range rel.References { if ref.OwnPrimaryKey { @@ -174,10 +81,12 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } } - identityMap, foreignValues = getIdentityFieldValuesMap(reflectValue, foreignFields) + identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) } - reflectResults := preloadData(tx, rel.FieldSchema, relForeignKeys, foreignValues, conds) + reflectResults := rel.FieldSchema.MakeSlice().Elem() + column, values := schema.ToQueryValues(relForeignKeys, foreignValues) + tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), conds...) fieldValues := make([]reflect.Value, len(foreignFields)) for i := 0; i < reflectResults.Len(); i++ { diff --git a/schema/schema.go b/schema/schema.go index 3abac2ba..5a28797b 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -43,6 +43,13 @@ func (schema Schema) String() string { return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name()) } +func (schema Schema) MakeSlice() reflect.Value { + slice := reflect.MakeSlice(reflect.SliceOf(schema.ModelType), 0, 0) + results := reflect.New(slice.Type()) + results.Elem().Set(slice) + return results +} + func (schema Schema) LookUpField(name string) *Field { if field, ok := schema.FieldsByDBName[name]; ok { return field diff --git a/schema/utils.go b/schema/utils.go index 7be78bc5..7a26332d 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -4,6 +4,8 @@ import ( "reflect" "regexp" "strings" + + "github.com/jinzhu/gorm/utils" ) func ParseTagSetting(str string, sep string) map[string]string { @@ -49,3 +51,96 @@ func toColumns(val string) (results []string) { func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag { return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`:.*?)(;|("))`).ReplaceAllString(string(tag), "${1}${4}")) } + +// GetRelationsValues get relations's values from a reflect value +func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { + for _, rel := range rels { + reflectResults = reflect.MakeSlice(reflect.SliceOf(rel.FieldSchema.ModelType), 0, 0) + + appendToResults := func(value reflect.Value) { + if _, isZero := rel.Field.ValueOf(value); !isZero { + result := reflect.Indirect(rel.Field.ReflectValueOf(value)) + switch result.Kind() { + case reflect.Struct: + reflectResults = reflect.Append(reflectResults, result) + case reflect.Slice, reflect.Array: + for i := 0; i < value.Len(); i++ { + reflectResults = reflect.Append(reflectResults, reflect.Indirect(value.Index(i))) + } + } + } + } + + switch reflectValue.Kind() { + case reflect.Struct: + appendToResults(reflectValue) + case reflect.Slice: + for i := 0; i < reflectValue.Len(); i++ { + appendToResults(reflectValue.Index(i)) + } + } + + reflectValue = reflectResults + } + + return +} + +// GetIdentityFieldValuesMap get identity map from fields +func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { + var ( + fieldValues = make([]reflect.Value, len(fields)) + results = [][]interface{}{} + dataResults = map[string][]reflect.Value{} + ) + + switch reflectValue.Kind() { + case reflect.Struct: + results = [][]interface{}{make([]interface{}, len(fields))} + + for idx, field := range fields { + fieldValues[idx] = field.ReflectValueOf(reflectValue) + results[0][idx] = fieldValues[idx].Interface() + } + + dataResults[utils.ToStringKey(fieldValues...)] = []reflect.Value{reflectValue} + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + for idx, field := range fields { + fieldValues[idx] = field.ReflectValueOf(reflectValue.Index(i)) + } + + dataKey := utils.ToStringKey(fieldValues...) + if _, ok := dataResults[dataKey]; !ok { + result := make([]interface{}, len(fieldValues)) + for idx, fieldValue := range fieldValues { + result[idx] = fieldValue.Interface() + } + results = append(results, result) + + dataResults[dataKey] = []reflect.Value{reflectValue.Index(i)} + } else { + dataResults[dataKey] = append(dataResults[dataKey], reflectValue.Index(i)) + } + } + } + + return dataResults, results +} + +// ToQueryValues to query values +func ToQueryValues(foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) { + queryValues := make([]interface{}, len(foreignValues)) + if len(foreignKeys) == 1 { + for idx, r := range foreignValues { + queryValues[idx] = r[0] + } + + return foreignKeys[0], queryValues + } else { + for idx, r := range foreignValues { + queryValues[idx] = r + } + } + return foreignKeys, queryValues +} From 922a8efc53e0d93fbabc9b87d0d7b3b8d941ef70 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 19 May 2020 21:50:06 +0800 Subject: [PATCH 308/881] Generate Query Conds for Relationship --- association.go | 29 ++++++++++++++++++++------- schema/relationship.go | 45 ++++++++++++++++++++++++++++++++++++++++++ schema/schema.go | 5 +++++ 3 files changed, 72 insertions(+), 7 deletions(-) diff --git a/association.go b/association.go index a9345255..82a2274e 100644 --- a/association.go +++ b/association.go @@ -3,6 +3,7 @@ package gorm import ( "fmt" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/schema" ) @@ -31,10 +32,6 @@ func (db *DB) Association(column string) *Association { func (association *Association) Find(out interface{}, conds ...interface{}) error { if association.Error == nil { - for _, ref := range association.Relationship.References { - if ref.OwnPrimaryKey { - } - } } return association.Error @@ -53,9 +50,27 @@ func (association *Association) Delete(values ...interface{}) error { } func (association *Association) Clear() error { - return association.Error + return association.Replace() } -func (association *Association) Count() int { - return 0 +func (association *Association) Count() (count int) { + if association.Error == nil { + var ( + tx = association.DB + conds = association.Relationship.ToQueryConditions(tx.Statement.ReflectValue) + ) + + if association.Relationship.JoinTable != nil { + tx.Clauses(clause.From{Joins: []clause.Join{{ + Table: clause.Table{Name: association.Relationship.JoinTable.Table}, + ON: clause.Where{Exprs: conds}, + }}}) + } else { + tx.Clauses(clause.Where{Exprs: conds}) + } + + association.Error = tx.Count(&count).Error + } + + return } diff --git a/schema/relationship.go b/schema/relationship.go index 4ffea8b3..59aaa7e4 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -6,6 +6,7 @@ import ( "regexp" "strings" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/inflection" ) @@ -345,3 +346,47 @@ func (rel *Relationship) ParseConstraint() *Constraint { return &constraint } + +func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) { + foreignFields := []*Field{} + relForeignKeys := []string{} + + if rel.JoinTable != nil { + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + } else if ref.PrimaryValue != "" { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } else { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + }) + } + } + } else { + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + foreignFields = append(foreignFields, ref.PrimaryKey) + } else if ref.PrimaryValue != "" { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } else { + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + } + + _, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) + column, values := ToQueryValues(relForeignKeys, foreignValues) + conds = append(conds, clause.IN{Column: column, Values: values}) + return +} diff --git a/schema/schema.go b/schema/schema.go index 5a28797b..79faae12 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -8,6 +8,7 @@ import ( "reflect" "sync" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/logger" ) @@ -26,6 +27,10 @@ type Schema struct { FieldsByDBName map[string]*Field FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database Relationships Relationships + CreateClauses []clause.Interface + QueryClauses []clause.Interface + UpdateClauses []clause.Interface + DeleteClauses []clause.Interface BeforeCreate, AfterCreate bool BeforeUpdate, AfterUpdate bool BeforeDelete, AfterDelete bool From 20cb57b1aceacf251f15e553b8082d8ed258b1a1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 20 May 2020 02:03:43 +0800 Subject: [PATCH 309/881] Add association Delete support --- association.go | 90 +++++++++++++++++++++++++++++++++++++++++++++++++ schema/utils.go | 15 +++++++++ 2 files changed, 105 insertions(+) diff --git a/association.go b/association.go index 82a2274e..027f327e 100644 --- a/association.go +++ b/association.go @@ -2,9 +2,11 @@ package gorm import ( "fmt" + "reflect" "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/gorm/utils" ) // Association Mode contains some helper methods to handle relationship things easily. @@ -46,6 +48,90 @@ func (association *Association) Replace(values ...interface{}) error { } func (association *Association) Delete(values ...interface{}) error { + if association.Error == nil { + var ( + tx = association.DB + rel = association.Relationship + reflectValue = tx.Statement.ReflectValue + conds = rel.ToQueryConditions(reflectValue) + relFields []*schema.Field + foreignKeys []string + updateAttrs = map[string]interface{}{} + ) + + for _, ref := range rel.References { + if ref.PrimaryValue == "" { + if rel.JoinTable == nil || !ref.OwnPrimaryKey { + if ref.OwnPrimaryKey { + relFields = append(relFields, ref.ForeignKey) + } else { + relFields = append(relFields, ref.PrimaryKey) + } + + foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) + updateAttrs[ref.ForeignKey.DBName] = nil + } + } + } + + relValuesMap, relQueryValues := schema.GetIdentityFieldValuesMapFromValues(values, relFields) + column, values := schema.ToQueryValues(foreignKeys, relQueryValues) + tx.Where(clause.IN{Column: column, Values: values}) + + switch association.Relationship.Type { + case schema.HasOne, schema.HasMany: + modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() + tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) + case schema.BelongsTo: + tx.Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) + case schema.Many2Many: + modelValue := reflect.New(rel.JoinTable.ModelType).Interface() + tx.Clauses(clause.Where{Exprs: conds}).Delete(modelValue) + } + + if tx.Error == nil { + cleanUpDeletedRelations := func(data reflect.Value) { + if _, zero := rel.Field.ValueOf(data); !zero { + fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) + + fieldValues := make([]reflect.Value, len(relFields)) + switch fieldValue.Kind() { + case reflect.Slice, reflect.Array: + validFieldValues := reflect.Zero(rel.Field.FieldType) + for i := 0; i < fieldValue.Len(); i++ { + for idx, field := range relFields { + fieldValues[idx] = field.ReflectValueOf(fieldValue.Index(i)) + } + + if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; !ok { + validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i)) + } + } + + rel.Field.Set(data, validFieldValues) + case reflect.Struct: + for idx, field := range relFields { + fieldValues[idx] = field.ReflectValueOf(data) + } + if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; ok { + rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType)) + } + } + } + } + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i))) + } + case reflect.Struct: + cleanUpDeletedRelations(reflectValue) + } + } else { + association.Error = tx.Error + } + } return association.Error } @@ -61,6 +147,10 @@ func (association *Association) Count() (count int) { ) if association.Relationship.JoinTable != nil { + for _, queryClause := range association.Relationship.JoinTable.QueryClauses { + tx.Clauses(queryClause) + } + tx.Clauses(clause.From{Joins: []clause.Join{{ Table: clause.Table{Name: association.Relationship.JoinTable.Table}, ON: clause.Where{Exprs: conds}, diff --git a/schema/utils.go b/schema/utils.go index 7a26332d..72bd149c 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -128,6 +128,21 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map return dataResults, results } +// GetIdentityFieldValuesMapFromValues get identity map from fields +func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { + resultsMap := map[string][]reflect.Value{} + results := [][]interface{}{} + + for _, v := range values { + rm, rs := GetIdentityFieldValuesMap(reflect.Indirect(reflect.ValueOf(v)), fields) + for k, v := range rm { + resultsMap[k] = append(resultsMap[k], v...) + } + results = append(results, rs...) + } + return resultsMap, results +} + // ToQueryValues to query values func ToQueryValues(foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) { queryValues := make([]interface{}, len(foreignValues)) From 0f21272c7fe254c90886a05e0cea359ac2f48fc1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 20 May 2020 23:44:50 +0800 Subject: [PATCH 310/881] Finish implement association support --- association.go | 198 +++++++++++++++++++++++++++++++++++++- callbacks/associations.go | 8 +- 2 files changed, 201 insertions(+), 5 deletions(-) diff --git a/association.go b/association.go index 027f327e..a889157b 100644 --- a/association.go +++ b/association.go @@ -1,6 +1,7 @@ package gorm import ( + "errors" "fmt" "reflect" @@ -34,16 +35,119 @@ func (db *DB) Association(column string) *Association { func (association *Association) Find(out interface{}, conds ...interface{}) error { if association.Error == nil { + var ( + tx = association.DB + queryConds = association.Relationship.ToQueryConditions(tx.Statement.ReflectValue) + ) + + if association.Relationship.JoinTable != nil { + for _, queryClause := range association.Relationship.JoinTable.QueryClauses { + tx.Clauses(queryClause) + } + + tx.Clauses(clause.From{Joins: []clause.Join{{ + Table: clause.Table{Name: association.Relationship.JoinTable.Table}, + ON: clause.Where{Exprs: queryConds}, + }}}) + } else { + tx.Clauses(clause.Where{Exprs: queryConds}) + } + + association.Error = tx.Find(out, conds...).Error } return association.Error } func (association *Association) Append(values ...interface{}) error { + if association.Error == nil { + switch association.Relationship.Type { + case schema.HasOne, schema.BelongsTo: + if len(values) > 0 { + association.Error = association.Replace(values...) + } + default: + association.saveAssociation(false, values...) + } + } + return association.Error } func (association *Association) Replace(values ...interface{}) error { + if association.Error == nil { + association.saveAssociation(true, values...) + reflectValue := association.DB.Statement.ReflectValue + rel := association.Relationship + + switch rel.Type { + case schema.HasOne, schema.HasMany: + var ( + primaryFields []*schema.Field + foreignKeys []string + updateMap = map[string]interface{}{} + modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() + ) + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + } else { + foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) + updateMap[ref.ForeignKey.DBName] = nil + } + } + + _, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + column, queryValues := schema.ToQueryValues(foreignKeys, values) + association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap) + case schema.Many2Many: + var primaryFields, relPrimaryFields []*schema.Field + var foreignKeys, relForeignKeys []string + modelValue := reflect.New(rel.JoinTable.ModelType).Interface() + conds := []clause.Expression{} + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } else { + relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } + } + + generateConds := func(rv reflect.Value) { + _, values := schema.GetIdentityFieldValuesMap(rv, primaryFields) + column, queryValues := schema.ToQueryValues(foreignKeys, values) + + relValue := rel.Field.ReflectValueOf(rv) + _, relValues := schema.GetIdentityFieldValuesMap(relValue, relPrimaryFields) + relColumn, relQueryValues := schema.ToQueryValues(relForeignKeys, relValues) + + conds = append(conds, clause.And( + clause.IN{Column: column, Values: queryValues}, + clause.Not(clause.IN{Column: relColumn, Values: relQueryValues}), + )) + } + + switch reflectValue.Kind() { + case reflect.Struct: + generateConds(reflectValue) + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + generateConds(reflectValue.Index(i)) + } + } + + association.DB.Where(conds).Delete(modelValue) + } + } return association.Error } @@ -78,7 +182,7 @@ func (association *Association) Delete(values ...interface{}) error { column, values := schema.ToQueryValues(foreignKeys, relQueryValues) tx.Where(clause.IN{Column: column, Values: values}) - switch association.Relationship.Type { + switch rel.Type { case schema.HasOne, schema.HasMany: modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) @@ -164,3 +268,95 @@ func (association *Association) Count() (count int) { return } + +func (association *Association) saveAssociation(clear bool, values ...interface{}) { + reflectValue := association.DB.Statement.ReflectValue + + appendToRelations := func(source, rv reflect.Value, clear bool) { + switch association.Relationship.Type { + case schema.HasOne, schema.BelongsTo: + switch rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() > 0 { + association.Error = association.Relationship.Field.Set(source, rv.Index(0)) + } + case reflect.Struct: + association.Error = association.Relationship.Field.Set(source, rv) + } + case schema.HasMany, schema.Many2Many: + elemType := association.Relationship.Field.IndirectFieldType.Elem() + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(reflectValue)) + if clear { + fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType) + } + + appendToFieldValues := func(ev reflect.Value) { + if ev.Type().AssignableTo(elemType) { + fieldValue = reflect.Append(fieldValue, ev) + } else if ev.Type().Elem().AssignableTo(elemType) { + fieldValue = reflect.Append(fieldValue, ev.Elem()) + } else { + association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name) + } + } + + switch rv.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < rv.Len(); i++ { + appendToFieldValues(reflect.Indirect(rv.Index(i))) + } + case reflect.Struct: + appendToFieldValues(rv) + } + + if association.Error == nil { + association.Error = association.Relationship.Field.Set(source, fieldValue) + } + } + } + + selectedColumns := []string{association.Relationship.Name} + hasZero := false + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey { + selectedColumns = append(selectedColumns, ref.ForeignKey.Name) + } + } + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + if len(values) != reflectValue.Len() { + if clear && len(values) == 0 { + for i := 0; i < reflectValue.Len(); i++ { + association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType)) + } + break + } + association.Error = errors.New("invalid association values, length doesn't match") + } + + for i := 0; i < reflectValue.Len(); i++ { + appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) + + if !hasZero { + _, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue.Index(i)) + } + } + case reflect.Struct: + if clear && len(values) == 0 { + association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType)) + } + + for idx, value := range values { + appendToRelations(reflectValue, reflect.Indirect(reflect.ValueOf(value)), clear && idx == 0) + } + + _, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue) + } + + if hasZero { + association.DB.Save(reflectValue.Interface()) + } else { + association.DB.Select(selectedColumns).Save(reflectValue.Interface()) + } +} diff --git a/callbacks/associations.go b/callbacks/associations.go index df19d5f5..a0c296e3 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -28,7 +28,7 @@ func SaveBeforeAssociations(db *gorm.DB) { } switch db.Statement.ReflectValue.Kind() { - case reflect.Slice: + case reflect.Slice, reflect.Array: var ( objs []reflect.Value fieldType = rel.Field.FieldType @@ -92,7 +92,7 @@ func SaveAfterAssociations(db *gorm.DB) { } switch db.Statement.ReflectValue.Kind() { - case reflect.Slice: + case reflect.Slice, reflect.Array: var ( fieldType = rel.Field.FieldType isPtr = fieldType.Kind() == reflect.Ptr @@ -193,7 +193,7 @@ func SaveAfterAssociations(db *gorm.DB) { } switch db.Statement.ReflectValue.Kind() { - case reflect.Slice: + case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { appendToElems(db.Statement.ReflectValue.Index(i)) } @@ -260,7 +260,7 @@ func SaveAfterAssociations(db *gorm.DB) { } switch db.Statement.ReflectValue.Kind() { - case reflect.Slice: + case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { appendToElems(db.Statement.ReflectValue.Index(i)) } From 72460df1bd40f8088cba45e8a79f4506bd31ab51 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 23 May 2020 11:57:28 +0800 Subject: [PATCH 311/881] Fix associations find --- association.go | 6 ++-- callbacks.go | 8 ++++- go.mod | 6 +--- tests/associations.go | 73 +++++++++++++++++++++++++++++++++++++++++++ tests/tests.go | 1 + 5 files changed, 86 insertions(+), 8 deletions(-) create mode 100644 tests/associations.go diff --git a/association.go b/association.go index a889157b..ab9090ac 100644 --- a/association.go +++ b/association.go @@ -26,6 +26,8 @@ func (db *DB) Association(column string) *Association { if association.Relationship == nil { association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column) } + + db.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(db.Statement.Model)) } else { association.Error = err } @@ -36,8 +38,8 @@ func (db *DB) Association(column string) *Association { func (association *Association) Find(out interface{}, conds ...interface{}) error { if association.Error == nil { var ( - tx = association.DB - queryConds = association.Relationship.ToQueryConditions(tx.Statement.ReflectValue) + queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) + tx = association.DB.Model(out).Table("") ) if association.Relationship.JoinTable != nil { diff --git a/callbacks.go b/callbacks.go index 6c70b392..61cebc81 100644 --- a/callbacks.go +++ b/callbacks.go @@ -83,7 +83,13 @@ func (p *processor) Execute(db *DB) { db.AddError(err) } } - stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest)) + + if stmt.Dest != nil { + stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest)) + if !stmt.ReflectValue.IsValid() { + db.AddError(fmt.Errorf("invalid value")) + } + } } for _, f := range p.fns { diff --git a/go.mod b/go.mod index 3e067d3c..d3421e1b 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,8 @@ module github.com/jinzhu/gorm -go 1.13 +go 1.14 require ( - github.com/denisenkom/go-mssqldb v0.0.0-20200206145737-bbfc9a55622e // indirect - github.com/go-sql-driver/mysql v1.5.0 // indirect github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.1.1 - github.com/lib/pq v1.3.0 // indirect - github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect ) diff --git a/tests/associations.go b/tests/associations.go new file mode 100644 index 00000000..7e93e81e --- /dev/null +++ b/tests/associations.go @@ -0,0 +1,73 @@ +package tests + +import ( + "testing" + + "github.com/jinzhu/gorm" +) + +func TestAssociations(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{}) + db.Migrator().AutoMigrate(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{}) + + TestBelongsToAssociations(t, db) +} + +func TestBelongsToAssociations(t *testing.T, db *gorm.DB) { + check := func(t *testing.T, user User, old User) { + if old.Company.Name != "" { + if user.CompanyID == nil { + t.Errorf("Company's foreign key should be saved") + } else { + var company Company + db.First(&company, "id = ?", *user.CompanyID) + if company.Name != old.Company.Name { + t.Errorf("Company's name should be same, expects: %v, got %v", old.Company.Name, user.Company.Name) + } else if user.Company.Name != old.Company.Name { + t.Errorf("Company's name should be same, expects: %v, got %v", old.Company.Name, user.Company.Name) + } + } + } else if user.CompanyID != nil { + t.Errorf("Company should not be created for zero value, got: %+v", user.CompanyID) + } + + if old.Manager != nil { + if user.ManagerID == nil { + t.Errorf("Manager's foreign key should be saved") + } else { + var manager User + db.First(&manager, "id = ?", *user.ManagerID) + if manager.Name != user.Manager.Name { + t.Errorf("Manager's name should be same") + } else if user.Manager.Name != old.Manager.Name { + t.Errorf("Manager's name should be same") + } + } + } else if user.ManagerID != nil { + t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) + } + } + + t.Run("BelongsTo", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association"}, + Manager: &User{Name: "manager-belongs-to-association"}, + } + + if err := db.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + check(t, user, user) + + var user2 User + db.Find(&user2, "id = ?", user.ID) + db.Model(&user2).Association("Company").Find(&user2.Company) + user2.Manager = &User{} + db.Model(&user2).Association("Manager").Find(user2.Manager) + check(t, user2, user) + }) +} diff --git a/tests/tests.go b/tests/tests.go index cc9c1a78..87005a71 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -20,4 +20,5 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) { TestGroupBy(t, db) TestJoins(t, db) + TestAssociations(t, db) } From bb68f0d6b3715c62025ac0ec560aac96923c5e83 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 23 May 2020 16:08:50 +0800 Subject: [PATCH 312/881] Refactor tests --- go.mod | 5 + statement.go | 20 +- tests/create.go | 834 +++---------------------------------------- tests/create_test.go | 761 +++++++++++++++++++++++++++++++++++++++ tests/main_test.go | 95 +++++ tests/tests.go | 2 +- 6 files changed, 936 insertions(+), 781 deletions(-) create mode 100644 tests/create_test.go create mode 100644 tests/main_test.go diff --git a/go.mod b/go.mod index d3421e1b..45bcf69c 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,11 @@ module github.com/jinzhu/gorm go 1.14 require ( + github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd + github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 + github.com/go-sql-driver/mysql v1.5.0 github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.1.1 + github.com/lib/pq v1.1.1 + github.com/mattn/go-sqlite3 v2.0.1+incompatible ) diff --git a/statement.go b/statement.go index f3090eb7..1ea5a56c 100644 --- a/statement.go +++ b/statement.go @@ -147,8 +147,24 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { writer.WriteString("(NULL)") } default: - stmt.Vars = append(stmt.Vars, v) - stmt.DB.Dialector.BindVarTo(writer, stmt, v) + switch rv := reflect.ValueOf(v); rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() == 0 { + writer.WriteString("(NULL)") + } else { + writer.WriteByte('(') + for i := 0; i < rv.Len(); i++ { + if i > 0 { + writer.WriteByte(',') + } + stmt.AddVar(writer, rv.Index(i).Interface()) + } + writer.WriteByte(')') + } + default: + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) + } } } } diff --git a/tests/create.go b/tests/create.go index 428f876c..ec57b8ee 100644 --- a/tests/create.go +++ b/tests/create.go @@ -1,787 +1,65 @@ package tests import ( - "fmt" - "testing" - - "github.com/jinzhu/gorm" + "strconv" + "time" ) -func TestCreate(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&User{}) - db.AutoMigrate(&User{}) - - t.Run("Create", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - if user.ID == 0 { - t.Errorf("user's primary key should has value after create, got : %v", user.ID) - } - - if user.CreatedAt.IsZero() { - t.Errorf("user's created at should be not zero") - } - - if user.UpdatedAt.IsZero() { - t.Errorf("user's updated at should be not zero") - } - - var newUser User - if err := db.Where("id = ?", user.ID).First(&newUser).Error; err != nil { - t.Errorf("errors happened when query: %v", err) - } else { - AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday") - } - - TestCreateAssociations(t, db) - }) +type Config struct { + Account bool + Pets int + Toys int + Company bool + Manager bool + Team int + Languages int + Friends int } -func TestCreateAssociations(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{}) - db.Migrator().AutoMigrate(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{}) +func GetUser(name string, config Config) User { + var ( + birthday = time.Now() + user = User{ + Name: name, + Age: 18, + Birthday: &birthday, + } + ) - TestCreateBelongsToAssociations(t, db) - TestCreateHasOneAssociations(t, db) - TestCreateHasManyAssociations(t, db) - TestCreateMany2ManyAssociations(t, db) -} - -func TestCreateBelongsToAssociations(t *testing.T, db *gorm.DB) { - check := func(t *testing.T, user User, old User) { - if old.Company.Name != "" { - if user.CompanyID == nil { - t.Errorf("Company's foreign key should be saved") - } else { - var company Company - db.First(&company, "id = ?", *user.CompanyID) - if company.Name != old.Company.Name { - t.Errorf("Company's name should be same") - } else if user.Company.Name != old.Company.Name { - t.Errorf("Company's name should be same") - } - } - } else if user.CompanyID != nil { - t.Errorf("Company should not be created for zero value, got: %+v", user.CompanyID) - } - - if old.Manager != nil { - if user.ManagerID == nil { - t.Errorf("Manager's foreign key should be saved") - } else { - var manager User - db.First(&manager, "id = ?", *user.ManagerID) - if manager.Name != user.Manager.Name { - t.Errorf("Manager's name should be same") - } else if user.Manager.Name != old.Manager.Name { - t.Errorf("Manager's name should be same") - } - } - } else if user.ManagerID != nil { - t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) - } - } - - t.Run("BelongsTo", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association"}, - Manager: &User{Name: "manager-belongs-to-association"}, - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - check(t, user, user) - - var user2 User - db.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID) - check(t, user2, user) - }) - - t.Run("BelongsToForBulkInsert", func(t *testing.T) { - var users = []User{{ - Name: "create-1", - Age: 18, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association-1"}, - Manager: &User{Name: "manager-belongs-to-association-1"}, - }, { - Name: "create-2", - Age: 28, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association-2"}, - }, { - Name: "create-3", - Age: 38, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association-3"}, - Manager: &User{Name: "manager-belongs-to-association-3"}, - }} - - if err := db.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - var userIDs []uint - for _, user := range users { - userIDs = append(userIDs, user.ID) - check(t, user, user) - } - - var users2 []User - db.Preload("Company").Preload("Manager").Find(&users2, "id IN (?)", userIDs) - for idx, user := range users2 { - check(t, user, users[idx]) - } - - var users3 []User - db.Preload("Company").Preload("Manager").Find(users3, "id IN (?)", userIDs) - for idx, user := range users3 { - check(t, user, users[idx]) - } - }) - - t.Run("BelongsToForBulkInsertPtrData", func(t *testing.T) { - var users = []*User{{ - Name: "create-1", - Age: 18, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association-1"}, - Manager: &User{Name: "manager-belongs-to-association-1"}, - }, { - Name: "create-2", - Age: 28, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association-2"}, - }, { - Name: "create-3", - Age: 38, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association-3"}, - Manager: &User{Name: "manager-belongs-to-association-3"}, - }} - - if err := db.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - for _, user := range users { - check(t, *user, *user) - } - }) - - t.Run("BelongsToForBulkInsertWithoutPtr", func(t *testing.T) { - var users = []*User{{ - Name: "create-1", - Age: 18, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association-1"}, - Manager: &User{Name: "manager-belongs-to-association-1"}, - }, { - Name: "create-2", - Age: 28, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association-2"}, - }, { - Name: "create-3", - Age: 38, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association-3"}, - Manager: &User{Name: "manager-belongs-to-association-3"}, - }} - - if err := db.Create(users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - for _, user := range users { - check(t, *user, *user) - } - }) -} - -func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) { - check := func(t *testing.T, user User, old User) { - if user.Account.ID == 0 { - t.Errorf("Account should be saved") - } else if user.Account.UserID.Int64 != int64(user.ID) { - t.Errorf("Account's foreign key should be saved") - } else { - var account Account - db.First(&account, "id = ?", user.Account.ID) - if account.Number != user.Account.Number { - t.Errorf("Account's number should be same") - } else if user.Account.Number != old.Account.Number { - t.Errorf("Account's number should be same") - } - } - } - - t.Run("HasOne", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - Account: Account{Number: "account-has-one-association"}, - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - check(t, user, user) - - var user2 User - db.Preload("Account").Find(&user2, "id = ?", user.ID) - check(t, user2, user) - }) - - t.Run("HasOneForBulkInsert", func(t *testing.T) { - var users = []User{{ - Name: "create-1", - Age: 18, - Birthday: Now(), - Account: Account{Number: "account-has-one-association-1"}, - }, { - Name: "create-2", - Age: 28, - Birthday: Now(), - Account: Account{Number: "account-has-one-association-2"}, - }, { - Name: "create-3", - Age: 38, - Birthday: Now(), - Account: Account{Number: "account-has-one-association-3"}, - }} - - if err := db.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - var userIDs []uint - for _, user := range users { - userIDs = append(userIDs, user.ID) - check(t, user, user) - } - - var users2 []User - db.Preload("Account").Find(&users2, "id IN (?)", userIDs) - for idx, user := range users2 { - check(t, user, users[idx]) - } - }) - - t.Run("HasOneForBulkInsertPtrData", func(t *testing.T) { - var users = []*User{{ - Name: "create-1", - Age: 18, - Birthday: Now(), - Account: Account{Number: "account-has-one-association-1"}, - }, { - Name: "create-2", - Age: 28, - Birthday: Now(), - Account: Account{Number: "account-has-one-association-2"}, - }, { - Name: "create-3", - Age: 38, - Birthday: Now(), - Account: Account{Number: "account-has-one-association-3"}, - }} - - if err := db.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - for _, user := range users { - check(t, *user, *user) - } - }) - - t.Run("HasOneForBulkInsertWithoutPtr", func(t *testing.T) { - var users = []User{{ - Name: "create-1", - Age: 18, - Birthday: Now(), - Account: Account{Number: "account-has-one-association-1"}, - }, { - Name: "create-2", - Age: 28, - Birthday: Now(), - Account: Account{Number: "account-has-one-association-2"}, - }, { - Name: "create-3", - Age: 38, - Birthday: Now(), - Account: Account{Number: "account-has-one-association-3"}, - }} - - if err := db.Create(users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - for _, user := range users { - check(t, user, user) - } - }) - - checkPet := func(t *testing.T, pet Pet, old Pet) { - if pet.Toy.OwnerID != fmt.Sprint(pet.ID) || pet.Toy.OwnerType != "pets" { - t.Errorf("Failed to create polymorphic has one association - toy owner id %v, owner type %v", pet.Toy.OwnerID, pet.Toy.OwnerType) - } else { - var toy Toy - db.First(&toy, "owner_id = ? and owner_type = ?", pet.Toy.OwnerID, pet.Toy.OwnerType) - if toy.Name != pet.Toy.Name { - t.Errorf("Failed to query saved polymorphic has one association") - } else if old.Toy.Name != pet.Toy.Name { - t.Errorf("Failed to query saved polymorphic has one association") - } - } - } - - t.Run("PolymorphicHasOne", func(t *testing.T) { - var pet = Pet{ - Name: "create", - Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic"}, - } - - if err := db.Create(&pet).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - checkPet(t, pet, pet) - - var pet2 Pet - db.Preload("Toy").Find(&pet2, "id = ?", pet.ID) - checkPet(t, pet2, pet) - }) - - t.Run("PolymorphicHasOneForBulkInsert", func(t *testing.T) { - var pets = []Pet{{ - Name: "create-1", - Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-1"}, - }, { - Name: "create-2", - Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-2"}, - }, { - Name: "create-3", - Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-3"}, - }} - - if err := db.Create(&pets).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - var petIDs []uint - for _, pet := range pets { - petIDs = append(petIDs, pet.ID) - checkPet(t, pet, pet) - } - - var pets2 []Pet - db.Preload("Toy").Find(&pets2, "id IN (?)", petIDs) - for idx, pet := range pets2 { - checkPet(t, pet, pets[idx]) - } - }) - - t.Run("PolymorphicHasOneForBulkInsertPtrData", func(t *testing.T) { - var pets = []*Pet{{ - Name: "create-1", - Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-1"}, - }, { - Name: "create-2", - Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-2"}, - }, { - Name: "create-3", - Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-3"}, - }} - - if err := db.Create(&pets).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - for _, pet := range pets { - checkPet(t, *pet, *pet) - } - }) - - t.Run("PolymorphicHasOneForBulkInsertWithoutPtr", func(t *testing.T) { - var pets = []*Pet{{ - Name: "create-1", - Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-1"}, - }, { - Name: "create-2", - Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-2"}, - }, { - Name: "create-3", - Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-3"}, - }} - - if err := db.Create(pets).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - for _, pet := range pets { - checkPet(t, *pet, *pet) - } - }) -} - -func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { - check := func(t *testing.T, user User, old User) { - for idx, pet := range user.Pets { - if pet.ID == 0 { - t.Errorf("Pet's foreign key should be saved") - } - - var result Pet - db.First(&result, "id = ?", pet.ID) - if result.Name != pet.Name { - t.Errorf("Pet's name should be same") - } else if result.UserID != user.ID { - t.Errorf("Pet's foreign key should be saved") - } else if result.Name != old.Pets[idx].Name { - t.Errorf("Pet's name should be same") - } - } - } - - t.Run("HasMany", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - Pets: []*Pet{{Name: "pet1"}, {Name: "pet2"}}, - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - check(t, user, user) - - var user2 User - db.Preload("Pets").Find(&user2, "id = ?", user.ID) - check(t, user2, user) - }) - - t.Run("HasManyForBulkInsert", func(t *testing.T) { - var users = []User{{ - Name: "create-1", - Age: 18, - Birthday: Now(), - Pets: []*Pet{{Name: "pet-1-1"}, {Name: "pet-1-2"}}, - }, { - Name: "create-2", - Age: 28, - Birthday: Now(), - Pets: []*Pet{{Name: "pet-2-1"}}, - }, { - Name: "create-3", - Age: 38, - Birthday: Now(), - Pets: []*Pet{{Name: "pet-3-1"}, {Name: "pet-3-2"}}, - }} - - if err := db.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - var userIDs []uint - for _, user := range users { - userIDs = append(userIDs, user.ID) - check(t, user, user) - } - - var users2 []User - db.Preload("Pets").Find(&users2, "id IN (?)", userIDs) - for idx, user := range users2 { - check(t, user, users[idx]) - } - }) - - t.Run("HasManyForBulkInsertPtrData", func(t *testing.T) { - var users = []*User{{ - Name: "create-1", - Age: 18, - Birthday: Now(), - Pets: []*Pet{{Name: "pet-1-1"}, {Name: "pet-1-2"}}, - }, { - Name: "create-2", - Age: 28, - Birthday: Now(), - Pets: []*Pet{{Name: "pet-2-1"}}, - }, { - Name: "create-3", - Age: 38, - Birthday: Now(), - Pets: []*Pet{{Name: "pet-3-1"}, {Name: "pet-3-2"}}, - }} - - if err := db.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - for _, user := range users { - check(t, *user, *user) - } - }) - - t.Run("HasManyForBulkInsertWithoutPtr", func(t *testing.T) { - var users = []*User{{ - Name: "create-1", - Age: 18, - Birthday: Now(), - Pets: []*Pet{{Name: "pet-1-1"}, {Name: "pet-1-2"}}, - }, { - Name: "create-2", - Age: 28, - Birthday: Now(), - Pets: []*Pet{{Name: "pet-2-1"}}, - }, { - Name: "create-3", - Age: 38, - Birthday: Now(), - Pets: []*Pet{{Name: "pet-3-1"}, {Name: "pet-3-2"}}, - }} - - if err := db.Create(users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - for _, user := range users { - check(t, *user, *user) - } - }) - - checkToy := func(t *testing.T, user User, old User) { - for idx, toy := range user.Toys { - if toy.ID == 0 { - t.Fatalf("Failed to create toy #%v", idx) - } - - var result Toy - db.First(&result, "id = ?", toy.ID) - if result.Name != toy.Name { - t.Errorf("Failed to query saved toy") - } else if result.Name != old.Toys[idx].Name { - t.Errorf("Failed to query saved toy") - } else if result.OwnerID != fmt.Sprint(user.ID) || result.OwnerType != "users" { - t.Errorf("Failed to save relation") - } - } - } - - t.Run("PolymorphicHasMany", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - Toys: []Toy{{Name: "toy1"}, {Name: "toy2"}}, - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - checkToy(t, user, user) - - var user2 User - db.Preload("Toys").Find(&user2, "id = ?", user.ID) - check(t, user2, user) - }) - - t.Run("PolymorphicHasManyForBulkInsert", func(t *testing.T) { - var users = []User{{ - Name: "create-1", - Age: 18, - Birthday: Now(), - Toys: []Toy{{Name: "toy-1-1"}, {Name: "toy-1-2"}}, - }, { - Name: "create-2", - Age: 28, - Birthday: Now(), - Toys: []Toy{{Name: "toy-2-1"}, {Name: "toy-2-2"}}, - }, { - Name: "create-3", - Age: 38, - Birthday: Now(), - Toys: []Toy{{Name: "toy-3-1"}, {Name: "toy-3-2"}}, - }} - - if err := db.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - var userIDs []uint - for _, user := range users { - userIDs = append(userIDs, user.ID) - checkToy(t, user, user) - } - - var users2 []User - db.Preload("Toys").Find(&users2, "id IN (?)", userIDs) - for idx, user := range users2 { - check(t, user, users[idx]) - } - }) - - t.Run("PolymorphicHasManyForBulkInsertPtrData", func(t *testing.T) { - var users = []*User{{ - Name: "create-1", - Age: 18, - Birthday: Now(), - Toys: []Toy{{Name: "toy-1-1"}, {Name: "toy-1-2"}}, - }, { - Name: "create-2", - Age: 28, - Birthday: Now(), - Toys: []Toy{{Name: "toy-2-1"}, {Name: "toy-2-2"}}, - }, { - Name: "create-3", - Age: 38, - Birthday: Now(), - Toys: []Toy{{Name: "toy-3-1"}, {Name: "toy-3-2"}}, - }} - - if err := db.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - for _, user := range users { - checkToy(t, *user, *user) - } - }) - - t.Run("PolymorphicHasManyForBulkInsertWithoutPtr", func(t *testing.T) { - var users = []User{{ - Name: "create-1", - Age: 18, - Birthday: Now(), - Toys: []Toy{{Name: "toy-1-1"}, {Name: "toy-1-2"}}, - }, { - Name: "create-2", - Age: 28, - Birthday: Now(), - Toys: []Toy{{Name: "toy-2-1"}, {Name: "toy-2-2"}}, - }, { - Name: "create-3", - Age: 38, - Birthday: Now(), - Toys: []Toy{{Name: "toy-3-1"}, {Name: "toy-3-2"}}, - }} - - if err := db.Create(users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - for _, user := range users { - checkToy(t, user, user) - } - }) -} - -func TestCreateMany2ManyAssociations(t *testing.T, db *gorm.DB) { - check := func(t *testing.T, user User, old User) { - for idx, language := range user.Languages { - var result Language - db.First(&result, "code = ?", language.Code) - if result.Name != language.Name { - t.Errorf("Language's name should be same") - } else if result.Name != old.Languages[idx].Name { - t.Errorf("Language's name should be same") - } - } - - for idx, f := range user.Friends { - if f.ID == 0 { - t.Errorf("Friend's foreign key should be saved") - } - - var result User - db.First(&result, "id = ?", f.ID) - if result.Name != f.Name { - t.Errorf("Friend's name should be same") - } else if result.Name != old.Friends[idx].Name { - t.Errorf("Language's name should be same") - } - } - } - - db.Create(&[]Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}}) - - t.Run("Many2Many", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - Languages: []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}}, - Friends: []*User{{Name: "friend-1"}, {Name: "friend-2"}}, - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - check(t, user, user) - - var user2 User - db.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) - check(t, user2, user) - }) - - t.Run("Many2ManyForBulkInsert", func(t *testing.T) { - var users = []User{ - { - Name: "create-1", - Age: 18, - Birthday: Now(), - Languages: []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}}, - Friends: []*User{{Name: "friend-1-1"}, {Name: "friend-1-2"}}, - }, - { - Name: "create-2", - Age: 18, - Birthday: Now(), - Languages: []Language{{Code: "zh-CN", Name: "Chinese"}}, - Friends: []*User{{Name: "friend-2-1"}}, - }, - { - Name: "create-3", - Age: 18, - Birthday: Now(), - Languages: []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}}, - Friends: []*User{{Name: "friend-3-1"}, {Name: "friend-3-2"}, {Name: "friend-3-3"}}, - }, - } - - if err := db.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - var userIDs []uint - for _, user := range users { - userIDs = append(userIDs, user.ID) - check(t, user, user) - } - - var users2 []User - db.Preload("Languages").Preload("Friends").Find(&users2, "id IN (?)", userIDs) - for idx, user := range users2 { - check(t, user, users[idx]) - } - }) + if config.Account { + user.Account = Account{Number: name + "_account"} + } + + for i := 0; i < config.Pets; i++ { + user.Pets = append(user.Pets, &Pet{Name: name + "_pet_" + strconv.Itoa(i+1)}) + } + + for i := 0; i < config.Toys; i++ { + user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)}) + } + + if config.Company { + user.Company = Company{Name: "company-" + name} + } + + if config.Manager { + manager := GetUser(name+"_manager", Config{}) + user.Manager = &manager + } + + for i := 0; i < config.Team; i++ { + user.Team = append(user.Team, GetUser(name+"_team_"+strconv.Itoa(i+1), Config{})) + } + + for i := 0; i < config.Languages; i++ { + name := "Locale_" + strconv.Itoa(i+0) + user.Languages = append(user.Languages, Language{Code: name, Name: name}) + } + + for i := 0; i < config.Friends; i++ { + f := GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{}) + user.Friends = append(user.Friends, &f) + } + + return user } diff --git a/tests/create_test.go b/tests/create_test.go new file mode 100644 index 00000000..471cecf6 --- /dev/null +++ b/tests/create_test.go @@ -0,0 +1,761 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestCreate(t *testing.T) { + var user = GetUser("create", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + if user.ID == 0 { + t.Errorf("user's primary key should has value after create, got : %v", user.ID) + } + + if user.CreatedAt.IsZero() { + t.Errorf("user's created at should be not zero") + } + + if user.UpdatedAt.IsZero() { + t.Errorf("user's updated at should be not zero") + } + + var newUser User + if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday") + } +} + +func TestCreateWithBelongsToAssociations(t *testing.T) { + check := func(t *testing.T, user User, old User) { + if old.Company.Name != "" { + if user.CompanyID == nil { + t.Errorf("Company's foreign key should be saved") + } else { + var company Company + DB.First(&company, "id = ?", *user.CompanyID) + if company.Name != old.Company.Name { + t.Errorf("Company's name should be same") + } else if user.Company.Name != old.Company.Name { + t.Errorf("Company's name should be same") + } + } + } else if user.CompanyID != nil { + t.Errorf("Company should not be created for zero value, got: %+v", user.CompanyID) + } + + if old.Manager != nil { + if user.ManagerID == nil { + t.Errorf("Manager's foreign key should be saved") + } else { + var manager User + DB.First(&manager, "id = ?", *user.ManagerID) + if manager.Name != user.Manager.Name { + t.Errorf("Manager's name should be same") + } else if user.Manager.Name != old.Manager.Name { + t.Errorf("Manager's name should be same") + } + } + } else if user.ManagerID != nil { + t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) + } + } + + t.Run("Struct", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association"}, + Manager: &User{Name: "manager-belongs-to-association"}, + } + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + check(t, user, user) + + var user2 User + DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID) + check(t, user2, user) + }) + + t.Run("BulkInsert", func(t *testing.T) { + var users = []User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-1"}, + Manager: &User{Name: "manager-belongs-to-association-1"}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-2"}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-3"}, + Manager: &User{Name: "manager-belongs-to-association-3"}, + }} + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + check(t, user, user) + } + + t.Run("Preload", func(t *testing.T) { + var users2 []User + DB.Preload("Company").Preload("Manager").Find(&users2, "id IN ?", userIDs) + for idx, user := range users2 { + check(t, user, users[idx]) + } + }) + }) + + t.Run("BulkInsertPtrData", func(t *testing.T) { + var users = []*User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-1"}, + Manager: &User{Name: "manager-belongs-to-association-1"}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-2"}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-3"}, + Manager: &User{Name: "manager-belongs-to-association-3"}, + }} + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, user := range users { + check(t, *user, *user) + } + }) + + t.Run("BulkInsertWithoutPtr", func(t *testing.T) { + var users = []*User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-1"}, + Manager: &User{Name: "manager-belongs-to-association-1"}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-2"}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-3"}, + Manager: &User{Name: "manager-belongs-to-association-3"}, + }} + + if err := DB.Create(users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, user := range users { + check(t, *user, *user) + } + }) +} + +// func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) { +// check := func(t *testing.T, user User, old User) { +// if user.Account.ID == 0 { +// t.Errorf("Account should be saved") +// } else if user.Account.UserID.Int64 != int64(user.ID) { +// t.Errorf("Account's foreign key should be saved") +// } else { +// var account Account +// db.First(&account, "id = ?", user.Account.ID) +// if account.Number != user.Account.Number { +// t.Errorf("Account's number should be same") +// } else if user.Account.Number != old.Account.Number { +// t.Errorf("Account's number should be same") +// } +// } +// } + +// t.Run("HasOne", func(t *testing.T) { +// var user = User{ +// Name: "create", +// Age: 18, +// Birthday: Now(), +// Account: Account{Number: "account-has-one-association"}, +// } + +// if err := db.Create(&user).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// check(t, user, user) + +// var user2 User +// db.Preload("Account").Find(&user2, "id = ?", user.ID) +// check(t, user2, user) +// }) + +// t.Run("HasOneForBulkInsert", func(t *testing.T) { +// var users = []User{{ +// Name: "create-1", +// Age: 18, +// Birthday: Now(), +// Account: Account{Number: "account-has-one-association-1"}, +// }, { +// Name: "create-2", +// Age: 28, +// Birthday: Now(), +// Account: Account{Number: "account-has-one-association-2"}, +// }, { +// Name: "create-3", +// Age: 38, +// Birthday: Now(), +// Account: Account{Number: "account-has-one-association-3"}, +// }} + +// if err := db.Create(&users).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// var userIDs []uint +// for _, user := range users { +// userIDs = append(userIDs, user.ID) +// check(t, user, user) +// } + +// var users2 []User +// db.Preload("Account").Find(&users2, "id IN (?)", userIDs) +// for idx, user := range users2 { +// check(t, user, users[idx]) +// } +// }) + +// t.Run("HasOneForBulkInsertPtrData", func(t *testing.T) { +// var users = []*User{{ +// Name: "create-1", +// Age: 18, +// Birthday: Now(), +// Account: Account{Number: "account-has-one-association-1"}, +// }, { +// Name: "create-2", +// Age: 28, +// Birthday: Now(), +// Account: Account{Number: "account-has-one-association-2"}, +// }, { +// Name: "create-3", +// Age: 38, +// Birthday: Now(), +// Account: Account{Number: "account-has-one-association-3"}, +// }} + +// if err := db.Create(&users).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// for _, user := range users { +// check(t, *user, *user) +// } +// }) + +// t.Run("HasOneForBulkInsertWithoutPtr", func(t *testing.T) { +// var users = []User{{ +// Name: "create-1", +// Age: 18, +// Birthday: Now(), +// Account: Account{Number: "account-has-one-association-1"}, +// }, { +// Name: "create-2", +// Age: 28, +// Birthday: Now(), +// Account: Account{Number: "account-has-one-association-2"}, +// }, { +// Name: "create-3", +// Age: 38, +// Birthday: Now(), +// Account: Account{Number: "account-has-one-association-3"}, +// }} + +// if err := db.Create(users).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// for _, user := range users { +// check(t, user, user) +// } +// }) + +// checkPet := func(t *testing.T, pet Pet, old Pet) { +// if pet.Toy.OwnerID != fmt.Sprint(pet.ID) || pet.Toy.OwnerType != "pets" { +// t.Errorf("Failed to create polymorphic has one association - toy owner id %v, owner type %v", pet.Toy.OwnerID, pet.Toy.OwnerType) +// } else { +// var toy Toy +// db.First(&toy, "owner_id = ? and owner_type = ?", pet.Toy.OwnerID, pet.Toy.OwnerType) +// if toy.Name != pet.Toy.Name { +// t.Errorf("Failed to query saved polymorphic has one association") +// } else if old.Toy.Name != pet.Toy.Name { +// t.Errorf("Failed to query saved polymorphic has one association") +// } +// } +// } + +// t.Run("PolymorphicHasOne", func(t *testing.T) { +// var pet = Pet{ +// Name: "create", +// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic"}, +// } + +// if err := db.Create(&pet).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// checkPet(t, pet, pet) + +// var pet2 Pet +// db.Preload("Toy").Find(&pet2, "id = ?", pet.ID) +// checkPet(t, pet2, pet) +// }) + +// t.Run("PolymorphicHasOneForBulkInsert", func(t *testing.T) { +// var pets = []Pet{{ +// Name: "create-1", +// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-1"}, +// }, { +// Name: "create-2", +// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-2"}, +// }, { +// Name: "create-3", +// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-3"}, +// }} + +// if err := db.Create(&pets).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// var petIDs []uint +// for _, pet := range pets { +// petIDs = append(petIDs, pet.ID) +// checkPet(t, pet, pet) +// } + +// var pets2 []Pet +// db.Preload("Toy").Find(&pets2, "id IN (?)", petIDs) +// for idx, pet := range pets2 { +// checkPet(t, pet, pets[idx]) +// } +// }) + +// t.Run("PolymorphicHasOneForBulkInsertPtrData", func(t *testing.T) { +// var pets = []*Pet{{ +// Name: "create-1", +// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-1"}, +// }, { +// Name: "create-2", +// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-2"}, +// }, { +// Name: "create-3", +// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-3"}, +// }} + +// if err := db.Create(&pets).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// for _, pet := range pets { +// checkPet(t, *pet, *pet) +// } +// }) + +// t.Run("PolymorphicHasOneForBulkInsertWithoutPtr", func(t *testing.T) { +// var pets = []*Pet{{ +// Name: "create-1", +// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-1"}, +// }, { +// Name: "create-2", +// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-2"}, +// }, { +// Name: "create-3", +// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-3"}, +// }} + +// if err := db.Create(pets).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// for _, pet := range pets { +// checkPet(t, *pet, *pet) +// } +// }) +// } + +// func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { +// check := func(t *testing.T, user User, old User) { +// for idx, pet := range user.Pets { +// if pet.ID == 0 { +// t.Errorf("Pet's foreign key should be saved") +// } + +// var result Pet +// db.First(&result, "id = ?", pet.ID) +// if result.Name != pet.Name { +// t.Errorf("Pet's name should be same") +// } else if result.UserID != user.ID { +// t.Errorf("Pet's foreign key should be saved") +// } else if result.Name != old.Pets[idx].Name { +// t.Errorf("Pet's name should be same") +// } +// } +// } + +// t.Run("HasMany", func(t *testing.T) { +// var user = User{ +// Name: "create", +// Age: 18, +// Birthday: Now(), +// Pets: []*Pet{{Name: "pet1"}, {Name: "pet2"}}, +// } + +// if err := db.Create(&user).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// check(t, user, user) + +// var user2 User +// db.Preload("Pets").Find(&user2, "id = ?", user.ID) +// check(t, user2, user) +// }) + +// t.Run("HasManyForBulkInsert", func(t *testing.T) { +// var users = []User{{ +// Name: "create-1", +// Age: 18, +// Birthday: Now(), +// Pets: []*Pet{{Name: "pet-1-1"}, {Name: "pet-1-2"}}, +// }, { +// Name: "create-2", +// Age: 28, +// Birthday: Now(), +// Pets: []*Pet{{Name: "pet-2-1"}}, +// }, { +// Name: "create-3", +// Age: 38, +// Birthday: Now(), +// Pets: []*Pet{{Name: "pet-3-1"}, {Name: "pet-3-2"}}, +// }} + +// if err := db.Create(&users).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// var userIDs []uint +// for _, user := range users { +// userIDs = append(userIDs, user.ID) +// check(t, user, user) +// } + +// var users2 []User +// db.Preload("Pets").Find(&users2, "id IN (?)", userIDs) +// for idx, user := range users2 { +// check(t, user, users[idx]) +// } +// }) + +// t.Run("HasManyForBulkInsertPtrData", func(t *testing.T) { +// var users = []*User{{ +// Name: "create-1", +// Age: 18, +// Birthday: Now(), +// Pets: []*Pet{{Name: "pet-1-1"}, {Name: "pet-1-2"}}, +// }, { +// Name: "create-2", +// Age: 28, +// Birthday: Now(), +// Pets: []*Pet{{Name: "pet-2-1"}}, +// }, { +// Name: "create-3", +// Age: 38, +// Birthday: Now(), +// Pets: []*Pet{{Name: "pet-3-1"}, {Name: "pet-3-2"}}, +// }} + +// if err := db.Create(&users).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// for _, user := range users { +// check(t, *user, *user) +// } +// }) + +// t.Run("HasManyForBulkInsertWithoutPtr", func(t *testing.T) { +// var users = []*User{{ +// Name: "create-1", +// Age: 18, +// Birthday: Now(), +// Pets: []*Pet{{Name: "pet-1-1"}, {Name: "pet-1-2"}}, +// }, { +// Name: "create-2", +// Age: 28, +// Birthday: Now(), +// Pets: []*Pet{{Name: "pet-2-1"}}, +// }, { +// Name: "create-3", +// Age: 38, +// Birthday: Now(), +// Pets: []*Pet{{Name: "pet-3-1"}, {Name: "pet-3-2"}}, +// }} + +// if err := db.Create(users).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// for _, user := range users { +// check(t, *user, *user) +// } +// }) + +// checkToy := func(t *testing.T, user User, old User) { +// for idx, toy := range user.Toys { +// if toy.ID == 0 { +// t.Fatalf("Failed to create toy #%v", idx) +// } + +// var result Toy +// db.First(&result, "id = ?", toy.ID) +// if result.Name != toy.Name { +// t.Errorf("Failed to query saved toy") +// } else if result.Name != old.Toys[idx].Name { +// t.Errorf("Failed to query saved toy") +// } else if result.OwnerID != fmt.Sprint(user.ID) || result.OwnerType != "users" { +// t.Errorf("Failed to save relation") +// } +// } +// } + +// t.Run("PolymorphicHasMany", func(t *testing.T) { +// var user = User{ +// Name: "create", +// Age: 18, +// Birthday: Now(), +// Toys: []Toy{{Name: "toy1"}, {Name: "toy2"}}, +// } + +// if err := db.Create(&user).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// checkToy(t, user, user) + +// var user2 User +// db.Preload("Toys").Find(&user2, "id = ?", user.ID) +// check(t, user2, user) +// }) + +// t.Run("PolymorphicHasManyForBulkInsert", func(t *testing.T) { +// var users = []User{{ +// Name: "create-1", +// Age: 18, +// Birthday: Now(), +// Toys: []Toy{{Name: "toy-1-1"}, {Name: "toy-1-2"}}, +// }, { +// Name: "create-2", +// Age: 28, +// Birthday: Now(), +// Toys: []Toy{{Name: "toy-2-1"}, {Name: "toy-2-2"}}, +// }, { +// Name: "create-3", +// Age: 38, +// Birthday: Now(), +// Toys: []Toy{{Name: "toy-3-1"}, {Name: "toy-3-2"}}, +// }} + +// if err := db.Create(&users).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// var userIDs []uint +// for _, user := range users { +// userIDs = append(userIDs, user.ID) +// checkToy(t, user, user) +// } + +// var users2 []User +// db.Preload("Toys").Find(&users2, "id IN (?)", userIDs) +// for idx, user := range users2 { +// check(t, user, users[idx]) +// } +// }) + +// t.Run("PolymorphicHasManyForBulkInsertPtrData", func(t *testing.T) { +// var users = []*User{{ +// Name: "create-1", +// Age: 18, +// Birthday: Now(), +// Toys: []Toy{{Name: "toy-1-1"}, {Name: "toy-1-2"}}, +// }, { +// Name: "create-2", +// Age: 28, +// Birthday: Now(), +// Toys: []Toy{{Name: "toy-2-1"}, {Name: "toy-2-2"}}, +// }, { +// Name: "create-3", +// Age: 38, +// Birthday: Now(), +// Toys: []Toy{{Name: "toy-3-1"}, {Name: "toy-3-2"}}, +// }} + +// if err := db.Create(&users).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// for _, user := range users { +// checkToy(t, *user, *user) +// } +// }) + +// t.Run("PolymorphicHasManyForBulkInsertWithoutPtr", func(t *testing.T) { +// var users = []User{{ +// Name: "create-1", +// Age: 18, +// Birthday: Now(), +// Toys: []Toy{{Name: "toy-1-1"}, {Name: "toy-1-2"}}, +// }, { +// Name: "create-2", +// Age: 28, +// Birthday: Now(), +// Toys: []Toy{{Name: "toy-2-1"}, {Name: "toy-2-2"}}, +// }, { +// Name: "create-3", +// Age: 38, +// Birthday: Now(), +// Toys: []Toy{{Name: "toy-3-1"}, {Name: "toy-3-2"}}, +// }} + +// if err := db.Create(users).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// for _, user := range users { +// checkToy(t, user, user) +// } +// }) +// } + +// func TestCreateMany2ManyAssociations(t *testing.T, db *gorm.DB) { +// check := func(t *testing.T, user User, old User) { +// for idx, language := range user.Languages { +// var result Language +// db.First(&result, "code = ?", language.Code) +// if result.Name != language.Name { +// t.Errorf("Language's name should be same") +// } else if result.Name != old.Languages[idx].Name { +// t.Errorf("Language's name should be same") +// } +// } + +// for idx, f := range user.Friends { +// if f.ID == 0 { +// t.Errorf("Friend's foreign key should be saved") +// } + +// var result User +// db.First(&result, "id = ?", f.ID) +// if result.Name != f.Name { +// t.Errorf("Friend's name should be same") +// } else if result.Name != old.Friends[idx].Name { +// t.Errorf("Language's name should be same") +// } +// } +// } + +// db.Create(&[]Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}}) + +// t.Run("Many2Many", func(t *testing.T) { +// var user = User{ +// Name: "create", +// Age: 18, +// Birthday: Now(), +// Languages: []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}}, +// Friends: []*User{{Name: "friend-1"}, {Name: "friend-2"}}, +// } + +// if err := db.Create(&user).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// check(t, user, user) + +// var user2 User +// db.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) +// check(t, user2, user) +// }) + +// t.Run("Many2ManyForBulkInsert", func(t *testing.T) { +// var users = []User{ +// { +// Name: "create-1", +// Age: 18, +// Birthday: Now(), +// Languages: []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}}, +// Friends: []*User{{Name: "friend-1-1"}, {Name: "friend-1-2"}}, +// }, +// { +// Name: "create-2", +// Age: 18, +// Birthday: Now(), +// Languages: []Language{{Code: "zh-CN", Name: "Chinese"}}, +// Friends: []*User{{Name: "friend-2-1"}}, +// }, +// { +// Name: "create-3", +// Age: 18, +// Birthday: Now(), +// Languages: []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}}, +// Friends: []*User{{Name: "friend-3-1"}, {Name: "friend-3-2"}, {Name: "friend-3-3"}}, +// }, +// } + +// if err := db.Create(&users).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// var userIDs []uint +// for _, user := range users { +// userIDs = append(userIDs, user.ID) +// check(t, user, user) +// } + +// var users2 []User +// db.Preload("Languages").Preload("Friends").Find(&users2, "id IN (?)", userIDs) +// for idx, user := range users2 { +// check(t, user, users[idx]) +// } +// }) +// } diff --git a/tests/main_test.go b/tests/main_test.go new file mode 100644 index 00000000..7324ed9e --- /dev/null +++ b/tests/main_test.go @@ -0,0 +1,95 @@ +package tests + +import ( + "log" + "math/rand" + "os" + "path/filepath" + "testing" + "time" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/dialects/mssql" + "github.com/jinzhu/gorm/dialects/mysql" + "github.com/jinzhu/gorm/dialects/postgres" + "github.com/jinzhu/gorm/dialects/sqlite" + "github.com/jinzhu/gorm/logger" +) + +var DB *gorm.DB + +func TestMain(m *testing.M) { + var err error + DB, err = OpenTestConnection() + if err == nil { + RunMigrations() + m.Run() + } else { + log.Printf("failed to connect database, got error %v\n", err) + os.Exit(1) + } +} + +func RunMigrations() { + var err error + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) + + if err = DB.Migrator().DropTable(allModels...); err != nil { + log.Printf("Failed to drop table, got error %v\n", err) + os.Exit(1) + } + + if err = DB.AutoMigrate(allModels...); err != nil { + log.Printf("Failed to auto migrate, but got error %v\n", err) + os.Exit(1) + } + + for _, m := range allModels { + if !DB.Migrator().HasTable(m) { + log.Printf("Failed to create table for %#v\n", m) + os.Exit(1) + } + } +} + +func OpenTestConnection() (db *gorm.DB, err error) { + dbDSN := os.Getenv("GORM_DSN") + switch os.Getenv("GORM_DIALECT") { + case "mysql": + log.Println("testing mysql...") + if dbDSN == "" { + dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" + } + db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) + case "postgres": + log.Println("testing postgres...") + if dbDSN == "" { + dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" + } + db, err = gorm.Open(postgres.Open(dbDSN), &gorm.Config{}) + case "mssql": + // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; + // CREATE DATABASE gorm; + // USE gorm; + // CREATE USER gorm FROM LOGIN gorm; + // sp_changedbowner 'gorm'; + log.Println("testing mssql...") + if dbDSN == "" { + dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + } + db, err = gorm.Open(mssql.Open(dbDSN), &gorm.Config{}) + default: + log.Println("testing sqlite3...") + db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) + } + + if debug := os.Getenv("DEBUG"); debug == "true" { + db.Logger.LogMode(logger.Info) + } else if debug == "false" { + db.Logger.LogMode(logger.Error) + } + + return +} diff --git a/tests/tests.go b/tests/tests.go index 87005a71..809d2e39 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -13,7 +13,7 @@ func Now() *time.Time { } func RunTestsSuit(t *testing.T, db *gorm.DB) { - TestCreate(t, db) + // TestCreate(t, db) TestFind(t, db) TestUpdate(t, db) TestDelete(t, db) From e64785573d533ba6f43e871fb778b0734bc22da0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 23 May 2020 16:38:13 +0800 Subject: [PATCH 313/881] Add helper methods to check user, pet --- migrator/migrator.go | 6 +- tests/create.go | 132 ++++++++++++++++++++++-- tests/create_test.go | 239 ++++++++++++++++--------------------------- tests/main_test.go | 2 + tests/utils.go | 9 ++ 5 files changed, 228 insertions(+), 160 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index f581f714..cab266a3 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -542,7 +542,11 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i } for _, value := range values { - parseDependence(value, true) + if v, ok := value.(string); ok { + results = append(results, v) + } else { + parseDependence(value, true) + } } for _, name := range modelNames { diff --git a/tests/create.go b/tests/create.go index ec57b8ee..09464674 100644 --- a/tests/create.go +++ b/tests/create.go @@ -2,6 +2,7 @@ package tests import ( "strconv" + "testing" "time" ) @@ -16,7 +17,7 @@ type Config struct { Friends int } -func GetUser(name string, config Config) User { +func GetUser(name string, config Config) *User { var ( birthday = time.Now() user = User{ @@ -43,23 +44,136 @@ func GetUser(name string, config Config) User { } if config.Manager { - manager := GetUser(name+"_manager", Config{}) - user.Manager = &manager + user.Manager = GetUser(name+"_manager", Config{}) } for i := 0; i < config.Team; i++ { - user.Team = append(user.Team, GetUser(name+"_team_"+strconv.Itoa(i+1), Config{})) + user.Team = append(user.Team, *GetUser(name+"_team_"+strconv.Itoa(i+1), Config{})) } for i := 0; i < config.Languages; i++ { - name := "Locale_" + strconv.Itoa(i+0) - user.Languages = append(user.Languages, Language{Code: name, Name: name}) + name := name + "_locale_" + strconv.Itoa(i+0) + language := Language{Code: name, Name: name} + DB.Create(&language) + user.Languages = append(user.Languages, language) } for i := 0; i < config.Friends; i++ { - f := GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{}) - user.Friends = append(user.Friends, &f) + user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{})) } - return user + return &user +} + +func CheckPet(t *testing.T, pet Pet, expect Pet) { + AssertObjEqual(t, pet, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") + + AssertObjEqual(t, pet.Toy, expect.Toy, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "OwnerID", "OwnerType") + + if expect.Toy.Name != "" && expect.Toy.OwnerType != "pets" { + t.Errorf("toys's OwnerType, expect: %v, got %v", "pets", expect.Toy.OwnerType) + } +} + +func CheckUser(t *testing.T, user User, expect User) { + if user.ID != 0 { + var newUser User + if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + } + + AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + + t.Run("Account", func(t *testing.T) { + AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") + + if user.Account.Number != "" { + if !user.Account.UserID.Valid { + t.Errorf("Account's foreign key should be saved") + } else { + var account Account + DB.First(&account, "user_id = ?", user.ID) + AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") + } + } + }) + + t.Run("Pets", func(t *testing.T) { + if len(user.Pets) != len(expect.Pets) { + t.Errorf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets)) + } + + for idx, pet := range user.Pets { + if pet == nil || expect.Pets[idx] == nil { + t.Errorf("pets#%v should equal, expect: %v, got %v", idx, expect.Pets[idx], pet) + } else { + CheckPet(t, *pet, *expect.Pets[idx]) + } + } + }) + + t.Run("Toys", func(t *testing.T) { + if len(user.Toys) != len(expect.Toys) { + t.Errorf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys)) + } + + for idx, toy := range user.Toys { + if toy.OwnerType != "users" { + t.Errorf("toys's OwnerType, expect: %v, got %v", "users", toy.OwnerType) + } + + AssertObjEqual(t, toy, expect.Toys[idx], "ID", "CreatedAt", "UpdatedAt", "Name", "OwnerID", "OwnerType") + } + }) + + t.Run("Company", func(t *testing.T) { + AssertObjEqual(t, user.Company, expect.Company, "ID", "Name") + }) + + t.Run("Manager", func(t *testing.T) { + if user.Manager != nil { + if user.ManagerID == nil { + t.Errorf("Manager's foreign key should be saved") + } else { + var manager User + DB.First(&manager, "id = ?", *user.ManagerID) + AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + } else if user.ManagerID != nil { + t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) + } + }) + + t.Run("Team", func(t *testing.T) { + if len(user.Team) != len(expect.Team) { + t.Errorf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team)) + } + + for idx, team := range user.Team { + AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + }) + + t.Run("Languages", func(t *testing.T) { + if len(user.Languages) != len(expect.Languages) { + t.Errorf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages)) + } + + for idx, language := range user.Languages { + AssertObjEqual(t, language, expect.Languages[idx], "Code", "Name") + } + }) + + t.Run("Friends", func(t *testing.T) { + if len(user.Friends) != len(expect.Friends) { + t.Errorf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends)) + } + + for idx, friend := range user.Friends { + AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + }) } diff --git a/tests/create_test.go b/tests/create_test.go index 471cecf6..9241e0a6 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -7,7 +7,7 @@ import ( ) func TestCreate(t *testing.T) { - var user = GetUser("create", Config{}) + var user = *GetUser("create", Config{}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -27,165 +27,104 @@ func TestCreate(t *testing.T) { var newUser User if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil { - t.Errorf("errors happened when query: %v", err) + t.Fatalf("errors happened when query: %v", err) } else { - AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday") + CheckUser(t, newUser, user) } } -func TestCreateWithBelongsToAssociations(t *testing.T) { - check := func(t *testing.T, user User, old User) { - if old.Company.Name != "" { - if user.CompanyID == nil { - t.Errorf("Company's foreign key should be saved") - } else { - var company Company - DB.First(&company, "id = ?", *user.CompanyID) - if company.Name != old.Company.Name { - t.Errorf("Company's name should be same") - } else if user.Company.Name != old.Company.Name { - t.Errorf("Company's name should be same") - } - } - } else if user.CompanyID != nil { - t.Errorf("Company should not be created for zero value, got: %+v", user.CompanyID) - } +func TestCreateWithAssociations(t *testing.T) { + var user = *GetUser("create_with_belongs_to", Config{ + Account: true, + Pets: 2, + Toys: 3, + Company: true, + Manager: true, + Team: 4, + Languages: 3, + Friends: 1, + }) - if old.Manager != nil { - if user.ManagerID == nil { - t.Errorf("Manager's foreign key should be saved") - } else { - var manager User - DB.First(&manager, "id = ?", *user.ManagerID) - if manager.Name != user.Manager.Name { - t.Errorf("Manager's name should be same") - } else if user.Manager.Name != old.Manager.Name { - t.Errorf("Manager's name should be same") - } - } - } else if user.ManagerID != nil { - t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) - } + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) } - t.Run("Struct", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association"}, - Manager: &User{Name: "manager-belongs-to-association"}, - } + CheckUser(t, user, user) - if err := DB.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - check(t, user, user) - - var user2 User - DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID) - check(t, user2, user) - }) - - t.Run("BulkInsert", func(t *testing.T) { - var users = []User{{ - Name: "create-1", - Age: 18, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association-1"}, - Manager: &User{Name: "manager-belongs-to-association-1"}, - }, { - Name: "create-2", - Age: 28, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association-2"}, - }, { - Name: "create-3", - Age: 38, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association-3"}, - Manager: &User{Name: "manager-belongs-to-association-3"}, - }} - - if err := DB.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - var userIDs []uint - for _, user := range users { - userIDs = append(userIDs, user.ID) - check(t, user, user) - } - - t.Run("Preload", func(t *testing.T) { - var users2 []User - DB.Preload("Company").Preload("Manager").Find(&users2, "id IN ?", userIDs) - for idx, user := range users2 { - check(t, user, users[idx]) - } - }) - }) - - t.Run("BulkInsertPtrData", func(t *testing.T) { - var users = []*User{{ - Name: "create-1", - Age: 18, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association-1"}, - Manager: &User{Name: "manager-belongs-to-association-1"}, - }, { - Name: "create-2", - Age: 28, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association-2"}, - }, { - Name: "create-3", - Age: 38, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association-3"}, - Manager: &User{Name: "manager-belongs-to-association-3"}, - }} - - if err := DB.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - for _, user := range users { - check(t, *user, *user) - } - }) - - t.Run("BulkInsertWithoutPtr", func(t *testing.T) { - var users = []*User{{ - Name: "create-1", - Age: 18, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association-1"}, - Manager: &User{Name: "manager-belongs-to-association-1"}, - }, { - Name: "create-2", - Age: 28, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association-2"}, - }, { - Name: "create-3", - Age: 38, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association-3"}, - Manager: &User{Name: "manager-belongs-to-association-3"}, - }} - - if err := DB.Create(users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - for _, user := range users { - check(t, *user, *user) - } - }) + var user2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) } +// func TestBulkCreateWithBelongsTo(t *testing.T) { +// users := []User{ +// *GetUser("create_with_belongs_to_1", Config{Company: true, Manager: true}), +// *GetUser("create_with_belongs_to_2", Config{Company: true, Manager: false}), +// *GetUser("create_with_belongs_to_3", Config{Company: false, Manager: true}), +// *GetUser("create_with_belongs_to_4", Config{Company: true, Manager: true}), +// } + +// if err := DB.Create(&users).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// var userIDs []uint +// for _, user := range users { +// userIDs = append(userIDs, user.ID) +// CheckUser(t, user, user) +// } + +// var users2 []User +// DB.Preload("Company").Preload("Manager").Find(&users2, "id IN ?", userIDs) +// for idx, user := range users2 { +// CheckUser(t, user, users[idx]) +// } +// } + +// func TestBulkCreatePtrDataWithBelongsTo(t *testing.T) { +// users := []*User{ +// GetUser("create_with_belongs_to_1", Config{Company: true, Manager: true}), +// GetUser("create_with_belongs_to_2", Config{Company: true, Manager: false}), +// GetUser("create_with_belongs_to_3", Config{Company: false, Manager: true}), +// GetUser("create_with_belongs_to_4", Config{Company: true, Manager: true}), +// } + +// if err := DB.Create(&users).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// var userIDs []uint +// for _, user := range users { +// userIDs = append(userIDs, user.ID) +// CheckUser(t, *user, *user) +// } + +// var users2 []User +// DB.Preload("Company").Preload("Manager").Find(&users2, "id IN ?", userIDs) +// for idx, user := range users2 { +// CheckUser(t, user, *users[idx]) +// } +// } + +// func TestBulkCreateWithoutPtrWithBelongsTo(t *testing.T) { +// users := []*User{ +// GetUser("create_with_belongs_to_1", Config{Company: true, Manager: true}), +// GetUser("create_with_belongs_to_2", Config{Company: true, Manager: false}), +// GetUser("create_with_belongs_to_3", Config{Company: false, Manager: true}), +// GetUser("create_with_belongs_to_4", Config{Company: true, Manager: true}), +// } + +// if err := DB.Create(&users).Error; err != nil { +// t.Fatalf("errors happened when create: %v", err) +// } + +// var userIDs []uint +// for _, user := range users { +// userIDs = append(userIDs, user.ID) +// CheckUser(t, *user, *user) +// } +// } + // func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) { // check := func(t *testing.T, user User, old User) { // if user.Account.ID == 0 { diff --git a/tests/main_test.go b/tests/main_test.go index 7324ed9e..3e329454 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -36,6 +36,8 @@ func RunMigrations() { rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) + DB.Migrator().DropTable("user_friends", "user_speak") + if err = DB.Migrator().DropTable(allModels...); err != nil { log.Printf("Failed to drop table, got error %v\n", err) os.Exit(1) diff --git a/tests/utils.go b/tests/utils.go index 9d61c422..cb4e4fcc 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -29,6 +29,15 @@ func AssertEqual(t *testing.T, got, expect interface{}) { } } + if got == expect { + return + } + + if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() { + t.Errorf("expect: %+v, got %+v", expect, got) + return + } + if got != nil { got = reflect.Indirect(reflect.ValueOf(got)).Interface() } From 2ca4e91d88a3392d4d3de8cebd52360e872b8b9c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 23 May 2020 18:38:27 +0800 Subject: [PATCH 314/881] Fix LastInsertID with string primary key --- callbacks/associations.go | 3 +- callbacks/create.go | 34 +++++++------- tests/main_test.go | 97 -------------------------------------- tests/tests.go | 99 ++++++++++++++++++++++++++++++++++----- 4 files changed, 106 insertions(+), 127 deletions(-) delete mode 100644 tests/main_test.go diff --git a/callbacks/associations.go b/callbacks/associations.go index a0c296e3..96d9ce22 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -234,7 +234,6 @@ func SaveAfterAssociations(db *gorm.DB) { ref.ForeignKey.Set(joinValue, fv) } } - joins = reflect.Append(joins, joinValue) } @@ -277,7 +276,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if joins.Len() > 0 { - db.Session(&gorm.Session{}).Create(joins.Interface()) + db.Session(&gorm.Session{}).Debug().Create(joins.Interface()) } } } diff --git a/callbacks/create.go b/callbacks/create.go index 9dc8dc67..ff88bc0e 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -56,25 +56,27 @@ func Create(config *Config) func(db *gorm.DB) { if err == nil { if db.Statement.Schema != nil { - if insertID, err := result.LastInsertId(); err == nil { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID-- - } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID++ + if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { + if insertID, err := result.LastInsertId(); err == nil { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID-- + } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID++ + } } + case reflect.Struct: + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) } - case reflect.Struct: - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } else { + db.AddError(err) } - } else { - db.AddError(err) } } db.RowsAffected, _ = result.RowsAffected() diff --git a/tests/main_test.go b/tests/main_test.go deleted file mode 100644 index 3e329454..00000000 --- a/tests/main_test.go +++ /dev/null @@ -1,97 +0,0 @@ -package tests - -import ( - "log" - "math/rand" - "os" - "path/filepath" - "testing" - "time" - - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/dialects/mssql" - "github.com/jinzhu/gorm/dialects/mysql" - "github.com/jinzhu/gorm/dialects/postgres" - "github.com/jinzhu/gorm/dialects/sqlite" - "github.com/jinzhu/gorm/logger" -) - -var DB *gorm.DB - -func TestMain(m *testing.M) { - var err error - DB, err = OpenTestConnection() - if err == nil { - RunMigrations() - m.Run() - } else { - log.Printf("failed to connect database, got error %v\n", err) - os.Exit(1) - } -} - -func RunMigrations() { - var err error - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} - rand.Seed(time.Now().UnixNano()) - rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) - - DB.Migrator().DropTable("user_friends", "user_speak") - - if err = DB.Migrator().DropTable(allModels...); err != nil { - log.Printf("Failed to drop table, got error %v\n", err) - os.Exit(1) - } - - if err = DB.AutoMigrate(allModels...); err != nil { - log.Printf("Failed to auto migrate, but got error %v\n", err) - os.Exit(1) - } - - for _, m := range allModels { - if !DB.Migrator().HasTable(m) { - log.Printf("Failed to create table for %#v\n", m) - os.Exit(1) - } - } -} - -func OpenTestConnection() (db *gorm.DB, err error) { - dbDSN := os.Getenv("GORM_DSN") - switch os.Getenv("GORM_DIALECT") { - case "mysql": - log.Println("testing mysql...") - if dbDSN == "" { - dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" - } - db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) - case "postgres": - log.Println("testing postgres...") - if dbDSN == "" { - dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" - } - db, err = gorm.Open(postgres.Open(dbDSN), &gorm.Config{}) - case "mssql": - // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; - // CREATE DATABASE gorm; - // USE gorm; - // CREATE USER gorm FROM LOGIN gorm; - // sp_changedbowner 'gorm'; - log.Println("testing mssql...") - if dbDSN == "" { - dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" - } - db, err = gorm.Open(mssql.Open(dbDSN), &gorm.Config{}) - default: - log.Println("testing sqlite3...") - db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) - } - - if debug := os.Getenv("DEBUG"); debug == "true" { - db.Logger.LogMode(logger.Info) - } else if debug == "false" { - db.Logger.LogMode(logger.Error) - } - - return -} diff --git a/tests/tests.go b/tests/tests.go index 809d2e39..1ff700c5 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -1,24 +1,99 @@ package tests import ( - "testing" + "log" + "math/rand" + "os" + "path/filepath" "time" "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/dialects/mssql" + "github.com/jinzhu/gorm/dialects/mysql" + "github.com/jinzhu/gorm/dialects/postgres" + "github.com/jinzhu/gorm/dialects/sqlite" + "github.com/jinzhu/gorm/logger" ) +var DB *gorm.DB + +func init() { + var err error + if DB, err = OpenTestConnection(); err == nil { + RunMigrations() + } else { + log.Printf("failed to connect database, got error %v\n", err) + os.Exit(1) + } +} + +func OpenTestConnection() (db *gorm.DB, err error) { + dbDSN := os.Getenv("GORM_DSN") + switch os.Getenv("GORM_DIALECT") { + case "mysql": + log.Println("testing mysql...") + if dbDSN == "" { + dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" + } + db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) + case "postgres": + log.Println("testing postgres...") + if dbDSN == "" { + dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" + } + db, err = gorm.Open(postgres.Open(dbDSN), &gorm.Config{}) + case "mssql": + // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; + // CREATE DATABASE gorm; + // USE gorm; + // CREATE USER gorm FROM LOGIN gorm; + // sp_changedbowner 'gorm'; + log.Println("testing mssql...") + if dbDSN == "" { + dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + } + db, err = gorm.Open(mssql.Open(dbDSN), &gorm.Config{}) + default: + log.Println("testing sqlite3...") + db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) + } + + if debug := os.Getenv("DEBUG"); debug == "true" { + db.Logger.LogMode(logger.Info) + } else if debug == "false" { + db.Logger.LogMode(logger.Error) + } + + return +} + +func RunMigrations() { + var err error + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) + + DB.Migrator().DropTable("user_friends", "user_speak") + + if err = DB.Migrator().DropTable(allModels...); err != nil { + log.Printf("Failed to drop table, got error %v\n", err) + os.Exit(1) + } + + if err = DB.AutoMigrate(allModels...); err != nil { + log.Printf("Failed to auto migrate, but got error %v\n", err) + os.Exit(1) + } + + for _, m := range allModels { + if !DB.Migrator().HasTable(m) { + log.Printf("Failed to create table for %#v\n", m) + os.Exit(1) + } + } +} + func Now() *time.Time { now := time.Now() return &now } - -func RunTestsSuit(t *testing.T, db *gorm.DB) { - // TestCreate(t, db) - TestFind(t, db) - TestUpdate(t, db) - TestDelete(t, db) - - TestGroupBy(t, db) - TestJoins(t, db) - TestAssociations(t, db) -} From 5ec4fee79704878e76cd591d5d516b9d55fe987e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 23 May 2020 21:03:28 +0800 Subject: [PATCH 315/881] Don't preload if foreign keys zero --- association.go | 12 ++++++---- callbacks/associations.go | 2 +- callbacks/preload.go | 24 ++++++++++++------- schema/utils.go | 39 +++++++++++++++++------------- tests/create.go | 2 +- tests/create_test.go | 50 +++++++++++++++++++++------------------ tests/tests.go | 2 +- utils/utils.go | 11 ++++----- 8 files changed, 79 insertions(+), 63 deletions(-) diff --git a/association.go b/association.go index ab9090ac..abcae47d 100644 --- a/association.go +++ b/association.go @@ -101,8 +101,10 @@ func (association *Association) Replace(values ...interface{}) error { } _, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) - column, queryValues := schema.ToQueryValues(foreignKeys, values) - association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap) + if len(values) > 0 { + column, queryValues := schema.ToQueryValues(foreignKeys, values) + association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap) + } case schema.Many2Many: var primaryFields, relPrimaryFields []*schema.Field var foreignKeys, relForeignKeys []string @@ -200,13 +202,13 @@ func (association *Association) Delete(values ...interface{}) error { if _, zero := rel.Field.ValueOf(data); !zero { fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) - fieldValues := make([]reflect.Value, len(relFields)) + fieldValues := make([]interface{}, len(relFields)) switch fieldValue.Kind() { case reflect.Slice, reflect.Array: validFieldValues := reflect.Zero(rel.Field.FieldType) for i := 0; i < fieldValue.Len(); i++ { for idx, field := range relFields { - fieldValues[idx] = field.ReflectValueOf(fieldValue.Index(i)) + fieldValues[idx], _ = field.ValueOf(fieldValue.Index(i)) } if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; !ok { @@ -217,7 +219,7 @@ func (association *Association) Delete(values ...interface{}) error { rel.Field.Set(data, validFieldValues) case reflect.Struct: for idx, field := range relFields { - fieldValues[idx] = field.ReflectValueOf(data) + fieldValues[idx], _ = field.ValueOf(data) } if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; ok { rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType)) diff --git a/callbacks/associations.go b/callbacks/associations.go index 96d9ce22..ef040b71 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -276,7 +276,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if joins.Len() > 0 { - db.Session(&gorm.Session{}).Debug().Create(joins.Interface()) + db.Session(&gorm.Session{}).Create(joins.Interface()) } } } diff --git a/callbacks/preload.go b/callbacks/preload.go index 9f23a2ca..7e3810b5 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -42,22 +42,25 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } } - joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, joinForeignFields) + joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) + if len(joinForeignValues) == 0 { + return + } joinResults := rel.JoinTable.MakeSlice().Elem() column, values := schema.ToQueryValues(joinForeignKeys, joinForeignValues) tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()) // convert join identity map to relation identity map - fieldValues := make([]reflect.Value, len(foreignFields)) - joinFieldValues := make([]reflect.Value, len(joinForeignFields)) + fieldValues := make([]interface{}, len(foreignFields)) + joinFieldValues := make([]interface{}, len(joinForeignFields)) for i := 0; i < joinResults.Len(); i++ { - for idx, field := range foreignFields { - fieldValues[idx] = field.ReflectValueOf(joinResults.Index(i)) + for idx, field := range joinForeignFields { + fieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) } - for idx, field := range joinForeignFields { - joinFieldValues[idx] = field.ReflectValueOf(joinResults.Index(i)) + for idx, field := range joinRelForeignFields { + joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) } if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { @@ -82,16 +85,19 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) + if len(foreignValues) == 0 { + return + } } reflectResults := rel.FieldSchema.MakeSlice().Elem() column, values := schema.ToQueryValues(relForeignKeys, foreignValues) tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), conds...) - fieldValues := make([]reflect.Value, len(foreignFields)) + fieldValues := make([]interface{}, len(foreignFields)) for i := 0; i < reflectResults.Len(); i++ { for idx, field := range relForeignFields { - fieldValues[idx] = field.ReflectValueOf(reflectResults.Index(i)) + fieldValues[idx], _ = field.ValueOf(reflectResults.Index(i)) } for _, data := range identityMap[utils.ToStringKey(fieldValues...)] { diff --git a/schema/utils.go b/schema/utils.go index 72bd149c..ead83cab 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -89,9 +89,9 @@ func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (refle // GetIdentityFieldValuesMap get identity map from fields func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { var ( - fieldValues = make([]reflect.Value, len(fields)) - results = [][]interface{}{} - dataResults = map[string][]reflect.Value{} + results = [][]interface{}{} + dataResults = map[string][]reflect.Value{} + notZero, zero bool ) switch reflectValue.Kind() { @@ -99,28 +99,33 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map results = [][]interface{}{make([]interface{}, len(fields))} for idx, field := range fields { - fieldValues[idx] = field.ReflectValueOf(reflectValue) - results[0][idx] = fieldValues[idx].Interface() + results[0][idx], zero = field.ValueOf(reflectValue) + notZero = notZero || !zero } - dataResults[utils.ToStringKey(fieldValues...)] = []reflect.Value{reflectValue} + if !notZero { + return nil, nil + } + + dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue} case reflect.Slice, reflect.Array: + fieldValues := make([]interface{}, len(fields)) + for i := 0; i < reflectValue.Len(); i++ { + notZero = false for idx, field := range fields { - fieldValues[idx] = field.ReflectValueOf(reflectValue.Index(i)) + fieldValues[idx], zero = field.ValueOf(reflectValue.Index(idx)) + notZero = notZero || !zero } - dataKey := utils.ToStringKey(fieldValues...) - if _, ok := dataResults[dataKey]; !ok { - result := make([]interface{}, len(fieldValues)) - for idx, fieldValue := range fieldValues { - result[idx] = fieldValue.Interface() + if notZero { + dataKey := utils.ToStringKey(fieldValues...) + if _, ok := dataResults[dataKey]; !ok { + results = append(results, fieldValues[:]) + dataResults[dataKey] = []reflect.Value{reflectValue.Index(i)} + } else { + dataResults[dataKey] = append(dataResults[dataKey], reflectValue.Index(i)) } - results = append(results, result) - - dataResults[dataKey] = []reflect.Value{reflectValue.Index(i)} - } else { - dataResults[dataKey] = append(dataResults[dataKey], reflectValue.Index(i)) } } } diff --git a/tests/create.go b/tests/create.go index 09464674..0d85a29e 100644 --- a/tests/create.go +++ b/tests/create.go @@ -52,7 +52,7 @@ func GetUser(name string, config Config) *User { } for i := 0; i < config.Languages; i++ { - name := name + "_locale_" + strconv.Itoa(i+0) + name := name + "_locale_" + strconv.Itoa(i+1) language := Language{Code: name, Name: name} DB.Create(&language) user.Languages = append(user.Languages, language) diff --git a/tests/create_test.go b/tests/create_test.go index 9241e0a6..ef9203aa 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -34,7 +34,7 @@ func TestCreate(t *testing.T) { } func TestCreateWithAssociations(t *testing.T) { - var user = *GetUser("create_with_belongs_to", Config{ + var user = *GetUser("create_with_associations", Config{ Account: true, Pets: 2, Toys: 3, @@ -52,34 +52,38 @@ func TestCreateWithAssociations(t *testing.T) { CheckUser(t, user, user) var user2 User - DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Find(&user2, "id = ?", user.ID) + DB.Debug().Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) } -// func TestBulkCreateWithBelongsTo(t *testing.T) { -// users := []User{ -// *GetUser("create_with_belongs_to_1", Config{Company: true, Manager: true}), -// *GetUser("create_with_belongs_to_2", Config{Company: true, Manager: false}), -// *GetUser("create_with_belongs_to_3", Config{Company: false, Manager: true}), -// *GetUser("create_with_belongs_to_4", Config{Company: true, Manager: true}), -// } +func TestBulkCreateWithAssociations(t *testing.T) { + users := []User{ + *GetUser("bulk_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), + *GetUser("bulk_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), + *GetUser("bulk_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), + *GetUser("bulk_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), + *GetUser("bulk_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), + *GetUser("bulk_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), + *GetUser("bulk_7", Config{Account: true, Pets: 1, Toys: 3, Company: true, Manager: true, Team: 4, Languages: 3, Friends: 1}), + *GetUser("bulk_8", Config{Account: false, Pets: 0, Toys: 0, Company: false, Manager: false, Team: 0, Languages: 0, Friends: 0}), + } -// if err := DB.Create(&users).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } -// var userIDs []uint -// for _, user := range users { -// userIDs = append(userIDs, user.ID) -// CheckUser(t, user, user) -// } + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + CheckUser(t, user, user) + } -// var users2 []User -// DB.Preload("Company").Preload("Manager").Find(&users2, "id IN ?", userIDs) -// for idx, user := range users2 { -// CheckUser(t, user, users[idx]) -// } -// } + var users2 []User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").Find(&users2, "id IN ?", userIDs) + for idx, user := range users2 { + CheckUser(t, user, users[idx]) + } +} // func TestBulkCreatePtrDataWithBelongsTo(t *testing.T) { // users := []*User{ diff --git a/tests/tests.go b/tests/tests.go index 1ff700c5..2b2bfc20 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -73,7 +73,7 @@ func RunMigrations() { rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) - DB.Migrator().DropTable("user_friends", "user_speak") + DB.Migrator().DropTable("user_friends", "user_speaks") if err = DB.Migrator().DropTable(allModels...); err != nil { log.Printf("Failed to drop table, got error %v\n", err) diff --git a/utils/utils.go b/utils/utils.go index 5d6c9da2..3924e69e 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -41,16 +41,15 @@ func CheckTruth(val interface{}) bool { return !reflect.ValueOf(val).IsZero() } -func ToStringKey(values ...reflect.Value) string { +func ToStringKey(values ...interface{}) string { results := make([]string, len(values)) for idx, value := range values { - rv := reflect.Indirect(value).Interface() - if valuer, ok := rv.(driver.Valuer); ok { - rv, _ = valuer.Value() + if valuer, ok := value.(driver.Valuer); ok { + value, _ = valuer.Value() } - switch v := rv.(type) { + switch v := value.(type) { case string: results[idx] = v case []byte: @@ -58,7 +57,7 @@ func ToStringKey(values ...reflect.Value) string { case uint: results[idx] = strconv.FormatUint(uint64(v), 10) default: - results[idx] = fmt.Sprint(v) + results[idx] = fmt.Sprint(reflect.Indirect(reflect.ValueOf(v)).Interface()) } } From 590f622674a8b956f2b0a0069211b860a12f585a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 23 May 2020 21:35:12 +0800 Subject: [PATCH 316/881] Refactor create tests --- schema/utils.go | 5 +- tests/create.go | 9 + tests/create_test.go | 716 ++++++------------------------------------- utils/utils.go | 10 +- 4 files changed, 117 insertions(+), 623 deletions(-) diff --git a/schema/utils.go b/schema/utils.go index ead83cab..c47f1984 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -109,12 +109,11 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue} case reflect.Slice, reflect.Array: - fieldValues := make([]interface{}, len(fields)) - for i := 0; i < reflectValue.Len(); i++ { + fieldValues := make([]interface{}, len(fields)) notZero = false for idx, field := range fields { - fieldValues[idx], zero = field.ValueOf(reflectValue.Index(idx)) + fieldValues[idx], zero = field.ValueOf(reflectValue.Index(i)) notZero = notZero || !zero } diff --git a/tests/create.go b/tests/create.go index 0d85a29e..6e5dd2c5 100644 --- a/tests/create.go +++ b/tests/create.go @@ -66,6 +66,15 @@ func GetUser(name string, config Config) *User { } func CheckPet(t *testing.T, pet Pet, expect Pet) { + if pet.ID != 0 { + var newPet Pet + if err := DB.Where("id = ?", pet.ID).First(&newPet).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, newPet, pet, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") + } + } + AssertObjEqual(t, pet, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") AssertObjEqual(t, pet.Toy, expect.Toy, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "OwnerID", "OwnerType") diff --git a/tests/create_test.go b/tests/create_test.go index ef9203aa..5b859e99 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -52,7 +52,7 @@ func TestCreateWithAssociations(t *testing.T) { CheckUser(t, user, user) var user2 User - DB.Debug().Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) } @@ -85,620 +85,100 @@ func TestBulkCreateWithAssociations(t *testing.T) { } } -// func TestBulkCreatePtrDataWithBelongsTo(t *testing.T) { -// users := []*User{ -// GetUser("create_with_belongs_to_1", Config{Company: true, Manager: true}), -// GetUser("create_with_belongs_to_2", Config{Company: true, Manager: false}), -// GetUser("create_with_belongs_to_3", Config{Company: false, Manager: true}), -// GetUser("create_with_belongs_to_4", Config{Company: true, Manager: true}), -// } - -// if err := DB.Create(&users).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// var userIDs []uint -// for _, user := range users { -// userIDs = append(userIDs, user.ID) -// CheckUser(t, *user, *user) -// } - -// var users2 []User -// DB.Preload("Company").Preload("Manager").Find(&users2, "id IN ?", userIDs) -// for idx, user := range users2 { -// CheckUser(t, user, *users[idx]) -// } -// } - -// func TestBulkCreateWithoutPtrWithBelongsTo(t *testing.T) { -// users := []*User{ -// GetUser("create_with_belongs_to_1", Config{Company: true, Manager: true}), -// GetUser("create_with_belongs_to_2", Config{Company: true, Manager: false}), -// GetUser("create_with_belongs_to_3", Config{Company: false, Manager: true}), -// GetUser("create_with_belongs_to_4", Config{Company: true, Manager: true}), -// } - -// if err := DB.Create(&users).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// var userIDs []uint -// for _, user := range users { -// userIDs = append(userIDs, user.ID) -// CheckUser(t, *user, *user) -// } -// } - -// func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) { -// check := func(t *testing.T, user User, old User) { -// if user.Account.ID == 0 { -// t.Errorf("Account should be saved") -// } else if user.Account.UserID.Int64 != int64(user.ID) { -// t.Errorf("Account's foreign key should be saved") -// } else { -// var account Account -// db.First(&account, "id = ?", user.Account.ID) -// if account.Number != user.Account.Number { -// t.Errorf("Account's number should be same") -// } else if user.Account.Number != old.Account.Number { -// t.Errorf("Account's number should be same") -// } -// } -// } - -// t.Run("HasOne", func(t *testing.T) { -// var user = User{ -// Name: "create", -// Age: 18, -// Birthday: Now(), -// Account: Account{Number: "account-has-one-association"}, -// } - -// if err := db.Create(&user).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// check(t, user, user) - -// var user2 User -// db.Preload("Account").Find(&user2, "id = ?", user.ID) -// check(t, user2, user) -// }) - -// t.Run("HasOneForBulkInsert", func(t *testing.T) { -// var users = []User{{ -// Name: "create-1", -// Age: 18, -// Birthday: Now(), -// Account: Account{Number: "account-has-one-association-1"}, -// }, { -// Name: "create-2", -// Age: 28, -// Birthday: Now(), -// Account: Account{Number: "account-has-one-association-2"}, -// }, { -// Name: "create-3", -// Age: 38, -// Birthday: Now(), -// Account: Account{Number: "account-has-one-association-3"}, -// }} - -// if err := db.Create(&users).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// var userIDs []uint -// for _, user := range users { -// userIDs = append(userIDs, user.ID) -// check(t, user, user) -// } - -// var users2 []User -// db.Preload("Account").Find(&users2, "id IN (?)", userIDs) -// for idx, user := range users2 { -// check(t, user, users[idx]) -// } -// }) - -// t.Run("HasOneForBulkInsertPtrData", func(t *testing.T) { -// var users = []*User{{ -// Name: "create-1", -// Age: 18, -// Birthday: Now(), -// Account: Account{Number: "account-has-one-association-1"}, -// }, { -// Name: "create-2", -// Age: 28, -// Birthday: Now(), -// Account: Account{Number: "account-has-one-association-2"}, -// }, { -// Name: "create-3", -// Age: 38, -// Birthday: Now(), -// Account: Account{Number: "account-has-one-association-3"}, -// }} - -// if err := db.Create(&users).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// for _, user := range users { -// check(t, *user, *user) -// } -// }) - -// t.Run("HasOneForBulkInsertWithoutPtr", func(t *testing.T) { -// var users = []User{{ -// Name: "create-1", -// Age: 18, -// Birthday: Now(), -// Account: Account{Number: "account-has-one-association-1"}, -// }, { -// Name: "create-2", -// Age: 28, -// Birthday: Now(), -// Account: Account{Number: "account-has-one-association-2"}, -// }, { -// Name: "create-3", -// Age: 38, -// Birthday: Now(), -// Account: Account{Number: "account-has-one-association-3"}, -// }} - -// if err := db.Create(users).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// for _, user := range users { -// check(t, user, user) -// } -// }) - -// checkPet := func(t *testing.T, pet Pet, old Pet) { -// if pet.Toy.OwnerID != fmt.Sprint(pet.ID) || pet.Toy.OwnerType != "pets" { -// t.Errorf("Failed to create polymorphic has one association - toy owner id %v, owner type %v", pet.Toy.OwnerID, pet.Toy.OwnerType) -// } else { -// var toy Toy -// db.First(&toy, "owner_id = ? and owner_type = ?", pet.Toy.OwnerID, pet.Toy.OwnerType) -// if toy.Name != pet.Toy.Name { -// t.Errorf("Failed to query saved polymorphic has one association") -// } else if old.Toy.Name != pet.Toy.Name { -// t.Errorf("Failed to query saved polymorphic has one association") -// } -// } -// } - -// t.Run("PolymorphicHasOne", func(t *testing.T) { -// var pet = Pet{ -// Name: "create", -// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic"}, -// } - -// if err := db.Create(&pet).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// checkPet(t, pet, pet) - -// var pet2 Pet -// db.Preload("Toy").Find(&pet2, "id = ?", pet.ID) -// checkPet(t, pet2, pet) -// }) - -// t.Run("PolymorphicHasOneForBulkInsert", func(t *testing.T) { -// var pets = []Pet{{ -// Name: "create-1", -// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-1"}, -// }, { -// Name: "create-2", -// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-2"}, -// }, { -// Name: "create-3", -// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-3"}, -// }} - -// if err := db.Create(&pets).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// var petIDs []uint -// for _, pet := range pets { -// petIDs = append(petIDs, pet.ID) -// checkPet(t, pet, pet) -// } - -// var pets2 []Pet -// db.Preload("Toy").Find(&pets2, "id IN (?)", petIDs) -// for idx, pet := range pets2 { -// checkPet(t, pet, pets[idx]) -// } -// }) - -// t.Run("PolymorphicHasOneForBulkInsertPtrData", func(t *testing.T) { -// var pets = []*Pet{{ -// Name: "create-1", -// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-1"}, -// }, { -// Name: "create-2", -// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-2"}, -// }, { -// Name: "create-3", -// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-3"}, -// }} - -// if err := db.Create(&pets).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// for _, pet := range pets { -// checkPet(t, *pet, *pet) -// } -// }) - -// t.Run("PolymorphicHasOneForBulkInsertWithoutPtr", func(t *testing.T) { -// var pets = []*Pet{{ -// Name: "create-1", -// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-1"}, -// }, { -// Name: "create-2", -// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-2"}, -// }, { -// Name: "create-3", -// Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-3"}, -// }} - -// if err := db.Create(pets).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// for _, pet := range pets { -// checkPet(t, *pet, *pet) -// } -// }) -// } - -// func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { -// check := func(t *testing.T, user User, old User) { -// for idx, pet := range user.Pets { -// if pet.ID == 0 { -// t.Errorf("Pet's foreign key should be saved") -// } - -// var result Pet -// db.First(&result, "id = ?", pet.ID) -// if result.Name != pet.Name { -// t.Errorf("Pet's name should be same") -// } else if result.UserID != user.ID { -// t.Errorf("Pet's foreign key should be saved") -// } else if result.Name != old.Pets[idx].Name { -// t.Errorf("Pet's name should be same") -// } -// } -// } - -// t.Run("HasMany", func(t *testing.T) { -// var user = User{ -// Name: "create", -// Age: 18, -// Birthday: Now(), -// Pets: []*Pet{{Name: "pet1"}, {Name: "pet2"}}, -// } - -// if err := db.Create(&user).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// check(t, user, user) - -// var user2 User -// db.Preload("Pets").Find(&user2, "id = ?", user.ID) -// check(t, user2, user) -// }) - -// t.Run("HasManyForBulkInsert", func(t *testing.T) { -// var users = []User{{ -// Name: "create-1", -// Age: 18, -// Birthday: Now(), -// Pets: []*Pet{{Name: "pet-1-1"}, {Name: "pet-1-2"}}, -// }, { -// Name: "create-2", -// Age: 28, -// Birthday: Now(), -// Pets: []*Pet{{Name: "pet-2-1"}}, -// }, { -// Name: "create-3", -// Age: 38, -// Birthday: Now(), -// Pets: []*Pet{{Name: "pet-3-1"}, {Name: "pet-3-2"}}, -// }} - -// if err := db.Create(&users).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// var userIDs []uint -// for _, user := range users { -// userIDs = append(userIDs, user.ID) -// check(t, user, user) -// } - -// var users2 []User -// db.Preload("Pets").Find(&users2, "id IN (?)", userIDs) -// for idx, user := range users2 { -// check(t, user, users[idx]) -// } -// }) - -// t.Run("HasManyForBulkInsertPtrData", func(t *testing.T) { -// var users = []*User{{ -// Name: "create-1", -// Age: 18, -// Birthday: Now(), -// Pets: []*Pet{{Name: "pet-1-1"}, {Name: "pet-1-2"}}, -// }, { -// Name: "create-2", -// Age: 28, -// Birthday: Now(), -// Pets: []*Pet{{Name: "pet-2-1"}}, -// }, { -// Name: "create-3", -// Age: 38, -// Birthday: Now(), -// Pets: []*Pet{{Name: "pet-3-1"}, {Name: "pet-3-2"}}, -// }} - -// if err := db.Create(&users).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// for _, user := range users { -// check(t, *user, *user) -// } -// }) - -// t.Run("HasManyForBulkInsertWithoutPtr", func(t *testing.T) { -// var users = []*User{{ -// Name: "create-1", -// Age: 18, -// Birthday: Now(), -// Pets: []*Pet{{Name: "pet-1-1"}, {Name: "pet-1-2"}}, -// }, { -// Name: "create-2", -// Age: 28, -// Birthday: Now(), -// Pets: []*Pet{{Name: "pet-2-1"}}, -// }, { -// Name: "create-3", -// Age: 38, -// Birthday: Now(), -// Pets: []*Pet{{Name: "pet-3-1"}, {Name: "pet-3-2"}}, -// }} - -// if err := db.Create(users).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// for _, user := range users { -// check(t, *user, *user) -// } -// }) - -// checkToy := func(t *testing.T, user User, old User) { -// for idx, toy := range user.Toys { -// if toy.ID == 0 { -// t.Fatalf("Failed to create toy #%v", idx) -// } - -// var result Toy -// db.First(&result, "id = ?", toy.ID) -// if result.Name != toy.Name { -// t.Errorf("Failed to query saved toy") -// } else if result.Name != old.Toys[idx].Name { -// t.Errorf("Failed to query saved toy") -// } else if result.OwnerID != fmt.Sprint(user.ID) || result.OwnerType != "users" { -// t.Errorf("Failed to save relation") -// } -// } -// } - -// t.Run("PolymorphicHasMany", func(t *testing.T) { -// var user = User{ -// Name: "create", -// Age: 18, -// Birthday: Now(), -// Toys: []Toy{{Name: "toy1"}, {Name: "toy2"}}, -// } - -// if err := db.Create(&user).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// checkToy(t, user, user) - -// var user2 User -// db.Preload("Toys").Find(&user2, "id = ?", user.ID) -// check(t, user2, user) -// }) - -// t.Run("PolymorphicHasManyForBulkInsert", func(t *testing.T) { -// var users = []User{{ -// Name: "create-1", -// Age: 18, -// Birthday: Now(), -// Toys: []Toy{{Name: "toy-1-1"}, {Name: "toy-1-2"}}, -// }, { -// Name: "create-2", -// Age: 28, -// Birthday: Now(), -// Toys: []Toy{{Name: "toy-2-1"}, {Name: "toy-2-2"}}, -// }, { -// Name: "create-3", -// Age: 38, -// Birthday: Now(), -// Toys: []Toy{{Name: "toy-3-1"}, {Name: "toy-3-2"}}, -// }} - -// if err := db.Create(&users).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// var userIDs []uint -// for _, user := range users { -// userIDs = append(userIDs, user.ID) -// checkToy(t, user, user) -// } - -// var users2 []User -// db.Preload("Toys").Find(&users2, "id IN (?)", userIDs) -// for idx, user := range users2 { -// check(t, user, users[idx]) -// } -// }) - -// t.Run("PolymorphicHasManyForBulkInsertPtrData", func(t *testing.T) { -// var users = []*User{{ -// Name: "create-1", -// Age: 18, -// Birthday: Now(), -// Toys: []Toy{{Name: "toy-1-1"}, {Name: "toy-1-2"}}, -// }, { -// Name: "create-2", -// Age: 28, -// Birthday: Now(), -// Toys: []Toy{{Name: "toy-2-1"}, {Name: "toy-2-2"}}, -// }, { -// Name: "create-3", -// Age: 38, -// Birthday: Now(), -// Toys: []Toy{{Name: "toy-3-1"}, {Name: "toy-3-2"}}, -// }} - -// if err := db.Create(&users).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// for _, user := range users { -// checkToy(t, *user, *user) -// } -// }) - -// t.Run("PolymorphicHasManyForBulkInsertWithoutPtr", func(t *testing.T) { -// var users = []User{{ -// Name: "create-1", -// Age: 18, -// Birthday: Now(), -// Toys: []Toy{{Name: "toy-1-1"}, {Name: "toy-1-2"}}, -// }, { -// Name: "create-2", -// Age: 28, -// Birthday: Now(), -// Toys: []Toy{{Name: "toy-2-1"}, {Name: "toy-2-2"}}, -// }, { -// Name: "create-3", -// Age: 38, -// Birthday: Now(), -// Toys: []Toy{{Name: "toy-3-1"}, {Name: "toy-3-2"}}, -// }} - -// if err := db.Create(users).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// for _, user := range users { -// checkToy(t, user, user) -// } -// }) -// } - -// func TestCreateMany2ManyAssociations(t *testing.T, db *gorm.DB) { -// check := func(t *testing.T, user User, old User) { -// for idx, language := range user.Languages { -// var result Language -// db.First(&result, "code = ?", language.Code) -// if result.Name != language.Name { -// t.Errorf("Language's name should be same") -// } else if result.Name != old.Languages[idx].Name { -// t.Errorf("Language's name should be same") -// } -// } - -// for idx, f := range user.Friends { -// if f.ID == 0 { -// t.Errorf("Friend's foreign key should be saved") -// } - -// var result User -// db.First(&result, "id = ?", f.ID) -// if result.Name != f.Name { -// t.Errorf("Friend's name should be same") -// } else if result.Name != old.Friends[idx].Name { -// t.Errorf("Language's name should be same") -// } -// } -// } - -// db.Create(&[]Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}}) - -// t.Run("Many2Many", func(t *testing.T) { -// var user = User{ -// Name: "create", -// Age: 18, -// Birthday: Now(), -// Languages: []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}}, -// Friends: []*User{{Name: "friend-1"}, {Name: "friend-2"}}, -// } - -// if err := db.Create(&user).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// check(t, user, user) - -// var user2 User -// db.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) -// check(t, user2, user) -// }) - -// t.Run("Many2ManyForBulkInsert", func(t *testing.T) { -// var users = []User{ -// { -// Name: "create-1", -// Age: 18, -// Birthday: Now(), -// Languages: []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}}, -// Friends: []*User{{Name: "friend-1-1"}, {Name: "friend-1-2"}}, -// }, -// { -// Name: "create-2", -// Age: 18, -// Birthday: Now(), -// Languages: []Language{{Code: "zh-CN", Name: "Chinese"}}, -// Friends: []*User{{Name: "friend-2-1"}}, -// }, -// { -// Name: "create-3", -// Age: 18, -// Birthday: Now(), -// Languages: []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}}, -// Friends: []*User{{Name: "friend-3-1"}, {Name: "friend-3-2"}, {Name: "friend-3-3"}}, -// }, -// } - -// if err := db.Create(&users).Error; err != nil { -// t.Fatalf("errors happened when create: %v", err) -// } - -// var userIDs []uint -// for _, user := range users { -// userIDs = append(userIDs, user.ID) -// check(t, user, user) -// } - -// var users2 []User -// db.Preload("Languages").Preload("Friends").Find(&users2, "id IN (?)", userIDs) -// for idx, user := range users2 { -// check(t, user, users[idx]) -// } -// }) -// } +func TestBulkCreatePtrDataWithAssociations(t *testing.T) { + users := []*User{ + GetUser("bulk_ptr_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), + GetUser("bulk_ptr_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), + GetUser("bulk_ptr_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), + GetUser("bulk_ptr_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), + GetUser("bulk_ptr_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), + GetUser("bulk_ptr_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), + GetUser("bulk_ptr_7", Config{Account: true, Pets: 1, Toys: 3, Company: true, Manager: true, Team: 4, Languages: 3, Friends: 1}), + GetUser("bulk_ptr_8", Config{Account: false, Pets: 0, Toys: 0, Company: false, Manager: false, Team: 0, Languages: 0, Friends: 0}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + CheckUser(t, *user, *user) + } + + var users2 []User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").Find(&users2, "id IN ?", userIDs) + for idx, user := range users2 { + CheckUser(t, user, *users[idx]) + } +} + +func TestPolymorphicHasOne(t *testing.T) { + t.Run("Struct", func(t *testing.T) { + var pet = Pet{ + Name: "PolymorphicHasOne", + Toy: Toy{Name: "Toy-PolymorphicHasOne"}, + } + + if err := DB.Create(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckPet(t, pet, pet) + + var pet2 Pet + DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID) + CheckPet(t, pet2, pet) + }) + + t.Run("Slice", func(t *testing.T) { + var pets = []Pet{{ + Name: "PolymorphicHasOne-Slice-1", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-1"}, + }, { + Name: "PolymorphicHasOne-Slice-2", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-2"}, + }, { + Name: "PolymorphicHasOne-Slice-3", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-3"}, + }} + + if err := DB.Create(&pets).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var petIDs []uint + for _, pet := range pets { + petIDs = append(petIDs, pet.ID) + CheckPet(t, pet, pet) + } + + var pets2 []Pet + DB.Preload("Toy").Find(&pets2, "id IN ?", petIDs) + for idx, pet := range pets2 { + CheckPet(t, pet, pets[idx]) + } + }) + + t.Run("SliceOfPtr", func(t *testing.T) { + var pets = []*Pet{{ + Name: "PolymorphicHasOne-Slice-1", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-1"}, + }, { + Name: "PolymorphicHasOne-Slice-2", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-2"}, + }, { + Name: "PolymorphicHasOne-Slice-3", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-3"}, + }} + + if err := DB.Create(&pets).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, pet := range pets { + CheckPet(t, *pet, *pet) + } + }) +} diff --git a/utils/utils.go b/utils/utils.go index 3924e69e..e177999e 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -3,6 +3,7 @@ package utils import ( "database/sql/driver" "fmt" + "path/filepath" "reflect" "regexp" "runtime" @@ -11,8 +12,13 @@ import ( "unicode" ) -var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`) -var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`) +var goSrcRegexp, goTestRegexp *regexp.Regexp + +func init() { + _, file, _, _ := runtime.Caller(0) + goSrcRegexp = regexp.MustCompile(filepath.Join(filepath.Dir(filepath.Dir(file)), ".*.go")) + goTestRegexp = regexp.MustCompile(filepath.Join(filepath.Dir(filepath.Dir(file)), ".*test.go")) +} func FileWithLineNum() string { for i := 2; i < 15; i++ { From f0a442adff91e70a5f85cb50b4dc27bd3c189714 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 23 May 2020 23:50:48 +0800 Subject: [PATCH 317/881] Refactor tests --- callbacks/helper.go | 4 +- finisher_api.go | 3 + logger/sql.go | 2 +- tests/associations.go | 73 ----- tests/associations_test.go | 24 ++ tests/create.go | 188 ------------- tests/delete.go | 64 ----- tests/delete_test.go | 48 ++++ tests/group_by.go | 62 ----- tests/group_by_test.go | 57 ++++ tests/joins.go | 81 ------ tests/joins_test.go | 55 ++++ tests/{migrate.go => migrate_test.go} | 12 +- tests/query.go | 95 ------- tests/query_test.go | 82 ++++++ tests/update.go | 382 -------------------------- tests/update_test.go | 226 +++++++++++++++ tests/utils.go | 232 +++++++++++++++- 18 files changed, 734 insertions(+), 956 deletions(-) delete mode 100644 tests/associations.go create mode 100644 tests/associations_test.go delete mode 100644 tests/create.go delete mode 100644 tests/delete.go create mode 100644 tests/delete_test.go delete mode 100644 tests/group_by.go create mode 100644 tests/group_by_test.go delete mode 100644 tests/joins.go create mode 100644 tests/joins_test.go rename tests/{migrate.go => migrate_test.go} (67%) delete mode 100644 tests/query.go create mode 100644 tests/query_test.go delete mode 100644 tests/update.go create mode 100644 tests/update_test.go diff --git a/callbacks/helper.go b/callbacks/helper.go index 092c9c37..43e90b8a 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -10,10 +10,12 @@ import ( // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate bool) (map[string]bool, bool) { results := map[string]bool{} + notRestricted := false // select columns for _, column := range stmt.Selects { if column == "*" { + notRestricted = true for _, dbName := range stmt.Schema.DBNames { results[dbName] = true } @@ -51,7 +53,7 @@ func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate boo } } - return results, len(stmt.Selects) > 0 + return results, !notRestricted && len(stmt.Selects) > 0 } // ConvertMapToValuesForCreate convert map to values diff --git a/finisher_api.go b/finisher_api.go index 9e29e327..1b2a7e29 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -35,6 +35,9 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.Statement.AddClause(where) } + if len(tx.Statement.Selects) == 0 { + tx.Statement.Selects = append(tx.Statement.Selects, "*") + } tx.callbacks.Update().Execute(tx) return } diff --git a/logger/sql.go b/logger/sql.go index 9c0f54d7..219ae301 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -53,7 +53,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v } else { rv := reflect.ValueOf(v) - if !rv.IsValid() { + if !rv.IsValid() || rv.IsNil() { vars[idx] = "NULL" } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { convertParams(reflect.Indirect(rv).Interface(), idx) diff --git a/tests/associations.go b/tests/associations.go deleted file mode 100644 index 7e93e81e..00000000 --- a/tests/associations.go +++ /dev/null @@ -1,73 +0,0 @@ -package tests - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -func TestAssociations(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{}) - db.Migrator().AutoMigrate(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{}) - - TestBelongsToAssociations(t, db) -} - -func TestBelongsToAssociations(t *testing.T, db *gorm.DB) { - check := func(t *testing.T, user User, old User) { - if old.Company.Name != "" { - if user.CompanyID == nil { - t.Errorf("Company's foreign key should be saved") - } else { - var company Company - db.First(&company, "id = ?", *user.CompanyID) - if company.Name != old.Company.Name { - t.Errorf("Company's name should be same, expects: %v, got %v", old.Company.Name, user.Company.Name) - } else if user.Company.Name != old.Company.Name { - t.Errorf("Company's name should be same, expects: %v, got %v", old.Company.Name, user.Company.Name) - } - } - } else if user.CompanyID != nil { - t.Errorf("Company should not be created for zero value, got: %+v", user.CompanyID) - } - - if old.Manager != nil { - if user.ManagerID == nil { - t.Errorf("Manager's foreign key should be saved") - } else { - var manager User - db.First(&manager, "id = ?", *user.ManagerID) - if manager.Name != user.Manager.Name { - t.Errorf("Manager's name should be same") - } else if user.Manager.Name != old.Manager.Name { - t.Errorf("Manager's name should be same") - } - } - } else if user.ManagerID != nil { - t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) - } - } - - t.Run("BelongsTo", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association"}, - Manager: &User{Name: "manager-belongs-to-association"}, - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - check(t, user, user) - - var user2 User - db.Find(&user2, "id = ?", user.ID) - db.Model(&user2).Association("Company").Find(&user2.Company) - user2.Manager = &User{} - db.Model(&user2).Association("Manager").Find(user2.Manager) - check(t, user2, user) - }) -} diff --git a/tests/associations_test.go b/tests/associations_test.go new file mode 100644 index 00000000..dc88ee03 --- /dev/null +++ b/tests/associations_test.go @@ -0,0 +1,24 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestAssociationForBelongsTo(t *testing.T) { + var user = *GetUser("belongs-to", Config{Company: true, Manager: true}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Company").Find(&user2.Company) + user2.Manager = &User{} + DB.Model(&user2).Association("Manager").Find(user2.Manager) + CheckUser(t, user2, user) +} diff --git a/tests/create.go b/tests/create.go deleted file mode 100644 index 6e5dd2c5..00000000 --- a/tests/create.go +++ /dev/null @@ -1,188 +0,0 @@ -package tests - -import ( - "strconv" - "testing" - "time" -) - -type Config struct { - Account bool - Pets int - Toys int - Company bool - Manager bool - Team int - Languages int - Friends int -} - -func GetUser(name string, config Config) *User { - var ( - birthday = time.Now() - user = User{ - Name: name, - Age: 18, - Birthday: &birthday, - } - ) - - if config.Account { - user.Account = Account{Number: name + "_account"} - } - - for i := 0; i < config.Pets; i++ { - user.Pets = append(user.Pets, &Pet{Name: name + "_pet_" + strconv.Itoa(i+1)}) - } - - for i := 0; i < config.Toys; i++ { - user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)}) - } - - if config.Company { - user.Company = Company{Name: "company-" + name} - } - - if config.Manager { - user.Manager = GetUser(name+"_manager", Config{}) - } - - for i := 0; i < config.Team; i++ { - user.Team = append(user.Team, *GetUser(name+"_team_"+strconv.Itoa(i+1), Config{})) - } - - for i := 0; i < config.Languages; i++ { - name := name + "_locale_" + strconv.Itoa(i+1) - language := Language{Code: name, Name: name} - DB.Create(&language) - user.Languages = append(user.Languages, language) - } - - for i := 0; i < config.Friends; i++ { - user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{})) - } - - return &user -} - -func CheckPet(t *testing.T, pet Pet, expect Pet) { - if pet.ID != 0 { - var newPet Pet - if err := DB.Where("id = ?", pet.ID).First(&newPet).Error; err != nil { - t.Fatalf("errors happened when query: %v", err) - } else { - AssertObjEqual(t, newPet, pet, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") - } - } - - AssertObjEqual(t, pet, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") - - AssertObjEqual(t, pet.Toy, expect.Toy, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "OwnerID", "OwnerType") - - if expect.Toy.Name != "" && expect.Toy.OwnerType != "pets" { - t.Errorf("toys's OwnerType, expect: %v, got %v", "pets", expect.Toy.OwnerType) - } -} - -func CheckUser(t *testing.T, user User, expect User) { - if user.ID != 0 { - var newUser User - if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil { - t.Fatalf("errors happened when query: %v", err) - } else { - AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") - } - } - - AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") - - t.Run("Account", func(t *testing.T) { - AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") - - if user.Account.Number != "" { - if !user.Account.UserID.Valid { - t.Errorf("Account's foreign key should be saved") - } else { - var account Account - DB.First(&account, "user_id = ?", user.ID) - AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") - } - } - }) - - t.Run("Pets", func(t *testing.T) { - if len(user.Pets) != len(expect.Pets) { - t.Errorf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets)) - } - - for idx, pet := range user.Pets { - if pet == nil || expect.Pets[idx] == nil { - t.Errorf("pets#%v should equal, expect: %v, got %v", idx, expect.Pets[idx], pet) - } else { - CheckPet(t, *pet, *expect.Pets[idx]) - } - } - }) - - t.Run("Toys", func(t *testing.T) { - if len(user.Toys) != len(expect.Toys) { - t.Errorf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys)) - } - - for idx, toy := range user.Toys { - if toy.OwnerType != "users" { - t.Errorf("toys's OwnerType, expect: %v, got %v", "users", toy.OwnerType) - } - - AssertObjEqual(t, toy, expect.Toys[idx], "ID", "CreatedAt", "UpdatedAt", "Name", "OwnerID", "OwnerType") - } - }) - - t.Run("Company", func(t *testing.T) { - AssertObjEqual(t, user.Company, expect.Company, "ID", "Name") - }) - - t.Run("Manager", func(t *testing.T) { - if user.Manager != nil { - if user.ManagerID == nil { - t.Errorf("Manager's foreign key should be saved") - } else { - var manager User - DB.First(&manager, "id = ?", *user.ManagerID) - AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") - } - } else if user.ManagerID != nil { - t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) - } - }) - - t.Run("Team", func(t *testing.T) { - if len(user.Team) != len(expect.Team) { - t.Errorf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team)) - } - - for idx, team := range user.Team { - AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") - } - }) - - t.Run("Languages", func(t *testing.T) { - if len(user.Languages) != len(expect.Languages) { - t.Errorf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages)) - } - - for idx, language := range user.Languages { - AssertObjEqual(t, language, expect.Languages[idx], "Code", "Name") - } - }) - - t.Run("Friends", func(t *testing.T) { - if len(user.Friends) != len(expect.Friends) { - t.Errorf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends)) - } - - for idx, friend := range user.Friends { - AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") - } - }) -} diff --git a/tests/delete.go b/tests/delete.go deleted file mode 100644 index 45701ff0..00000000 --- a/tests/delete.go +++ /dev/null @@ -1,64 +0,0 @@ -package tests - -import ( - "errors" - "testing" - - "github.com/jinzhu/gorm" -) - -func TestDelete(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&User{}) - db.AutoMigrate(&User{}) - - t.Run("Delete", func(t *testing.T) { - var users = []User{{ - Name: "find", - Age: 1, - Birthday: Now(), - }, { - Name: "find", - Age: 2, - Birthday: Now(), - }, { - Name: "find", - Age: 3, - Birthday: Now(), - }} - - if err := db.Create(&users).Error; err != nil { - t.Errorf("errors happened when create: %v", err) - } - - for _, user := range users { - if user.ID == 0 { - t.Fatalf("user's primary key should has value after create, got : %v", user.ID) - } - } - - if err := db.Delete(&users[1]).Error; err != nil { - t.Errorf("errors happened when delete: %v", err) - } - - var result User - if err := db.Where("id = ?", users[1].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { - t.Errorf("should returns record not found error, but got %v", err) - } - - for _, user := range []User{users[0], users[2]} { - if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { - t.Errorf("no error should returns when query %v, but got %v", user.ID, err) - } - } - - if err := db.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { - t.Errorf("should returns missing WHERE clause while deleting error") - } - - for _, user := range []User{users[0], users[2]} { - if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { - t.Errorf("no error should returns when query %v, but got %v", user.ID, err) - } - } - }) -} diff --git a/tests/delete_test.go b/tests/delete_test.go new file mode 100644 index 00000000..8be072d3 --- /dev/null +++ b/tests/delete_test.go @@ -0,0 +1,48 @@ +package tests_test + +import ( + "errors" + "testing" + + "github.com/jinzhu/gorm" + . "github.com/jinzhu/gorm/tests" +) + +func TestDelete(t *testing.T) { + var users = []User{*GetUser("delete", Config{}), *GetUser("delete", Config{}), *GetUser("delete", Config{})} + + if err := DB.Create(&users).Error; err != nil { + t.Errorf("errors happened when create: %v", err) + } + + for _, user := range users { + if user.ID == 0 { + t.Fatalf("user's primary key should has value after create, got : %v", user.ID) + } + } + + if err := DB.Delete(&users[1]).Error; err != nil { + t.Errorf("errors happened when delete: %v", err) + } + + var result User + if err := DB.Where("id = ?", users[1].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should returns record not found error, but got %v", err) + } + + for _, user := range []User{users[0], users[2]} { + if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("no error should returns when query %v, but got %v", user.ID, err) + } + } + + if err := DB.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { + t.Errorf("should returns missing WHERE clause while deleting error") + } + + for _, user := range []User{users[0], users[2]} { + if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("no error should returns when query %v, but got %v", user.ID, err) + } + } +} diff --git a/tests/group_by.go b/tests/group_by.go deleted file mode 100644 index b0bb4155..00000000 --- a/tests/group_by.go +++ /dev/null @@ -1,62 +0,0 @@ -package tests - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -func TestGroupBy(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&User{}) - db.AutoMigrate(&User{}) - - t.Run("GroupBy", func(t *testing.T) { - var users = []User{{ - Name: "groupby", - Age: 10, - Birthday: Now(), - }, { - Name: "groupby", - Age: 20, - Birthday: Now(), - }, { - Name: "groupby", - Age: 30, - Birthday: Now(), - }, { - Name: "groupby1", - Age: 110, - Birthday: Now(), - }, { - Name: "groupby1", - Age: 220, - Birthday: Now(), - }, { - Name: "groupby1", - Age: 330, - Birthday: Now(), - }} - - if err := db.Create(&users).Error; err != nil { - t.Errorf("errors happened when create: %v", err) - } - - var name string - var total int - if err := db.Model(&User{}).Select("name, sum(age)").Where("name = ?", "groupby").Group("name").Row().Scan(&name, &total); err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - if name != "groupby" || total != 60 { - t.Errorf("name should be groupby, but got %v, total should be 60, but got %v", name, total) - } - - if err := db.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Row().Scan(&name, &total); err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - if name != "groupby1" || total != 660 { - t.Errorf("name should be groupby, but got %v, total should be 660, but got %v", name, total) - } - }) -} diff --git a/tests/group_by_test.go b/tests/group_by_test.go new file mode 100644 index 00000000..66a733aa --- /dev/null +++ b/tests/group_by_test.go @@ -0,0 +1,57 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestGroupBy(t *testing.T) { + var users = []User{{ + Name: "groupby", + Age: 10, + Birthday: Now(), + }, { + Name: "groupby", + Age: 20, + Birthday: Now(), + }, { + Name: "groupby", + Age: 30, + Birthday: Now(), + }, { + Name: "groupby1", + Age: 110, + Birthday: Now(), + }, { + Name: "groupby1", + Age: 220, + Birthday: Now(), + }, { + Name: "groupby1", + Age: 330, + Birthday: Now(), + }} + + if err := DB.Create(&users).Error; err != nil { + t.Errorf("errors happened when create: %v", err) + } + + var name string + var total int + if err := DB.Model(&User{}).Select("name, sum(age)").Where("name = ?", "groupby").Group("name").Row().Scan(&name, &total); err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if name != "groupby" || total != 60 { + t.Errorf("name should be groupby, but got %v, total should be 60, but got %v", name, total) + } + + if err := DB.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Row().Scan(&name, &total); err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if name != "groupby1" || total != 660 { + t.Errorf("name should be groupby, but got %v, total should be 660, but got %v", name, total) + } +} diff --git a/tests/joins.go b/tests/joins.go deleted file mode 100644 index 86f9f104..00000000 --- a/tests/joins.go +++ /dev/null @@ -1,81 +0,0 @@ -package tests - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -func TestJoins(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&User{}, &Account{}, &Company{}) - db.AutoMigrate(&User{}, &Account{}, &Company{}) - - check := func(t *testing.T, oldUser, newUser User) { - if newUser.Company.ID != oldUser.Company.ID { - t.Errorf("Company is not equal when load with joins, loaded company id: %v", newUser.Company.ID) - } - - if newUser.Manager == nil || newUser.Manager.ID != oldUser.Manager.ID { - t.Errorf("Manager is not equal when load with joins: loaded manager: %+v", newUser.Manager) - } - - if newUser.Account.ID != oldUser.Account.ID { - t.Errorf("Account is not equal when load with joins, loaded account id: %v, expect: %v", newUser.Account.ID, oldUser.Account.ID) - } - } - - t.Run("Joins", func(t *testing.T) { - user := User{ - Name: "joins-1", - Company: Company{Name: "company"}, - Manager: &User{Name: "manager"}, - Account: Account{Number: "account-has-one-association"}, - } - - db.Create(&user) - - var user2 User - if err := db.Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil { - t.Fatalf("Failed to load with joins, got error: %v", err) - } - - check(t, user, user2) - }) - - t.Run("JoinsForSlice", func(t *testing.T) { - users := []User{{ - Name: "slice-joins-1", - Company: Company{Name: "company"}, - Manager: &User{Name: "manager"}, - Account: Account{Number: "account-has-one-association"}, - }, { - Name: "slice-joins-2", - Company: Company{Name: "company2"}, - Manager: &User{Name: "manager2"}, - Account: Account{Number: "account-has-one-association2"}, - }, { - Name: "slice-joins-3", - Company: Company{Name: "company3"}, - Manager: &User{Name: "manager3"}, - Account: Account{Number: "account-has-one-association3"}, - }} - - db.Create(&users) - - var users2 []User - if err := db.Joins("Company").Joins("Manager").Joins("Account").Find(&users2, "users.name LIKE ?", "slice-joins%").Error; err != nil { - t.Fatalf("Failed to load with joins, got error: %v", err) - } else if len(users2) != len(users) { - t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) - } - - for _, u2 := range users2 { - for _, u := range users { - if u.Name == u2.Name { - check(t, u, u2) - continue - } - } - } - }) -} diff --git a/tests/joins_test.go b/tests/joins_test.go new file mode 100644 index 00000000..556130ee --- /dev/null +++ b/tests/joins_test.go @@ -0,0 +1,55 @@ +package tests_test + +import ( + "sort" + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestJoins(t *testing.T) { + user := *GetUser("joins-1", Config{Company: true, Manager: true, Account: true}) + + DB.Create(&user) + + var user2 User + if err := DB.Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil { + t.Fatalf("Failed to load with joins, got error: %v", err) + } + + CheckUser(t, user2, user) +} + +func TestJoinsForSlice(t *testing.T) { + users := []User{ + *GetUser("slice-joins-1", Config{Company: true, Manager: true, Account: true}), + *GetUser("slice-joins-2", Config{Company: true, Manager: true, Account: true}), + *GetUser("slice-joins-3", Config{Company: true, Manager: true, Account: true}), + } + + DB.Create(&users) + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + var users2 []User + if err := DB.Joins("Company").Joins("Manager").Joins("Account").Find(&users2, "users.id IN ?", userIDs).Error; err != nil { + t.Fatalf("Failed to load with joins, got error: %v", err) + } else if len(users2) != len(users) { + t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) + } + + sort.Slice(users2, func(i, j int) bool { + return users2[i].ID > users2[j].ID + }) + + sort.Slice(users, func(i, j int) bool { + return users[i].ID > users[j].ID + }) + + for idx, user := range users { + CheckUser(t, user, users2[idx]) + } +} diff --git a/tests/migrate.go b/tests/migrate_test.go similarity index 67% rename from tests/migrate.go rename to tests/migrate_test.go index fa8a89e8..917fba75 100644 --- a/tests/migrate.go +++ b/tests/migrate_test.go @@ -1,28 +1,28 @@ -package tests +package tests_test import ( "math/rand" "testing" "time" - "github.com/jinzhu/gorm" + . "github.com/jinzhu/gorm/tests" ) -func TestMigrate(t *testing.T, db *gorm.DB) { +func TestMigrate(t *testing.T) { allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) - if err := db.Migrator().DropTable(allModels...); err != nil { + if err := DB.Migrator().DropTable(allModels...); err != nil { t.Errorf("Failed to drop table, got error %v", err) } - if err := db.AutoMigrate(allModels...); err != nil { + if err := DB.AutoMigrate(allModels...); err != nil { t.Errorf("Failed to auto migrate, but got error %v", err) } for _, m := range allModels { - if !db.Migrator().HasTable(m) { + if !DB.Migrator().HasTable(m) { t.Errorf("Failed to create table for %#v", m) } } diff --git a/tests/query.go b/tests/query.go deleted file mode 100644 index 5eabfb48..00000000 --- a/tests/query.go +++ /dev/null @@ -1,95 +0,0 @@ -package tests - -import ( - "reflect" - "strconv" - "testing" - - "github.com/jinzhu/gorm" -) - -func TestFind(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&User{}) - db.AutoMigrate(&User{}) - - t.Run("Find", func(t *testing.T) { - var users = []User{{ - Name: "find", - Age: 1, - Birthday: Now(), - }, { - Name: "find", - Age: 2, - Birthday: Now(), - }, { - Name: "find", - Age: 3, - Birthday: Now(), - }} - - if err := db.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create users: %v", err) - } - - t.Run("First", func(t *testing.T) { - var first User - if err := db.Where("name = ?", "find").First(&first).Error; err != nil { - t.Errorf("errors happened when query first: %v", err) - } else { - AssertObjEqual(t, first, users[0], "Name", "Age", "Birthday") - } - }) - - t.Run("Last", func(t *testing.T) { - var last User - if err := db.Where("name = ?", "find").Last(&last).Error; err != nil { - t.Errorf("errors happened when query last: %v", err) - } else { - AssertObjEqual(t, last, users[2], "Name", "Age", "Birthday") - } - }) - - var all []User - if err := db.Where("name = ?", "find").Find(&all).Error; err != nil || len(all) != 3 { - t.Errorf("errors happened when query find: %v, length: %v", err, len(all)) - } else { - for idx, user := range users { - t.Run("FindAll#"+strconv.Itoa(idx+1), func(t *testing.T) { - AssertObjEqual(t, all[idx], user, "Name", "Age", "Birthday") - }) - } - } - - t.Run("FirstMap", func(t *testing.T) { - var first = map[string]interface{}{} - if err := db.Model(&User{}).Where("name = ?", "find").First(first).Error; err != nil { - t.Errorf("errors happened when query first: %v", err) - } else { - for _, name := range []string{"Name", "Age", "Birthday"} { - t.Run(name, func(t *testing.T) { - dbName := db.NamingStrategy.ColumnName("", name) - reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) - AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) - }) - } - } - }) - - var allMap = []map[string]interface{}{} - if err := db.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { - t.Errorf("errors happened when query first: %v", err) - } else { - for idx, user := range users { - t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { - for _, name := range []string{"Name", "Age", "Birthday"} { - t.Run(name, func(t *testing.T) { - dbName := db.NamingStrategy.ColumnName("", name) - reflectValue := reflect.Indirect(reflect.ValueOf(user)) - AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) - }) - } - }) - } - } - }) -} diff --git a/tests/query_test.go b/tests/query_test.go new file mode 100644 index 00000000..4388066f --- /dev/null +++ b/tests/query_test.go @@ -0,0 +1,82 @@ +package tests_test + +import ( + "reflect" + "strconv" + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestFind(t *testing.T) { + var users = []User{ + *GetUser("find", Config{}), + *GetUser("find", Config{}), + *GetUser("find", Config{}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create users: %v", err) + } + + t.Run("First", func(t *testing.T) { + var first User + if err := DB.Where("name = ?", "find").First(&first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + CheckUser(t, first, users[0]) + } + }) + + t.Run("Last", func(t *testing.T) { + var last User + if err := DB.Where("name = ?", "find").Last(&last).Error; err != nil { + t.Errorf("errors happened when query last: %v", err) + } else { + CheckUser(t, last, users[2]) + } + }) + + var all []User + if err := DB.Where("name = ?", "find").Find(&all).Error; err != nil || len(all) != 3 { + t.Errorf("errors happened when query find: %v, length: %v", err, len(all)) + } else { + for idx, user := range users { + t.Run("FindAll#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, all[idx], user) + }) + } + } + + t.Run("FirstMap", func(t *testing.T) { + var first = map[string]interface{}{} + if err := DB.Model(&User{}).Where("name = ?", "find").First(first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) + AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) + }) + } + } + }) + + var allMap = []map[string]interface{}{} + if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for idx, user := range users { + t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + reflectValue := reflect.Indirect(reflect.ValueOf(user)) + AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) + }) + } + }) + } + } +} diff --git a/tests/update.go b/tests/update.go deleted file mode 100644 index 82a2dc8b..00000000 --- a/tests/update.go +++ /dev/null @@ -1,382 +0,0 @@ -package tests - -import ( - "fmt" - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -func TestUpdate(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&User{}) - db.AutoMigrate(&User{}) - - t.Run("Update", func(t *testing.T) { - var ( - users = []*User{{ - Name: "update-before", - Age: 1, - Birthday: Now(), - }, { - Name: "update", - Age: 18, - Birthday: Now(), - }, { - Name: "update-after", - Age: 1, - Birthday: Now(), - }} - user = users[1] - lastUpdatedAt time.Time - ) - - checkUpdatedTime := func(name string, n time.Time) { - if n.UnixNano() == lastUpdatedAt.UnixNano() { - t.Errorf("%v: user's updated at should be changed, but got %v, was %v", name, n, lastUpdatedAt) - } - lastUpdatedAt = n - } - - checkOtherData := func(name string) { - var beforeUser, afterUser User - if err := db.Where("id = ?", users[0].ID).First(&beforeUser).Error; err != nil { - t.Errorf("errors happened when query before user: %v", err) - } - t.Run(name, func(t *testing.T) { - AssertObjEqual(t, beforeUser, users[0], "Name", "Age", "Birthday") - }) - - if err := db.Where("id = ?", users[2].ID).First(&afterUser).Error; err != nil { - t.Errorf("errors happened when query after user: %v", err) - } - t.Run(name, func(t *testing.T) { - AssertObjEqual(t, afterUser, users[2], "Name", "Age", "Birthday") - }) - } - - if err := db.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } else if user.ID == 0 { - t.Fatalf("user's primary value should not zero, %v", user.ID) - } else if user.UpdatedAt.IsZero() { - t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt) - } - lastUpdatedAt = user.UpdatedAt - - if err := db.Model(user).Update("Age", 10).Error; err != nil { - t.Errorf("errors happened when update: %v", err) - } else if user.Age != 10 { - t.Errorf("Age should equals to 10, but got %v", user.Age) - } - checkUpdatedTime("Update", user.UpdatedAt) - checkOtherData("Update") - - var result User - if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { - t.Errorf("errors happened when query: %v", err) - } else { - AssertObjEqual(t, result, user, "Name", "Age", "Birthday") - } - - values := map[string]interface{}{"Active": true, "age": 5} - if err := db.Model(user).Updates(values).Error; err != nil { - t.Errorf("errors happened when update: %v", err) - } else if user.Age != 5 { - t.Errorf("Age should equals to 5, but got %v", user.Age) - } else if user.Active != true { - t.Errorf("Active should be true, but got %v", user.Active) - } - checkUpdatedTime("Updates with map", user.UpdatedAt) - checkOtherData("Updates with map") - - var result2 User - if err := db.Where("id = ?", user.ID).First(&result2).Error; err != nil { - t.Errorf("errors happened when query: %v", err) - } else { - AssertObjEqual(t, result2, user, "Name", "Age", "Birthday") - } - - if err := db.Model(user).Updates(User{Age: 2}).Error; err != nil { - t.Errorf("errors happened when update: %v", err) - } else if user.Age != 2 { - t.Errorf("Age should equals to 2, but got %v", user.Age) - } - checkUpdatedTime("Updates with struct", user.UpdatedAt) - checkOtherData("Updates with struct") - - var result3 User - if err := db.Where("id = ?", user.ID).First(&result3).Error; err != nil { - t.Errorf("errors happened when query: %v", err) - } else { - AssertObjEqual(t, result3, user, "Name", "Age", "Birthday") - } - - user.Active = false - user.Age = 1 - if err := db.Save(user).Error; err != nil { - t.Errorf("errors happened when update: %v", err) - } else if user.Age != 1 { - t.Errorf("Age should equals to 1, but got %v", user.Age) - } else if user.Active != false { - t.Errorf("Active should equals to false, but got %v", user.Active) - } - checkUpdatedTime("Save", user.UpdatedAt) - checkOtherData("Save") - - var result4 User - if err := db.Where("id = ?", user.ID).First(&result4).Error; err != nil { - t.Errorf("errors happened when query: %v", err) - } else { - AssertObjEqual(t, result4, user, "Name", "Age", "Birthday") - } - - TestUpdateAssociations(t, db) - }) -} - -func TestUpdateAssociations(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{}) - db.Migrator().AutoMigrate(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{}) - - TestUpdateBelongsToAssociations(t, db) - TestUpdateHasOneAssociations(t, db) - TestUpdateHasManyAssociations(t, db) - TestUpdateMany2ManyAssociations(t, db) -} - -func TestUpdateBelongsToAssociations(t *testing.T, db *gorm.DB) { - check := func(t *testing.T, user User) { - if user.Company.Name != "" { - if user.CompanyID == nil { - t.Errorf("Company's foreign key should be saved") - } else { - var company Company - db.First(&company, "id = ?", *user.CompanyID) - if company.Name != user.Company.Name { - t.Errorf("Company's name should be same") - } - } - } else if user.CompanyID != nil { - t.Errorf("Company should not be created for zero value, got: %+v", user.CompanyID) - } - - if user.Manager != nil { - if user.ManagerID == nil { - t.Errorf("Manager's foreign key should be saved") - } else { - var manager User - db.First(&manager, "id = ?", *user.ManagerID) - if manager.Name != user.Manager.Name { - t.Errorf("Manager's name should be same") - } - } - } else if user.ManagerID != nil { - t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) - } - } - - t.Run("BelongsTo", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - user.Company = Company{Name: "company-belongs-to-association"} - user.Manager = &User{Name: "manager-belongs-to-association"} - if err := db.Save(&user).Error; err != nil { - t.Fatalf("errors happened when update: %v", err) - } - - check(t, user) - }) -} - -func TestUpdateHasOneAssociations(t *testing.T, db *gorm.DB) { - check := func(t *testing.T, user User) { - if user.Account.ID == 0 { - t.Errorf("Account should be saved") - } else if user.Account.UserID.Int64 != int64(user.ID) { - t.Errorf("Account's foreign key should be saved") - } else { - var account Account - db.First(&account, "id = ?", user.Account.ID) - if account.Number != user.Account.Number { - t.Errorf("Account's number should be sme") - } - } - } - - t.Run("HasOne", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - user.Account = Account{Number: "account-has-one-association"} - - if err := db.Save(&user).Error; err != nil { - t.Fatalf("errors happened when update: %v", err) - } - - check(t, user) - }) - - checkPet := func(t *testing.T, pet Pet) { - if pet.Toy.OwnerID != fmt.Sprint(pet.ID) || pet.Toy.OwnerType != "pets" { - t.Errorf("Failed to create polymorphic has one association - toy owner id %v, owner type %v", pet.Toy.OwnerID, pet.Toy.OwnerType) - } else { - var toy Toy - db.First(&toy, "owner_id = ? and owner_type = ?", pet.Toy.OwnerID, pet.Toy.OwnerType) - if toy.Name != pet.Toy.Name { - t.Errorf("Failed to query saved polymorphic has one association") - } - } - } - - t.Run("PolymorphicHasOne", func(t *testing.T) { - var pet = Pet{ - Name: "create", - } - - if err := db.Create(&pet).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - pet.Toy = Toy{Name: "Update-HasOneAssociation-Polymorphic"} - - if err := db.Save(&pet).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - checkPet(t, pet) - }) -} - -func TestUpdateHasManyAssociations(t *testing.T, db *gorm.DB) { - check := func(t *testing.T, user User) { - for _, pet := range user.Pets { - if pet.ID == 0 { - t.Errorf("Pet's foreign key should be saved") - } - - var result Pet - db.First(&result, "id = ?", pet.ID) - if result.Name != pet.Name { - t.Errorf("Pet's name should be same") - } else if result.UserID != user.ID { - t.Errorf("Pet's foreign key should be saved") - } - } - } - - t.Run("HasMany", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - user.Pets = []*Pet{{Name: "pet1"}, {Name: "pet2"}} - if err := db.Save(&user).Error; err != nil { - t.Fatalf("errors happened when update: %v", err) - } - - check(t, user) - }) - - checkToy := func(t *testing.T, user User) { - for idx, toy := range user.Toys { - if toy.ID == 0 { - t.Fatalf("Failed to create toy #%v", idx) - } - - var result Toy - db.First(&result, "id = ?", toy.ID) - if result.Name != toy.Name { - t.Errorf("Failed to query saved toy") - } else if result.OwnerID != fmt.Sprint(user.ID) || result.OwnerType != "users" { - t.Errorf("Failed to save relation") - } - } - } - - t.Run("PolymorphicHasMany", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - user.Toys = []Toy{{Name: "toy1"}, {Name: "toy2"}} - if err := db.Save(&user).Error; err != nil { - t.Fatalf("errors happened when update: %v", err) - } - - checkToy(t, user) - }) -} - -func TestUpdateMany2ManyAssociations(t *testing.T, db *gorm.DB) { - check := func(t *testing.T, user User) { - for _, language := range user.Languages { - var result Language - db.First(&result, "code = ?", language.Code) - // TODO - // if result.Name != language.Name { - // t.Errorf("Language's name should be same") - // } - } - - for _, f := range user.Friends { - if f.ID == 0 { - t.Errorf("Friend's foreign key should be saved") - } - - var result User - db.First(&result, "id = ?", f.ID) - if result.Name != f.Name { - t.Errorf("Friend's name should be same") - } - } - } - - t.Run("Many2Many", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - user.Languages = []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}} - user.Friends = []*User{{Name: "friend-1"}, {Name: "friend-2"}} - - if err := db.Save(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - check(t, user) - }) -} diff --git a/tests/update_test.go b/tests/update_test.go new file mode 100644 index 00000000..10835f97 --- /dev/null +++ b/tests/update_test.go @@ -0,0 +1,226 @@ +package tests_test + +import ( + "testing" + "time" + + . "github.com/jinzhu/gorm/tests" +) + +func TestUpdate(t *testing.T) { + var ( + users = []*User{ + GetUser("update-1", Config{}), + GetUser("update-2", Config{}), + GetUser("update-3", Config{}), + } + user = users[1] + lastUpdatedAt time.Time + ) + + checkUpdatedTime := func(name string, n time.Time) { + if n.UnixNano() == lastUpdatedAt.UnixNano() { + t.Errorf("%v: user's updated at should be changed, but got %v, was %v", name, n, lastUpdatedAt) + } + lastUpdatedAt = n + } + + checkOtherData := func(name string) { + var first, last User + if err := DB.Where("id = ?", users[0].ID).First(&first).Error; err != nil { + t.Errorf("errors happened when query before user: %v", err) + } + CheckUser(t, first, *users[0]) + + if err := DB.Where("id = ?", users[2].ID).First(&last).Error; err != nil { + t.Errorf("errors happened when query after user: %v", err) + } + CheckUser(t, last, *users[2]) + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } else if user.ID == 0 { + t.Fatalf("user's primary value should not zero, %v", user.ID) + } else if user.UpdatedAt.IsZero() { + t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt) + } + lastUpdatedAt = user.UpdatedAt + + if err := DB.Model(user).Update("Age", 10).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 10 { + t.Errorf("Age should equals to 10, but got %v", user.Age) + } + checkUpdatedTime("Update", user.UpdatedAt) + checkOtherData("Update") + + var result User + if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + CheckUser(t, result, *user) + } + + values := map[string]interface{}{"Active": true, "age": 5} + if err := DB.Model(user).Updates(values).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 5 { + t.Errorf("Age should equals to 5, but got %v", user.Age) + } else if user.Active != true { + t.Errorf("Active should be true, but got %v", user.Active) + } + checkUpdatedTime("Updates with map", user.UpdatedAt) + checkOtherData("Updates with map") + + var result2 User + if err := DB.Where("id = ?", user.ID).First(&result2).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + CheckUser(t, result2, *user) + } + + if err := DB.Model(user).Updates(User{Age: 2}).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 2 { + t.Errorf("Age should equals to 2, but got %v", user.Age) + } + checkUpdatedTime("Updates with struct", user.UpdatedAt) + checkOtherData("Updates with struct") + + var result3 User + if err := DB.Where("id = ?", user.ID).First(&result3).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + CheckUser(t, result3, *user) + } + + user.Active = false + user.Age = 1 + if err := DB.Save(user).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 1 { + t.Errorf("Age should equals to 1, but got %v", user.Age) + } else if user.Active != false { + t.Errorf("Active should equals to false, but got %v", user.Active) + } + checkUpdatedTime("Save", user.UpdatedAt) + checkOtherData("Save") + + var result4 User + if err := DB.Where("id = ?", user.ID).First(&result4).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + CheckUser(t, result4, *user) + } +} + +func TestUpdateBelongsTo(t *testing.T) { + var user = *GetUser("update-belongs-to", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Company = Company{Name: "company-belongs-to-association"} + user.Manager = &User{Name: "manager-belongs-to-association"} + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) +} + +func TestUpdateHasOne(t *testing.T) { + var user = *GetUser("update-has-one", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Account = Account{Number: "account-has-one-association"} + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Account").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + t.Run("Polymorphic", func(t *testing.T) { + var pet = Pet{Name: "create"} + + if err := DB.Create(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + pet.Toy = Toy{Name: "Update-HasOneAssociation-Polymorphic"} + + if err := DB.Save(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var pet2 Pet + DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID) + CheckPet(t, pet2, pet) + }) +} + +func TestUpdateHasManyAssociations(t *testing.T) { + var user = *GetUser("update-has-many", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Pets = []*Pet{{Name: "pet1"}, {Name: "pet2"}} + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Pets").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + t.Run("Polymorphic", func(t *testing.T) { + var user = *GetUser("update-has-many", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Toys = []Toy{{Name: "toy1"}, {Name: "toy2"}} + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Toys").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + }) +} + +func TestUpdateMany2ManyAssociations(t *testing.T) { + var user = *GetUser("update-many2many", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Languages = []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}} + for _, lang := range user.Languages { + DB.Create(&lang) + } + user.Friends = []*User{{Name: "friend-1"}, {Name: "friend-2"}} + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) +} diff --git a/tests/utils.go b/tests/utils.go index cb4e4fcc..001d77e9 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -2,10 +2,74 @@ package tests import ( "reflect" + "sort" + "strconv" + "strings" "testing" "time" + + "github.com/jinzhu/gorm/utils" ) +type Config struct { + Account bool + Pets int + Toys int + Company bool + Manager bool + Team int + Languages int + Friends int +} + +func GetUser(name string, config Config) *User { + var ( + birthday = time.Now() + user = User{ + Name: name, + Age: 18, + Birthday: &birthday, + } + ) + + if config.Account { + user.Account = Account{Number: name + "_account"} + } + + for i := 0; i < config.Pets; i++ { + user.Pets = append(user.Pets, &Pet{Name: name + "_pet_" + strconv.Itoa(i+1)}) + } + + for i := 0; i < config.Toys; i++ { + user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)}) + } + + if config.Company { + user.Company = Company{Name: "company-" + name} + } + + if config.Manager { + user.Manager = GetUser(name+"_manager", Config{}) + } + + for i := 0; i < config.Team; i++ { + user.Team = append(user.Team, *GetUser(name+"_team_"+strconv.Itoa(i+1), Config{})) + } + + for i := 0; i < config.Languages; i++ { + name := name + "_locale_" + strconv.Itoa(i+1) + language := Language{Code: name, Name: name} + DB.Create(&language) + user.Languages = append(user.Languages, language) + } + + for i := 0; i < config.Friends; i++ { + user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{})) + } + + return &user +} + func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { for _, name := range names { got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() @@ -21,11 +85,12 @@ func AssertEqual(t *testing.T, got, expect interface{}) { isEqual := func() { if curTime, ok := got.(time.Time); ok { format := "2006-01-02T15:04:05Z07:00" + if curTime.Format(format) != expect.(time.Time).Format(format) { - t.Errorf("expect: %v, got %v", expect.(time.Time).Format(format), curTime.Format(format)) + t.Errorf("%v: expect: %v, got %v", utils.FileWithLineNum(), expect.(time.Time).Format(format), curTime.Format(format)) } } else if got != expect { - t.Errorf("expect: %#v, got %#v", expect, got) + t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) } } @@ -34,7 +99,7 @@ func AssertEqual(t *testing.T, got, expect interface{}) { } if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() { - t.Errorf("expect: %+v, got %+v", expect, got) + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) return } @@ -55,3 +120,164 @@ func AssertEqual(t *testing.T, got, expect interface{}) { } } } + +func CheckPet(t *testing.T, pet Pet, expect Pet) { + if pet.ID != 0 { + var newPet Pet + if err := DB.Where("id = ?", pet.ID).First(&newPet).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, newPet, pet, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") + } + } + + AssertObjEqual(t, pet, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") + + AssertObjEqual(t, pet.Toy, expect.Toy, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "OwnerID", "OwnerType") + + if expect.Toy.Name != "" && expect.Toy.OwnerType != "pets" { + t.Errorf("toys's OwnerType, expect: %v, got %v", "pets", expect.Toy.OwnerType) + } +} + +func CheckUser(t *testing.T, user User, expect User) { + if user.ID != 0 { + var newUser User + if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + } + + AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + + t.Run("Account", func(t *testing.T) { + AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") + + if user.Account.Number != "" { + if !user.Account.UserID.Valid { + t.Errorf("Account's foreign key should be saved") + } else { + var account Account + DB.First(&account, "user_id = ?", user.ID) + AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") + } + } + }) + + t.Run("Pets", func(t *testing.T) { + if len(user.Pets) != len(expect.Pets) { + t.Errorf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets)) + } + + sort.Slice(user.Pets, func(i, j int) bool { + return user.Pets[i].ID > user.Pets[j].ID + }) + + sort.Slice(expect.Pets, func(i, j int) bool { + return expect.Pets[i].ID > expect.Pets[j].ID + }) + + for idx, pet := range user.Pets { + if pet == nil || expect.Pets[idx] == nil { + t.Errorf("pets#%v should equal, expect: %v, got %v", idx, expect.Pets[idx], pet) + } else { + CheckPet(t, *pet, *expect.Pets[idx]) + } + } + }) + + t.Run("Toys", func(t *testing.T) { + if len(user.Toys) != len(expect.Toys) { + t.Errorf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys)) + } + + sort.Slice(user.Toys, func(i, j int) bool { + return user.Toys[i].ID > user.Toys[j].ID + }) + + sort.Slice(expect.Toys, func(i, j int) bool { + return expect.Toys[i].ID > expect.Toys[j].ID + }) + + for idx, toy := range user.Toys { + if toy.OwnerType != "users" { + t.Errorf("toys's OwnerType, expect: %v, got %v", "users", toy.OwnerType) + } + + AssertObjEqual(t, toy, expect.Toys[idx], "ID", "CreatedAt", "UpdatedAt", "Name", "OwnerID", "OwnerType") + } + }) + + t.Run("Company", func(t *testing.T) { + AssertObjEqual(t, user.Company, expect.Company, "ID", "Name") + }) + + t.Run("Manager", func(t *testing.T) { + if user.Manager != nil { + if user.ManagerID == nil { + t.Errorf("Manager's foreign key should be saved") + } else { + var manager User + DB.First(&manager, "id = ?", *user.ManagerID) + AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + } else if user.ManagerID != nil { + t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) + } + }) + + t.Run("Team", func(t *testing.T) { + if len(user.Team) != len(expect.Team) { + t.Errorf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team)) + } + + sort.Slice(user.Team, func(i, j int) bool { + return user.Team[i].ID > user.Team[j].ID + }) + + sort.Slice(expect.Team, func(i, j int) bool { + return expect.Team[i].ID > expect.Team[j].ID + }) + + for idx, team := range user.Team { + AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + }) + + t.Run("Languages", func(t *testing.T) { + if len(user.Languages) != len(expect.Languages) { + t.Errorf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages)) + } + + sort.Slice(user.Languages, func(i, j int) bool { + return strings.Compare(user.Languages[i].Code, user.Languages[j].Code) > 0 + }) + + sort.Slice(expect.Languages, func(i, j int) bool { + return strings.Compare(expect.Languages[i].Code, expect.Languages[j].Code) > 0 + }) + for idx, language := range user.Languages { + AssertObjEqual(t, language, expect.Languages[idx], "Code", "Name") + } + }) + + t.Run("Friends", func(t *testing.T) { + if len(user.Friends) != len(expect.Friends) { + t.Errorf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends)) + } + + sort.Slice(user.Friends, func(i, j int) bool { + return user.Friends[i].ID > user.Friends[j].ID + }) + + sort.Slice(expect.Friends, func(i, j int) bool { + return expect.Friends[i].ID > expect.Friends[j].ID + }) + + for idx, friend := range user.Friends { + AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + }) +} From e60a8d54ff609e9ba74c2335b22a7c36decaa5fb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 24 May 2020 00:52:25 +0800 Subject: [PATCH 318/881] Test Nested Preload --- callbacks/preload.go | 6 ++--- schema/schema.go | 2 +- schema/utils.go | 12 ++++++--- tests/preload_test.go | 58 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 70 insertions(+), 8 deletions(-) create mode 100644 tests/preload_test.go diff --git a/callbacks/preload.go b/callbacks/preload.go index 7e3810b5..f48777c2 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -22,7 +22,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { ) if len(rels) > 1 { - reflectValue = schema.GetRelationsValues(reflectValue, rels[:len(rels)]) + reflectValue = schema.GetRelationsValues(reflectValue, rels[:len(rels)-1]) } if rel.JoinTable != nil { @@ -107,9 +107,9 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { rel.Field.Set(data, reflectResults.Index(i).Interface()) case reflect.Slice, reflect.Array: if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(data, reflect.Append(reflectFieldValue, reflectResults.Index(i).Addr()).Interface()) - } else { rel.Field.Set(data, reflect.Append(reflectFieldValue, reflectResults.Index(i)).Interface()) + } else { + rel.Field.Set(data, reflect.Append(reflectFieldValue, reflectResults.Index(i).Elem()).Interface()) } } } diff --git a/schema/schema.go b/schema/schema.go index 79faae12..caae55ac 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -49,7 +49,7 @@ func (schema Schema) String() string { } func (schema Schema) MakeSlice() reflect.Value { - slice := reflect.MakeSlice(reflect.SliceOf(schema.ModelType), 0, 0) + slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 0) results := reflect.New(slice.Type()) results.Elem().Set(slice) return results diff --git a/schema/utils.go b/schema/utils.go index c47f1984..f7808f0e 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -55,17 +55,21 @@ func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag // GetRelationsValues get relations's values from a reflect value func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { for _, rel := range rels { - reflectResults = reflect.MakeSlice(reflect.SliceOf(rel.FieldSchema.ModelType), 0, 0) + reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 0) appendToResults := func(value reflect.Value) { if _, isZero := rel.Field.ValueOf(value); !isZero { result := reflect.Indirect(rel.Field.ReflectValueOf(value)) switch result.Kind() { case reflect.Struct: - reflectResults = reflect.Append(reflectResults, result) + reflectResults = reflect.Append(reflectResults, result.Addr()) case reflect.Slice, reflect.Array: - for i := 0; i < value.Len(); i++ { - reflectResults = reflect.Append(reflectResults, reflect.Indirect(value.Index(i))) + for i := 0; i < result.Len(); i++ { + if result.Index(i).Kind() == reflect.Ptr { + reflectResults = reflect.Append(reflectResults, result.Index(i)) + } else { + reflectResults = reflect.Append(reflectResults, result.Index(i).Addr()) + } } } } diff --git a/tests/preload_test.go b/tests/preload_test.go new file mode 100644 index 00000000..74f21f55 --- /dev/null +++ b/tests/preload_test.go @@ -0,0 +1,58 @@ +package tests_test + +import ( + "strconv" + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestNestedPreload(t *testing.T) { + var user = *GetUser("nested_preload", Config{Pets: 2}) + + for idx, pet := range user.Pets { + pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)} + } + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var user2 User + DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID) + + CheckUser(t, user2, user) +} + +func TestNestedPreloadForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice_nested_preload_1", Config{Pets: 2}), + *GetUser("slice_nested_preload_2", Config{Pets: 0}), + *GetUser("slice_nested_preload_3", Config{Pets: 3}), + } + + for _, user := range users { + for idx, pet := range user.Pets { + pet.Toy = Toy{Name: user.Name + "_toy_nested_preload_" + strconv.Itoa(idx+1)} + } + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + var users2 []User + DB.Preload("Pets.Toy").Find(&users2, "id IN ?", userIDs) + + for idx, user := range users2 { + CheckUser(t, user, users[idx]) + } +} + +func TestPreloadWithConds(t *testing.T) { +} From 1c39ac921b3cdc38974092a538649b15331ccdb4 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 24 May 2020 01:16:08 +0800 Subject: [PATCH 319/881] Test preload with conds --- tests/preload_test.go | 80 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/tests/preload_test.go b/tests/preload_test.go index 74f21f55..b14c5b90 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -1,9 +1,11 @@ package tests_test import ( + "sort" "strconv" "testing" + "github.com/jinzhu/gorm/clause" . "github.com/jinzhu/gorm/tests" ) @@ -55,4 +57,82 @@ func TestNestedPreloadForSlice(t *testing.T) { } func TestPreloadWithConds(t *testing.T) { + var users = []User{ + *GetUser("slice_nested_preload_1", Config{Account: true}), + *GetUser("slice_nested_preload_2", Config{Account: false}), + *GetUser("slice_nested_preload_3", Config{Account: true}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + var users2 []User + DB.Preload("Account", clause.Eq{Column: "number", Value: users[0].Account.Number}).Find(&users2, "id IN ?", userIDs) + sort.Slice(users2, func(i, j int) bool { + return users2[i].ID < users2[j].ID + }) + + for idx, user := range users2[1:2] { + if user.Account.Number != "" { + t.Errorf("No account should found for user %v but got %v", idx+2, user.Account.Number) + } + } + + CheckUser(t, users2[0], users[0]) +} + +func TestNestedPreloadWithConds(t *testing.T) { + var users = []User{ + *GetUser("slice_nested_preload_1", Config{Pets: 2}), + *GetUser("slice_nested_preload_2", Config{Pets: 0}), + *GetUser("slice_nested_preload_3", Config{Pets: 3}), + } + + for _, user := range users { + for idx, pet := range user.Pets { + pet.Toy = Toy{Name: user.Name + "_toy_nested_preload_" + strconv.Itoa(idx+1)} + } + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + var users2 []User + DB.Preload("Pets.Toy", "name like ?", `%preload_3`).Find(&users2, "id IN ?", userIDs) + + for idx, user := range users2[0:2] { + for _, pet := range user.Pets { + if pet.Toy.Name != "" { + t.Errorf("No toy should for user %v's pet %v but got %v", idx+1, pet.Name, pet.Toy.Name) + } + } + } + + if len(users2[2].Pets) != 3 { + t.Errorf("Invalid pet toys found for user 3 got %v", len(users2[2].Pets)) + } else { + sort.Slice(users2[2].Pets, func(i, j int) bool { + return users2[2].Pets[i].ID < users2[2].Pets[j].ID + }) + + for _, pet := range users2[2].Pets[0:2] { + if pet.Toy.Name != "" { + t.Errorf("No toy should for user %v's pet %v but got %v", 3, pet.Name, pet.Toy.Name) + } + } + + CheckPet(t, *users2[2].Pets[2], *users[2].Pets[2]) + } } From cbc4a8114026692f8f1720087f674f2f4e4df3f6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 24 May 2020 11:32:59 +0800 Subject: [PATCH 320/881] Add Count tests --- association.go | 7 ++--- callbacks.go | 3 ++- callbacks/query.go | 7 ++++- callbacks/scan.go | 5 ++++ clause/values.go | 3 --- finisher_api.go | 13 ++++++++- statement.go | 54 ++++++++++++++++++++------------------ tests/associations_test.go | 8 ++++++ tests/count_test.go | 42 +++++++++++++++++++++++++++++ 9 files changed, 108 insertions(+), 34 deletions(-) create mode 100644 tests/count_test.go diff --git a/association.go b/association.go index abcae47d..bd2a7cdd 100644 --- a/association.go +++ b/association.go @@ -247,11 +247,12 @@ func (association *Association) Clear() error { return association.Replace() } -func (association *Association) Count() (count int) { +func (association *Association) Count() (count int64) { if association.Error == nil { var ( - tx = association.DB - conds = association.Relationship.ToQueryConditions(tx.Statement.ReflectValue) + conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) + modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() + tx = association.DB.Model(modelValue) ) if association.Relationship.JoinTable != nil { diff --git a/callbacks.go b/callbacks.go index 61cebc81..629b90aa 100644 --- a/callbacks.go +++ b/callbacks.go @@ -73,6 +73,7 @@ func (cs *callbacks) Raw() *processor { func (p *processor) Execute(db *DB) { curTime := time.Now() + db.RowsAffected = 0 if stmt := db.Statement; stmt != nil { if stmt.Model == nil { stmt.Model = stmt.Dest @@ -102,7 +103,7 @@ func (p *processor) Execute(db *DB) { }, db.Error) stmt.reinit() - db.Config.statementPool.Put(stmt) + // db.Config.statementPool.Put(stmt) } } diff --git a/callbacks/query.go b/callbacks/query.go index 4a89c575..95b5ead3 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -21,6 +21,11 @@ func Query(db *gorm.DB) { clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ Name: f.DBName, }) + } else { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Name: name, + Raw: true, + }) } } } @@ -85,7 +90,7 @@ func Query(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.From{}) } - db.Statement.AddClauseIfNotExists(clauseSelect) + db.Statement.AddClause(clauseSelect) db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } diff --git a/callbacks/scan.go b/callbacks/scan.go index 6ea8bf23..9ffcab4a 100644 --- a/callbacks/scan.go +++ b/callbacks/scan.go @@ -49,6 +49,11 @@ func Scan(rows *sql.Rows, db *gorm.DB) { } *dest = append(*dest, v) } + case *int, *int64, *uint, *uint64: + for rows.Next() { + db.RowsAffected++ + rows.Scan(dest) + } default: switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: diff --git a/clause/values.go b/clause/values.go index a997fc26..b2f5421b 100644 --- a/clause/values.go +++ b/clause/values.go @@ -41,8 +41,5 @@ func (values Values) Build(builder Builder) { // MergeClause merge values clauses func (values Values) MergeClause(clause *Clause) { clause.Name = "" - if v, ok := clause.Expression.(Values); ok { - values.Values = append(v.Values, values.Values...) - } clause.Expression = values } diff --git a/finisher_api.go b/finisher_api.go index 1b2a7e29..6a787576 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -145,8 +145,19 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { return } -func (db *DB) Count(value interface{}) (tx *DB) { +func (db *DB) Count(count *int64) (tx *DB) { tx = db.getInstance() + if len(tx.Statement.Selects) == 0 { + tx.Statement.Selects = []string{"count(1)"} + } + if tx.Statement.Model == nil { + tx.Statement.Model = tx.Statement.Dest + } + tx.Statement.Dest = count + tx.callbacks.Query().Execute(tx) + if db.RowsAffected != 1 { + *count = db.RowsAffected + } return } diff --git a/statement.go b/statement.go index 1ea5a56c..0abf7a7e 100644 --- a/statement.go +++ b/statement.go @@ -63,6 +63,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { case clause.Table: if v.Name == clause.CurrentTable { stmt.DB.Dialector.QuoteTo(writer, stmt.Table) + } else if v.Raw { + writer.WriteString(v.Name) } else { stmt.DB.Dialector.QuoteTo(writer, v.Name) } @@ -85,6 +87,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil { stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName) } + } else if v.Raw { + writer.WriteString(v.Name) } else { stmt.DB.Dialector.QuoteTo(writer, v.Name) } @@ -275,33 +279,33 @@ func (stmt *Statement) Parse(value interface{}) (err error) { } func (stmt *Statement) reinit() { - stmt.Table = "" - stmt.Model = nil - stmt.Selects = nil - stmt.Omits = nil - stmt.ConnPool = stmt.DB.Config.ConnPool - stmt.Schema = nil - stmt.Context = context.Background() - stmt.RaiseErrorOnNotFound = false + // stmt.Table = "" + // stmt.Model = nil + // stmt.Selects = nil + // stmt.Omits = nil + // stmt.ConnPool = stmt.DB.Config.ConnPool + // stmt.Context = context.Background() + // stmt.RaiseErrorOnNotFound = false + // for k := range stmt.Clauses { + // delete(stmt.Clauses, k) + // } + + // for k := range stmt.Joins { + // delete(stmt.Joins, k) + // } + + // for k := range stmt.Preloads { + // delete(stmt.Preloads, k) + // } + + // stmt.Settings.Range(func(k, _ interface{}) bool { + // stmt.Settings.Delete(k) + // return true + // }) + + stmt.Schema = nil stmt.SQL.Reset() stmt.Vars = nil stmt.NamedVars = nil - - for k := range stmt.Clauses { - delete(stmt.Clauses, k) - } - - for k := range stmt.Joins { - delete(stmt.Joins, k) - } - - for k := range stmt.Preloads { - delete(stmt.Preloads, k) - } - - stmt.Settings.Range(func(k, _ interface{}) bool { - stmt.Settings.Delete(k) - return true - }) } diff --git a/tests/associations_test.go b/tests/associations_test.go index dc88ee03..845ee65e 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -21,4 +21,12 @@ func TestAssociationForBelongsTo(t *testing.T) { user2.Manager = &User{} DB.Model(&user2).Association("Manager").Find(user2.Manager) CheckUser(t, user2, user) + + if count := DB.Model(&user).Association("Company").Count(); count != 1 { + t.Errorf("invalid company count, got %v", count) + } + + if count := DB.Model(&user).Association("Manager").Count(); count != 1 { + t.Errorf("invalid manager count, got %v", count) + } } diff --git a/tests/count_test.go b/tests/count_test.go new file mode 100644 index 00000000..960db167 --- /dev/null +++ b/tests/count_test.go @@ -0,0 +1,42 @@ +package tests_test + +import ( + "fmt" + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestCount(t *testing.T) { + var ( + user1 = *GetUser("count-1", Config{}) + user2 = *GetUser("count-2", Config{}) + user3 = *GetUser("count-3", Config{}) + users []User + count, count1, count2 int64 + ) + + DB.Save(&user1).Save(&user2).Save(&user3) + + if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil { + t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + if count != int64(len(users)) { + t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users)) + } + + DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2) + if count1 != 1 || count2 != 3 { + t.Errorf("multiple count in chain should works") + } + + var count3 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { + t.Errorf("No error should happen when count with group, but got %v", err) + } + + if count3 != 2 { + t.Errorf("Should get correct count for count with group, but got %v", count3) + } +} From 91a695893c4c5c5e830631fa58d63b9a26d50aed Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 24 May 2020 17:24:23 +0800 Subject: [PATCH 321/881] Test Association For BelongsTo --- association.go | 79 +++++++++++++++++------- callbacks/associations.go | 2 +- callbacks/helper.go | 2 +- callbacks/update.go | 30 +++++++-- gorm.go | 11 ++-- schema/field.go | 29 +++++---- schema/relationship.go | 1 + statement.go | 33 ++++++++++ tests/associations_test.go | 121 +++++++++++++++++++++++++++++++++++++ tests/count_test.go | 2 +- 10 files changed, 265 insertions(+), 45 deletions(-) diff --git a/association.go b/association.go index bd2a7cdd..c179a148 100644 --- a/association.go +++ b/association.go @@ -19,8 +19,10 @@ type Association struct { func (db *DB) Association(column string) *Association { association := &Association{DB: db} + table := db.Statement.Table if err := db.Statement.Parse(db.Statement.Model); err == nil { + db.Statement.Table = table association.Relationship = db.Statement.Schema.Relationships.Relations[column] if association.Relationship == nil { @@ -83,6 +85,16 @@ func (association *Association) Replace(values ...interface{}) error { rel := association.Relationship switch rel.Type { + case schema.BelongsTo: + if len(values) == 0 { + updateMap := map[string]interface{}{} + + for _, ref := range rel.References { + updateMap[ref.ForeignKey.DBName] = nil + } + + association.DB.UpdateColumns(updateMap) + } case schema.HasOne, schema.HasMany: var ( primaryFields []*schema.Field @@ -90,6 +102,9 @@ func (association *Association) Replace(values ...interface{}) error { updateMap = map[string]interface{}{} modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() ) + if rel.Type == schema.BelongsTo { + modelValue = reflect.New(rel.Schema.ModelType).Interface() + } for _, ref := range rel.References { if ref.OwnPrimaryKey { @@ -101,7 +116,7 @@ func (association *Association) Replace(values ...interface{}) error { } _, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) - if len(values) > 0 { + if len(values) == 0 { column, queryValues := schema.ToQueryValues(foreignKeys, values) association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap) } @@ -158,13 +173,13 @@ func (association *Association) Replace(values ...interface{}) error { func (association *Association) Delete(values ...interface{}) error { if association.Error == nil { var ( - tx = association.DB - rel = association.Relationship - reflectValue = tx.Statement.ReflectValue - conds = rel.ToQueryConditions(reflectValue) - relFields []*schema.Field - foreignKeys []string - updateAttrs = map[string]interface{}{} + tx = association.DB + rel = association.Relationship + reflectValue = tx.Statement.ReflectValue + relFields []*schema.Field + foreignKeyFields []*schema.Field + foreignKeys []string + updateAttrs = map[string]interface{}{} ) for _, ref := range rel.References { @@ -174,6 +189,7 @@ func (association *Association) Delete(values ...interface{}) error { relFields = append(relFields, ref.ForeignKey) } else { relFields = append(relFields, ref.PrimaryKey) + foreignKeyFields = append(foreignKeyFields, ref.ForeignKey) } foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) @@ -189,11 +205,14 @@ func (association *Association) Delete(values ...interface{}) error { switch rel.Type { case schema.HasOne, schema.HasMany: modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() + conds := rel.ToQueryConditions(reflectValue) tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) case schema.BelongsTo: - tx.Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) + modelValue := reflect.New(rel.Schema.ModelType).Interface() + tx.Model(modelValue).UpdateColumns(updateAttrs) case schema.Many2Many: modelValue := reflect.New(rel.JoinTable.ModelType).Interface() + conds := rel.ToQueryConditions(reflectValue) tx.Clauses(clause.Where{Exprs: conds}).Delete(modelValue) } @@ -216,13 +235,16 @@ func (association *Association) Delete(values ...interface{}) error { } } - rel.Field.Set(data, validFieldValues) + rel.Field.Set(data, validFieldValues.Interface()) case reflect.Struct: for idx, field := range relFields { - fieldValues[idx], _ = field.ValueOf(data) + fieldValues[idx], _ = field.ValueOf(fieldValue) } if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; ok { - rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType)) + rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()) + for _, field := range foreignKeyFields { + field.Set(data, reflect.Zero(field.FieldType).Interface()) + } } } } @@ -275,7 +297,11 @@ func (association *Association) Count() (count int64) { } func (association *Association) saveAssociation(clear bool, values ...interface{}) { - reflectValue := association.DB.Statement.ReflectValue + var ( + reflectValue = association.DB.Statement.ReflectValue + assignBacks = [][2]reflect.Value{} + assignBack = association.Relationship.Field.FieldType.Kind() == reflect.Struct + ) appendToRelations := func(source, rv reflect.Value, clear bool) { switch association.Relationship.Type { @@ -283,10 +309,16 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ switch rv.Kind() { case reflect.Slice, reflect.Array: if rv.Len() > 0 { - association.Error = association.Relationship.Field.Set(source, rv.Index(0)) + association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface()) + if assignBack { + assignBacks = append(assignBacks, [2]reflect.Value{source, rv.Index(0)}) + } } case reflect.Struct: - association.Error = association.Relationship.Field.Set(source, rv) + association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface()) + if assignBack { + assignBacks = append(assignBacks, [2]reflect.Value{source, rv}) + } } case schema.HasMany, schema.Many2Many: elemType := association.Relationship.Field.IndirectFieldType.Elem() @@ -315,7 +347,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } if association.Error == nil { - association.Error = association.Relationship.Field.Set(source, fieldValue) + association.Error = association.Relationship.Field.Set(source, fieldValue.Addr().Interface()) } } } @@ -333,7 +365,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ if len(values) != reflectValue.Len() { if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { - association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType)) + association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) } break } @@ -349,19 +381,24 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } case reflect.Struct: if clear && len(values) == 0 { - association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType)) + association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) } for idx, value := range values { - appendToRelations(reflectValue, reflect.Indirect(reflect.ValueOf(value)), clear && idx == 0) + rv := reflect.Indirect(reflect.ValueOf(value)) + appendToRelations(reflectValue, rv, clear && idx == 0) } _, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue) } if hasZero { - association.DB.Save(reflectValue.Interface()) + association.DB.Save(reflectValue.Addr().Interface()) } else { - association.DB.Select(selectedColumns).Save(reflectValue.Interface()) + association.DB.Select(selectedColumns).Save(reflectValue.Addr().Interface()) + } + + for _, assignBack := range assignBacks { + reflect.Indirect(assignBack[1]).Set(association.Relationship.Field.ReflectValueOf(assignBack[0])) } } diff --git a/callbacks/associations.go b/callbacks/associations.go index ef040b71..37addd60 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -73,8 +73,8 @@ func SaveBeforeAssociations(db *gorm.DB) { if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { db.Session(&gorm.Session{}).Create(rv.Interface()) - setupReferences(db.Statement.ReflectValue, rv) } + setupReferences(db.Statement.ReflectValue, rv) } } } diff --git a/callbacks/helper.go b/callbacks/helper.go index 43e90b8a..8da74690 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -22,7 +22,7 @@ func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate boo break } - if field := stmt.Schema.LookUpField(column); field != nil { + if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { results[field.DBName] = true } else { results[column] = true diff --git a/callbacks/update.go b/callbacks/update.go index 53c646e9..be9fe30a 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -7,6 +7,7 @@ import ( "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" ) func BeforeUpdate(db *gorm.DB) { @@ -91,8 +92,27 @@ func AfterUpdate(db *gorm.DB) { // ConvertToAssignments convert to update assignments func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { - selectColumns, restricted := SelectAndOmitColumns(stmt, false, true) - reflectModelValue := reflect.ValueOf(stmt.Model) + var ( + selectColumns, restricted = SelectAndOmitColumns(stmt, false, true) + reflectModelValue = reflect.Indirect(reflect.ValueOf(stmt.Model)) + assignValue func(field *schema.Field, value interface{}) + ) + + switch reflectModelValue.Kind() { + case reflect.Slice, reflect.Array: + assignValue = func(field *schema.Field, value interface{}) { + for i := 0; i < reflectModelValue.Len(); i++ { + field.Set(reflectModelValue.Index(i), value) + } + } + case reflect.Struct: + assignValue = func(field *schema.Field, value interface{}) { + field.Set(reflectModelValue, value) + } + default: + assignValue = func(field *schema.Field, value interface{}) { + } + } switch value := stmt.Dest.(type) { case map[string]interface{}: @@ -111,7 +131,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { value[k] = time.Now() } set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) - field.Set(reflectModelValue, value[k]) + assignValue(field, value[k]) } } else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]}) @@ -122,7 +142,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { now := time.Now() set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) - field.Set(reflectModelValue, now) + assignValue(field, now) } } default: @@ -140,7 +160,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if ok || !isZero { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) - field.Set(reflectModelValue, value) + assignValue(field, value) } } } else { diff --git a/gorm.go b/gorm.go index f8c944af..1fa69383 100644 --- a/gorm.go +++ b/gorm.go @@ -105,11 +105,12 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { func (db *DB) Session(config *Session) *DB { var ( tx = db.getInstance() + stmt = tx.Statement.clone() txConfig = *tx.Config ) if config.Context != nil { - tx.Statement.Context = config.Context + stmt.Context = config.Context } if config.Logger != nil { @@ -120,9 +121,11 @@ func (db *DB) Session(config *Session) *DB { txConfig.NowFunc = config.NowFunc } - tx.Config = &txConfig - tx.clone = true - return tx + return &DB{ + Config: &txConfig, + Statement: stmt, + clone: true, + } } // WithContext change current instance db's context to ctx diff --git a/schema/field.go b/schema/field.go index 7b37733b..9a5f1fc6 100644 --- a/schema/field.go +++ b/schema/field.go @@ -372,19 +372,24 @@ func (field *Field) setupValuerAndSetter() { } recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) - } else if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - return setter(value, v) - } - } else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { - field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) - } else if reflectV.Kind() == reflect.Ptr { - return field.Set(value, reflectV.Elem().Interface()) + if v == nil { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { - return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + reflectV := reflect.ValueOf(v) + + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) + } else if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + return setter(value, v) + } + } else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { + field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) + } else if reflectV.Kind() == reflect.Ptr { + return field.Set(value, reflectV.Elem().Interface()) + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } } return err } diff --git a/schema/relationship.go b/schema/relationship.go index 59aaa7e4..d10bfe30 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -387,6 +387,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds [] _, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) column, values := ToQueryValues(relForeignKeys, foreignValues) + conds = append(conds, clause.IN{Column: column, Values: values}) return } diff --git a/statement.go b/statement.go index 0abf7a7e..d37622dd 100644 --- a/statement.go +++ b/statement.go @@ -278,6 +278,39 @@ func (stmt *Statement) Parse(value interface{}) (err error) { return err } +func (stmt *Statement) clone() *Statement { + newStmt := &Statement{ + DB: stmt.DB, + Table: stmt.Table, + Model: stmt.Model, + Dest: stmt.Dest, + ReflectValue: stmt.ReflectValue, + Clauses: map[string]clause.Clause{}, + Selects: stmt.Selects, + Omits: stmt.Omits, + Joins: map[string][]interface{}{}, + Preloads: map[string][]interface{}{}, + ConnPool: stmt.ConnPool, + Schema: stmt.Schema, + Context: stmt.Context, + RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, + } + + for k, c := range stmt.Clauses { + newStmt.Clauses[k] = c + } + + for k, p := range stmt.Preloads { + newStmt.Preloads[k] = p + } + + for k, j := range stmt.Joins { + newStmt.Joins[k] = j + } + + return newStmt +} + func (stmt *Statement) reinit() { // stmt.Table = "" // stmt.Model = nil diff --git a/tests/associations_test.go b/tests/associations_test.go index 845ee65e..159f7f3a 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -15,6 +15,7 @@ func TestAssociationForBelongsTo(t *testing.T) { CheckUser(t, user, user) + // Find var user2 User DB.Find(&user2, "id = ?", user.ID) DB.Model(&user2).Association("Company").Find(&user2.Company) @@ -22,6 +23,7 @@ func TestAssociationForBelongsTo(t *testing.T) { DB.Model(&user2).Association("Manager").Find(user2.Manager) CheckUser(t, user2, user) + // Count if count := DB.Model(&user).Association("Company").Count(); count != 1 { t.Errorf("invalid company count, got %v", count) } @@ -29,4 +31,123 @@ func TestAssociationForBelongsTo(t *testing.T) { if count := DB.Model(&user).Association("Manager").Count(); count != 1 { t.Errorf("invalid manager count, got %v", count) } + + // Append + var company = Company{Name: "company-belongs-to-append"} + var manager = GetUser("manager-belongs-to-append", Config{}) + + if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { + t.Fatalf("Error happened when append Company, got %v", err) + } + + if company.ID == 0 { + t.Fatalf("Company's ID should be created") + } + + if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil { + t.Fatalf("Error happened when append Manager, got %v", err) + } + + if manager.ID == 0 { + t.Fatalf("Manager's ID should be created") + } + + user.Company = company + user.Manager = manager + user.CompanyID = &company.ID + user.ManagerID = &manager.ID + CheckUser(t, user2, user) + + // Replace + var company2 = Company{Name: "company-belongs-to-replace"} + var manager2 = GetUser("manager-belongs-to-replace", Config{}) + + if err := DB.Model(&user2).Association("Company").Replace(&company2); err != nil { + t.Fatalf("Error happened when replace Company, got %v", err) + } + + if company2.ID == 0 { + t.Fatalf("Company's ID should be created") + } + + if err := DB.Model(&user2).Association("Manager").Replace(manager2); err != nil { + t.Fatalf("Error happened when replace Manager, got %v", err) + } + + if manager2.ID == 0 { + t.Fatalf("Manager's ID should be created") + } + + user.Company = company2 + user.Manager = manager2 + user.CompanyID = &company2.ID + user.ManagerID = &manager2.ID + CheckUser(t, user2, user) + + // Delete + if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil { + t.Fatalf("Error happened when delete Company, got %v", err) + } + + if count := DB.Model(&user2).Association("Company").Count(); count != 1 { + t.Errorf("Invalid company count after delete non-existing association, got %v", count) + } + + if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil { + t.Fatalf("Error happened when delete Company, got %v", err) + } + + if count := DB.Model(&user2).Association("Company").Count(); count != 0 { + t.Errorf("Invalid company count after delete, got %v", count) + } + + if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil { + t.Fatalf("Error happened when delete Manager, got %v", err) + } + + if count := DB.Model(&user2).Association("Manager").Count(); count != 1 { + t.Errorf("Invalid manager count after delete non-existing association, got %v", count) + } + + if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil { + t.Fatalf("Error happened when delete Manager, got %v", err) + } + + if count := DB.Model(&user2).Association("Manager").Count(); count != 0 { + t.Errorf("Invalid manager count after delete, got %v", count) + } + + // Prepare Data + if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { + t.Fatalf("Error happened when append Company, got %v", err) + } + + if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil { + t.Fatalf("Error happened when append Manager, got %v", err) + } + + if count := DB.Model(&user2).Association("Company").Count(); count != 1 { + t.Errorf("Invalid company count after append, got %v", count) + } + + if count := DB.Model(&user2).Association("Manager").Count(); count != 1 { + t.Errorf("Invalid manager count after append, got %v", count) + } + + // Clear + if err := DB.Model(&user2).Association("Company").Clear(); err != nil { + t.Errorf("Error happened when clear Company, got %v", err) + } + + if err := DB.Model(&user2).Association("Manager").Clear(); err != nil { + t.Errorf("Error happened when clear Manager, got %v", err) + } + + if count := DB.Model(&user2).Association("Company").Count(); count != 0 { + t.Errorf("Invalid company count after clear, got %v", count) + } + + if count := DB.Model(&user2).Association("Manager").Count(); count != 0 { + t.Errorf("Invalid manager count after clear, got %v", count) + } } diff --git a/tests/count_test.go b/tests/count_test.go index 960db167..257959c3 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -33,7 +33,7 @@ func TestCount(t *testing.T) { var count3 int64 if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { - t.Errorf("No error should happen when count with group, but got %v", err) + t.Errorf("Error happened when count with group, but got %v", err) } if count3 != 2 { From 2db33730b63a3680b5fe108e3d9f07de2d3c1671 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 24 May 2020 20:44:37 +0800 Subject: [PATCH 322/881] Add Slice Association for BelongsTo --- association.go | 20 +++++-- callbacks/update.go | 26 +++++++-- errors.go | 2 + finisher_api.go | 17 +++--- tests/associations_test.go | 107 ++++++++++++++++++++++++------------- 5 files changed, 122 insertions(+), 50 deletions(-) diff --git a/association.go b/association.go index c179a148..ff1e155f 100644 --- a/association.go +++ b/association.go @@ -366,6 +366,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey { + ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } + } } break } @@ -382,6 +387,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ case reflect.Struct: if clear && len(values) == 0 { association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey { + ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } + } } for idx, value := range values { @@ -392,10 +402,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ _, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue) } - if hasZero { - association.DB.Save(reflectValue.Addr().Interface()) - } else { - association.DB.Select(selectedColumns).Save(reflectValue.Addr().Interface()) + if len(values) > 0 { + if hasZero { + association.DB.Create(reflectValue.Addr().Interface()) + } else { + association.DB.Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface()) + } } for _, assignBack := range assignBacks { diff --git a/callbacks/update.go b/callbacks/update.go index be9fe30a..6a59e487 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -173,10 +173,28 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } if stmt.Dest != stmt.Model { - reflectValue := reflect.ValueOf(stmt.Model) - for _, field := range stmt.Schema.PrimaryFields { - if value, isZero := field.ValueOf(reflectValue); !isZero { - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Model)) + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + var priamryKeyExprs []clause.Expression + for i := 0; i < reflectValue.Len(); i++ { + var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) + var notZero bool + for idx, field := range stmt.Schema.PrimaryFields { + value, isZero := field.ValueOf(reflectValue.Index(i)) + exprs[idx] = clause.Eq{Column: field.DBName, Value: value} + notZero = notZero || !isZero + } + if notZero { + priamryKeyExprs = append(priamryKeyExprs, clause.And(exprs...)) + } + } + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}}) + case reflect.Struct: + for _, field := range stmt.Schema.PrimaryFields { + if value, isZero := field.ValueOf(reflectValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + } } } } diff --git a/errors.go b/errors.go index a990cc4a..4f2bd4fa 100644 --- a/errors.go +++ b/errors.go @@ -19,4 +19,6 @@ var ( ErrMissingWhereClause = errors.New("missing WHERE clause while deleting") // ErrUnsupportedRelation unsupported relations ErrUnsupportedRelation = errors.New("unsupported relations") + // ErrPtrStructSupported only ptr of struct supported + ErrPtrStructSupported = errors.New("only ptr of struct supported") ) diff --git a/finisher_api.go b/finisher_api.go index 6a787576..c64ecdda 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -23,12 +23,17 @@ func (db *DB) Save(value interface{}) (tx *DB) { if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} - reflectValue := reflect.ValueOf(value) - for idx, pf := range tx.Statement.Schema.PrimaryFields { - if pv, isZero := pf.ValueOf(reflectValue); isZero { - tx.callbacks.Create().Execute(tx) - where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} - return + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + tx.AddError(ErrPtrStructSupported) + case reflect.Struct: + for idx, pf := range tx.Statement.Schema.PrimaryFields { + if pv, isZero := pf.ValueOf(reflectValue); isZero { + tx.callbacks.Create().Execute(tx) + where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} + return + } } } diff --git a/tests/associations_test.go b/tests/associations_test.go index 159f7f3a..77a5ce47 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -6,7 +6,26 @@ import ( . "github.com/jinzhu/gorm/tests" ) -func TestAssociationForBelongsTo(t *testing.T) { +func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) { + if count := DB.Model(data).Association(name).Count(); count != result { + t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count) + } + + var newUser User + if user, ok := data.(User); ok { + DB.Find(&newUser, "id = ?", user.ID) + } else if user, ok := data.(*User); ok { + DB.Find(&newUser, "id = ?", user.ID) + } + + if newUser.ID != 0 { + if count := DB.Model(&newUser).Association(name).Count(); count != result { + t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count) + } + } +} + +func TestBelongsToAssociation(t *testing.T) { var user = *GetUser("belongs-to", Config{Company: true, Manager: true}) if err := DB.Create(&user).Error; err != nil { @@ -24,13 +43,8 @@ func TestAssociationForBelongsTo(t *testing.T) { CheckUser(t, user2, user) // Count - if count := DB.Model(&user).Association("Company").Count(); count != 1 { - t.Errorf("invalid company count, got %v", count) - } - - if count := DB.Model(&user).Association("Manager").Count(); count != 1 { - t.Errorf("invalid manager count, got %v", count) - } + AssertAssociationCount(t, user, "Company", 1, "") + AssertAssociationCount(t, user, "Manager", 1, "") // Append var company = Company{Name: "company-belongs-to-append"} @@ -58,6 +72,9 @@ func TestAssociationForBelongsTo(t *testing.T) { user.ManagerID = &manager.ID CheckUser(t, user2, user) + AssertAssociationCount(t, user2, "Company", 1, "AfterAppend") + AssertAssociationCount(t, user2, "Manager", 1, "AfterAppend") + // Replace var company2 = Company{Name: "company-belongs-to-replace"} var manager2 = GetUser("manager-belongs-to-replace", Config{}) @@ -84,40 +101,31 @@ func TestAssociationForBelongsTo(t *testing.T) { user.ManagerID = &manager2.ID CheckUser(t, user2, user) + AssertAssociationCount(t, user2, "Company", 1, "AfterReplace") + AssertAssociationCount(t, user2, "Manager", 1, "AfterReplace") + // Delete if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil { t.Fatalf("Error happened when delete Company, got %v", err) } - - if count := DB.Model(&user2).Association("Company").Count(); count != 1 { - t.Errorf("Invalid company count after delete non-existing association, got %v", count) - } + AssertAssociationCount(t, user2, "Company", 1, "after delete non-existing data") if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil { t.Fatalf("Error happened when delete Company, got %v", err) } - - if count := DB.Model(&user2).Association("Company").Count(); count != 0 { - t.Errorf("Invalid company count after delete, got %v", count) - } + AssertAssociationCount(t, user2, "Company", 0, "after delete") if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil { t.Fatalf("Error happened when delete Manager, got %v", err) } - - if count := DB.Model(&user2).Association("Manager").Count(); count != 1 { - t.Errorf("Invalid manager count after delete non-existing association, got %v", count) - } + AssertAssociationCount(t, user2, "Manager", 1, "after delete non-existing data") if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil { t.Fatalf("Error happened when delete Manager, got %v", err) } + AssertAssociationCount(t, user2, "Manager", 0, "after delete") - if count := DB.Model(&user2).Association("Manager").Count(); count != 0 { - t.Errorf("Invalid manager count after delete, got %v", count) - } - - // Prepare Data + // Prepare Data for Clear if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { t.Fatalf("Error happened when append Company, got %v", err) } @@ -126,13 +134,8 @@ func TestAssociationForBelongsTo(t *testing.T) { t.Fatalf("Error happened when append Manager, got %v", err) } - if count := DB.Model(&user2).Association("Company").Count(); count != 1 { - t.Errorf("Invalid company count after append, got %v", count) - } - - if count := DB.Model(&user2).Association("Manager").Count(); count != 1 { - t.Errorf("Invalid manager count after append, got %v", count) - } + AssertAssociationCount(t, user2, "Company", 1, "after prepare data") + AssertAssociationCount(t, user2, "Manager", 1, "after prepare data") // Clear if err := DB.Model(&user2).Association("Company").Clear(); err != nil { @@ -143,11 +146,43 @@ func TestAssociationForBelongsTo(t *testing.T) { t.Errorf("Error happened when clear Manager, got %v", err) } - if count := DB.Model(&user2).Association("Company").Count(); count != 0 { - t.Errorf("Invalid company count after clear, got %v", count) + AssertAssociationCount(t, user2, "Company", 0, "after clear") + AssertAssociationCount(t, user2, "Manager", 0, "after clear") +} + +func TestBelongsToAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-belongs-to-1", Config{Company: true, Manager: true}), + *GetUser("slice-belongs-to-2", Config{Company: true, Manager: false}), + *GetUser("slice-belongs-to-3", Config{Company: true, Manager: true}), } - if count := DB.Model(&user2).Association("Manager").Count(); count != 0 { - t.Errorf("Invalid manager count after clear, got %v", count) + DB.Create(&users) + + AssertAssociationCount(t, users, "Company", 3, "") + AssertAssociationCount(t, users, "Manager", 2, "") + + // Find + var companies []Company + if DB.Model(users).Association("Company").Find(&companies); len(companies) != 3 { + t.Errorf("companies count should be %v, but got %v", 3, len(companies)) } + + var managers []User + if DB.Model(users).Association("Manager").Find(&managers); len(managers) != 2 { + t.Errorf("managers count should be %v, but got %v", 2, len(managers)) + } + + // Append + + // Replace + + // Delete + + // Clear + DB.Model(&users).Association("Company").Clear() + AssertAssociationCount(t, users, "Company", 0, "After Clear") + + DB.Model(&users).Association("Manager").Clear() + AssertAssociationCount(t, users, "Manager", 0, "After Clear") } From 677c745b620bdfc114ae87495f49fee2200a3008 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 24 May 2020 21:46:33 +0800 Subject: [PATCH 323/881] Test shared association --- association.go | 27 ++++++++++++------- tests/associations_test.go | 53 +++++++++++++++++++++++++++++++++++--- 2 files changed, 67 insertions(+), 13 deletions(-) diff --git a/association.go b/association.go index ff1e155f..47ec500e 100644 --- a/association.go +++ b/association.go @@ -195,6 +195,8 @@ func (association *Association) Delete(values ...interface{}) error { foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) updateAttrs[ref.ForeignKey.DBName] = nil } + } else { + tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryKey}) } } @@ -208,6 +210,15 @@ func (association *Association) Delete(values ...interface{}) error { conds := rel.ToQueryConditions(reflectValue) tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) case schema.BelongsTo: + primaryKeys := []string{} + for _, field := range rel.Schema.PrimaryFields { + primaryKeys = append(primaryKeys, field.DBName) + } + _, queryValues := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) + if column, values := schema.ToQueryValues(primaryKeys, queryValues); len(values) > 0 { + tx.Where(clause.IN{Column: column, Values: values}) + } + modelValue := reflect.New(rel.Schema.ModelType).Interface() tx.Model(modelValue).UpdateColumns(updateAttrs) case schema.Many2Many: @@ -353,7 +364,6 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } selectedColumns := []string{association.Relationship.Name} - hasZero := false for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey { selectedColumns = append(selectedColumns, ref.ForeignKey.Name) @@ -375,13 +385,16 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ break } association.Error = errors.New("invalid association values, length doesn't match") + return } for i := 0; i < reflectValue.Len(); i++ { appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) - if !hasZero { - _, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue.Index(i)) + if len(values) > 0 { + // TODO support save slice data, sql with case + err := association.DB.Session(&Session{}).Select(selectedColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error + association.DB.AddError(err) } } case reflect.Struct: @@ -399,13 +412,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ appendToRelations(reflectValue, rv, clear && idx == 0) } - _, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue) - } - - if len(values) > 0 { - if hasZero { - association.DB.Create(reflectValue.Addr().Interface()) - } else { + if len(values) > 0 { association.DB.Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface()) } } diff --git a/tests/associations_test.go b/tests/associations_test.go index 77a5ce47..c67e79c8 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -164,20 +164,49 @@ func TestBelongsToAssociationForSlice(t *testing.T) { // Find var companies []Company - if DB.Model(users).Association("Company").Find(&companies); len(companies) != 3 { + if DB.Model(&users).Association("Company").Find(&companies); len(companies) != 3 { t.Errorf("companies count should be %v, but got %v", 3, len(companies)) } var managers []User - if DB.Model(users).Association("Manager").Find(&managers); len(managers) != 2 { + if DB.Model(&users).Association("Manager").Find(&managers); len(managers) != 2 { t.Errorf("managers count should be %v, but got %v", 2, len(managers)) } // Append + DB.Model(&users).Association("Company").Append( + &Company{Name: "company-slice-append-1"}, + &Company{Name: "company-slice-append-2"}, + &Company{Name: "company-slice-append-3"}, + ) - // Replace + AssertAssociationCount(t, users, "Company", 3, "After Append") + + DB.Model(&users).Association("Manager").Append( + GetUser("manager-slice-belongs-to-1", Config{}), + GetUser("manager-slice-belongs-to-2", Config{}), + GetUser("manager-slice-belongs-to-3", Config{}), + ) + AssertAssociationCount(t, users, "Manager", 3, "After Append") + + if err := DB.Model(&users).Association("Manager").Append( + GetUser("manager-slice-belongs-to-test-1", Config{}), + ).Error; err == nil { + t.Errorf("unmatched length when update user's manager") + } + + // Replace -> same as append // Delete + if err := DB.Model(&users).Association("Company").Delete(&users[0].Company); err != nil { + t.Errorf("no error should happend when deleting company, but got %v", err) + } + + if users[0].CompanyID != nil || users[0].Company.ID != 0 { + t.Errorf("users[0]'s company should be deleted'") + } + + AssertAssociationCount(t, users, "Company", 2, "After Delete") // Clear DB.Model(&users).Association("Company").Clear() @@ -185,4 +214,22 @@ func TestBelongsToAssociationForSlice(t *testing.T) { DB.Model(&users).Association("Manager").Clear() AssertAssociationCount(t, users, "Manager", 0, "After Clear") + + // shared company + company := Company{Name: "shared"} + if err := DB.Model(&users[0]).Association("Company").Append(&company); err != nil { + t.Errorf("Error happened when append company to user, got %v", err) + } + + if err := DB.Model(&users[1]).Association("Company").Append(&company); err != nil { + t.Errorf("Error happened when append company to user, got %v", err) + } + + if users[0].CompanyID == nil || users[1].CompanyID == nil || *users[0].CompanyID != *users[1].CompanyID { + t.Errorf("user's company id should exists and equal, but its: %v, %v", users[0].CompanyID, users[1].CompanyID) + } + + DB.Model(&users[0]).Association("Company").Delete(&company) + AssertAssociationCount(t, users[0], "Company", 0, "After Delete") + AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete") } From 68a7a8207a39ba3df10945bd6a5af486ecd88f73 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 24 May 2020 22:52:16 +0800 Subject: [PATCH 324/881] Test HasOne Association --- association.go | 32 +++++++++------- callbacks/associations.go | 6 +++ callbacks/update.go | 7 +++- clause/expression.go | 2 + schema/field.go | 40 ++++++++++++-------- tests/associations_test.go | 76 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 133 insertions(+), 30 deletions(-) diff --git a/association.go b/association.go index 47ec500e..c90258ec 100644 --- a/association.go +++ b/association.go @@ -97,28 +97,34 @@ func (association *Association) Replace(values ...interface{}) error { } case schema.HasOne, schema.HasMany: var ( - primaryFields []*schema.Field - foreignKeys []string - updateMap = map[string]interface{}{} - modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() + tx = association.DB + primaryFields []*schema.Field + foreignKeys []string + updateMap = map[string]interface{}{} + relPrimaryKeys = []string{} + relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel}) + modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() ) - if rel.Type == schema.BelongsTo { - modelValue = reflect.New(rel.Schema.ModelType).Interface() + + for _, field := range rel.FieldSchema.PrimaryFields { + relPrimaryKeys = append(relPrimaryKeys, field.DBName) + } + if _, qvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(qvs) > 0 { + if column, values := schema.ToQueryValues(relPrimaryKeys, qvs); len(values) > 0 { + tx = tx.Not(clause.IN{Column: column, Values: values}) + } } for _, ref := range rel.References { if ref.OwnPrimaryKey { primaryFields = append(primaryFields, ref.PrimaryKey) - } else { foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) updateMap[ref.ForeignKey.DBName] = nil } } - - _, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) - if len(values) == 0 { - column, queryValues := schema.ToQueryValues(foreignKeys, values) - association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap) + if _, qvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(qvs) > 0 { + column, values := schema.ToQueryValues(foreignKeys, qvs) + tx.Model(modelValue).Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap) } case schema.Many2Many: var primaryFields, relPrimaryFields []*schema.Field @@ -413,7 +419,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } if len(values) > 0 { - association.DB.Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface()) + association.DB.Session(&Session{}).Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface()) } } diff --git a/callbacks/associations.go b/callbacks/associations.go index 37addd60..2342f110 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -124,6 +124,8 @@ func SaveAfterAssociations(db *gorm.DB) { if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { elems = reflect.Append(elems, rv) + } else { + db.Session(&gorm.Session{}).Save(rv.Interface()) } } } @@ -149,6 +151,8 @@ func SaveAfterAssociations(db *gorm.DB) { if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f); isZero { db.Session(&gorm.Session{}).Create(f.Interface()) + } else { + db.Session(&gorm.Session{}).Save(f.Interface()) } } } @@ -187,6 +191,8 @@ func SaveAfterAssociations(db *gorm.DB) { } else { elems = reflect.Append(elems, elem.Addr()) } + } else { + db.Session(&gorm.Session{}).Save(elem.Interface()) } } } diff --git a/callbacks/update.go b/callbacks/update.go index 6a59e487..f9b20981 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -45,7 +45,11 @@ func BeforeUpdate(db *gorm.DB) { func Update(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Update{}) - db.Statement.AddClause(ConvertToAssignments(db.Statement)) + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + db.Statement.AddClause(set) + } else { + return + } db.Statement.Build("UPDATE", "SET", "WHERE") result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) @@ -198,5 +202,6 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } } + return } diff --git a/clause/expression.go b/clause/expression.go index 8150f838..872736ce 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -55,9 +55,11 @@ func (in IN) NegationBuild(builder Builder) { switch len(in.Values) { case 0: case 1: + builder.WriteQuoted(in.Column) builder.WriteString(" <> ") builder.AddVar(builder, in.Values...) default: + builder.WriteQuoted(in.Column) builder.WriteString(" NOT IN (") builder.AddVar(builder, in.Values...) builder.WriteByte(')') diff --git a/schema/field.go b/schema/field.go index 9a5f1fc6..8b8b190d 100644 --- a/schema/field.go +++ b/schema/field.go @@ -603,32 +603,40 @@ func (field *Field) setupValuerAndSetter() { if _, ok := fieldValue.Interface().(sql.Scanner); ok { // struct scanner field.Set = func(value reflect.Value, v interface{}) (err error) { - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) - } else if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { + if v == nil { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) + } else if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) + } + } else { err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) } - } else { - err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) } return } } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner field.Set = func(value reflect.Value, v interface{}) (err error) { - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) - } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { - field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) - } else if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { + if v == nil { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) + } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { + field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) + } else if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) + } + } else { err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) } - } else { - err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) } return } diff --git a/tests/associations_test.go b/tests/associations_test.go index c67e79c8..137b2c50 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -233,3 +233,79 @@ func TestBelongsToAssociationForSlice(t *testing.T) { AssertAssociationCount(t, users[0], "Company", 0, "After Delete") AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete") } + +func TestHasOneAssociation(t *testing.T) { + var user = *GetUser("hasone", Config{Account: true}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Account").Find(&user2.Account) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Account", 1, "") + + // Append + var account = Account{Number: "account-has-one-append"} + + if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + if account.ID == 0 { + t.Fatalf("Account's ID should be created") + } + + user.Account = account + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Account", 1, "AfterAppend") + + // Replace + var account2 = Account{Number: "account-has-one-replace"} + + if err := DB.Model(&user2).Association("Account").Replace(&account2); err != nil { + t.Fatalf("Error happened when append Account, got %v", err) + } + + if account2.ID == 0 { + t.Fatalf("account2's ID should be created") + } + + user.Account = account2 + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Account", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Account").Delete(&Company{}); err != nil { + t.Fatalf("Error happened when delete account, got %v", err) + } + AssertAssociationCount(t, user2, "Account", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Account").Delete(&account2); err != nil { + t.Fatalf("Error happened when delete Account, got %v", err) + } + AssertAssociationCount(t, user2, "Account", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { + t.Fatalf("Error happened when append Account, got %v", err) + } + + AssertAssociationCount(t, user2, "Account", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Account").Clear(); err != nil { + t.Errorf("Error happened when clear Account, got %v", err) + } + + AssertAssociationCount(t, user2, "Account", 0, "after clear") +} From 6a0ef985ffb3c600da7449376453eb23692c6c05 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 24 May 2020 23:28:06 +0800 Subject: [PATCH 325/881] Test Polymorphic HasOne Association --- association.go | 10 ++--- tests/associations_test.go | 78 +++++++++++++++++++++++++++++++++++++- 2 files changed, 81 insertions(+), 7 deletions(-) diff --git a/association.go b/association.go index c90258ec..4d240418 100644 --- a/association.go +++ b/association.go @@ -179,9 +179,9 @@ func (association *Association) Replace(values ...interface{}) error { func (association *Association) Delete(values ...interface{}) error { if association.Error == nil { var ( - tx = association.DB + reflectValue = association.DB.Statement.ReflectValue rel = association.Relationship - reflectValue = tx.Statement.ReflectValue + tx = association.DB relFields []*schema.Field foreignKeyFields []*schema.Field foreignKeys []string @@ -201,14 +201,12 @@ func (association *Association) Delete(values ...interface{}) error { foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) updateAttrs[ref.ForeignKey.DBName] = nil } - } else { - tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryKey}) } } relValuesMap, relQueryValues := schema.GetIdentityFieldValuesMapFromValues(values, relFields) column, values := schema.ToQueryValues(foreignKeys, relQueryValues) - tx.Where(clause.IN{Column: column, Values: values}) + tx = tx.Session(&Session{}).Where(clause.IN{Column: column, Values: values}) switch rel.Type { case schema.HasOne, schema.HasMany: @@ -407,7 +405,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ if clear && len(values) == 0 { association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) for _, ref := range association.Relationship.References { - if !ref.OwnPrimaryKey { + if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } } diff --git a/tests/associations_test.go b/tests/associations_test.go index 137b2c50..0b131450 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -285,7 +285,7 @@ func TestHasOneAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Account", 1, "AfterReplace") // Delete - if err := DB.Model(&user2).Association("Account").Delete(&Company{}); err != nil { + if err := DB.Model(&user2).Association("Account").Delete(&Account{}); err != nil { t.Fatalf("Error happened when delete account, got %v", err) } AssertAssociationCount(t, user2, "Account", 1, "after delete non-existing data") @@ -309,3 +309,79 @@ func TestHasOneAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Account", 0, "after clear") } + +func TestPolymorphicHasOneAssociation(t *testing.T) { + var pet = Pet{Name: "hasone", Toy: Toy{Name: "toy-has-one"}} + + if err := DB.Create(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckPet(t, pet, pet) + + // Find + var pet2 Pet + DB.Find(&pet2, "id = ?", pet.ID) + DB.Model(&pet2).Association("Toy").Find(&pet2.Toy) + CheckPet(t, pet2, pet) + + // Count + AssertAssociationCount(t, pet, "Toy", 1, "") + + // Append + var toy = Toy{Name: "toy-has-one-append"} + + if err := DB.Model(&pet2).Association("Toy").Append(&toy); err != nil { + t.Fatalf("Error happened when append toy, got %v", err) + } + + if toy.ID == 0 { + t.Fatalf("Toy's ID should be created") + } + + pet.Toy = toy + CheckPet(t, pet2, pet) + + AssertAssociationCount(t, pet, "Toy", 1, "AfterAppend") + + // Replace + var toy2 = Toy{Name: "toy-has-one-replace"} + + if err := DB.Model(&pet2).Association("Toy").Replace(&toy2); err != nil { + t.Fatalf("Error happened when append Toy, got %v", err) + } + + if toy2.ID == 0 { + t.Fatalf("toy2's ID should be created") + } + + pet.Toy = toy2 + CheckPet(t, pet2, pet) + + AssertAssociationCount(t, pet2, "Toy", 1, "AfterReplace") + + // Delete + if err := DB.Model(&pet2).Association("Toy").Delete(&Toy{}); err != nil { + t.Fatalf("Error happened when delete toy, got %v", err) + } + AssertAssociationCount(t, pet2, "Toy", 1, "after delete non-existing data") + + if err := DB.Model(&pet2).Association("Toy").Delete(&toy2); err != nil { + t.Fatalf("Error happened when delete Toy, got %v", err) + } + AssertAssociationCount(t, pet2, "Toy", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&pet2).Association("Toy").Append(&toy); err != nil { + t.Fatalf("Error happened when append Toy, got %v", err) + } + + AssertAssociationCount(t, pet2, "Toy", 1, "after prepare data") + + // Clear + if err := DB.Model(&pet2).Association("Toy").Clear(); err != nil { + t.Errorf("Error happened when clear Toy, got %v", err) + } + + AssertAssociationCount(t, pet2, "Toy", 0, "after clear") +} From 91eaf0bb2113fbe74aeb0051510cda6c57326544 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 24 May 2020 23:43:42 +0800 Subject: [PATCH 326/881] Test HasOne Association for Slice --- association.go | 2 +- tests/associations_test.go | 82 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/association.go b/association.go index 4d240418..f65e77c2 100644 --- a/association.go +++ b/association.go @@ -381,7 +381,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ for i := 0; i < reflectValue.Len(); i++ { association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) for _, ref := range association.Relationship.References { - if !ref.OwnPrimaryKey { + if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()) } } diff --git a/tests/associations_test.go b/tests/associations_test.go index 0b131450..2b81a719 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -310,6 +310,47 @@ func TestHasOneAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Account", 0, "after clear") } +func TestHasOneAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-hasone-1", Config{Account: true}), + *GetUser("slice-hasone-2", Config{Account: false}), + *GetUser("slice-hasone-3", Config{Account: true}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Account", 2, "") + + // Find + var accounts []Account + if DB.Model(&users).Association("Account").Find(&accounts); len(accounts) != 2 { + t.Errorf("accounts count should be %v, but got %v", 3, len(accounts)) + } + + // Append + DB.Model(&users).Association("Account").Append( + &Account{Number: "account-slice-append-1"}, + &Account{Number: "account-slice-append-2"}, + &Account{Number: "account-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Account", 3, "After Append") + + // Replace -> same as append + + // Delete + if err := DB.Model(&users).Association("Account").Delete(&users[0].Account); err != nil { + t.Errorf("no error should happend when deleting account, but got %v", err) + } + + AssertAssociationCount(t, users, "Account", 2, "after delete") + + // Clear + DB.Model(&users).Association("Account").Clear() + AssertAssociationCount(t, users, "Account", 0, "After Clear") +} + func TestPolymorphicHasOneAssociation(t *testing.T) { var pet = Pet{Name: "hasone", Toy: Toy{Name: "toy-has-one"}} @@ -385,3 +426,44 @@ func TestPolymorphicHasOneAssociation(t *testing.T) { AssertAssociationCount(t, pet2, "Toy", 0, "after clear") } + +func TestPolymorphicHasOneAssociationForSlice(t *testing.T) { + var pets = []Pet{ + {Name: "hasone-1", Toy: Toy{Name: "toy-has-one"}}, + {Name: "hasone-2", Toy: Toy{}}, + {Name: "hasone-3", Toy: Toy{Name: "toy-has-one"}}, + } + + DB.Create(&pets) + + // Count + AssertAssociationCount(t, pets, "Toy", 2, "") + + // Find + var toys []Toy + if DB.Model(&pets).Association("Toy").Find(&toys); len(toys) != 2 { + t.Errorf("toys count should be %v, but got %v", 3, len(toys)) + } + + // Append + DB.Model(&pets).Association("Toy").Append( + &Toy{Name: "toy-slice-append-1"}, + &Toy{Name: "toy-slice-append-2"}, + &Toy{Name: "toy-slice-append-3"}, + ) + + AssertAssociationCount(t, pets, "Toy", 3, "After Append") + + // Replace -> same as append + + // Delete + if err := DB.Model(&pets).Association("Toy").Delete(&pets[0].Toy); err != nil { + t.Errorf("no error should happend when deleting toy, but got %v", err) + } + + AssertAssociationCount(t, pets, "Toy", 2, "after delete") + + // Clear + DB.Model(&pets).Association("Toy").Clear() + AssertAssociationCount(t, pets, "Toy", 0, "After Clear") +} From 5d9b57cc4e5e1df2067e6ea6384f049e57b39200 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 25 May 2020 11:11:09 +0800 Subject: [PATCH 327/881] Test HasMany Association --- association.go | 99 ++++++++++++------------ schema/schema.go | 5 ++ tests/associations_test.go | 149 +++++++++++++++++++++++++++++++++++++ 3 files changed, 207 insertions(+), 46 deletions(-) diff --git a/association.go b/association.go index f65e77c2..9405d962 100644 --- a/association.go +++ b/association.go @@ -179,69 +179,71 @@ func (association *Association) Replace(values ...interface{}) error { func (association *Association) Delete(values ...interface{}) error { if association.Error == nil { var ( - reflectValue = association.DB.Statement.ReflectValue - rel = association.Relationship - tx = association.DB - relFields []*schema.Field - foreignKeyFields []*schema.Field - foreignKeys []string - updateAttrs = map[string]interface{}{} + reflectValue = association.DB.Statement.ReflectValue + rel = association.Relationship + tx = association.DB + primaryFields, foreignFields []*schema.Field + foreignKeys []string + updateAttrs = map[string]interface{}{} ) for _, ref := range rel.References { if ref.PrimaryValue == "" { - if rel.JoinTable == nil || !ref.OwnPrimaryKey { - if ref.OwnPrimaryKey { - relFields = append(relFields, ref.ForeignKey) - } else { - relFields = append(relFields, ref.PrimaryKey) - foreignKeyFields = append(foreignKeyFields, ref.ForeignKey) - } - - foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) - updateAttrs[ref.ForeignKey.DBName] = nil - } + primaryFields = append(primaryFields, ref.PrimaryKey) + foreignFields = append(foreignFields, ref.ForeignKey) + foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) + updateAttrs[ref.ForeignKey.DBName] = nil } } - relValuesMap, relQueryValues := schema.GetIdentityFieldValuesMapFromValues(values, relFields) - column, values := schema.ToQueryValues(foreignKeys, relQueryValues) - tx = tx.Session(&Session{}).Where(clause.IN{Column: column, Values: values}) - switch rel.Type { case schema.HasOne, schema.HasMany: - modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() - conds := rel.ToQueryConditions(reflectValue) - tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) - case schema.BelongsTo: - primaryKeys := []string{} - for _, field := range rel.Schema.PrimaryFields { - primaryKeys = append(primaryKeys, field.DBName) - } - _, queryValues := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) - if column, values := schema.ToQueryValues(primaryKeys, queryValues); len(values) > 0 { - tx.Where(clause.IN{Column: column, Values: values}) - } + var ( + modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() + _, queryValues = schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, relQueryValues = schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + ) - modelValue := reflect.New(rel.Schema.ModelType).Interface() - tx.Model(modelValue).UpdateColumns(updateAttrs) + column, values := schema.ToQueryValues(foreignKeys, queryValues) + relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, relQueryValues) + + tx.Session(&Session{}).Model(modelValue).Clauses( + clause.IN{Column: column, Values: values}, + clause.IN{Column: relColumn, Values: relValues}, + ).UpdateColumns(updateAttrs) + case schema.BelongsTo: + var ( + modelValue = reflect.New(rel.Schema.ModelType).Interface() + _, queryValues = schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) + _, relQueryValues = schema.GetIdentityFieldValuesMapFromValues(values, primaryFields) + ) + + column, values := schema.ToQueryValues(rel.Schema.PrimaryFieldDBNames, queryValues) + relColumn, relValues := schema.ToQueryValues(foreignKeys, relQueryValues) + + tx.Session(&Session{}).Model(modelValue).Clauses( + clause.IN{Column: column, Values: values}, + clause.IN{Column: relColumn, Values: relValues}, + ).UpdateColumns(updateAttrs) case schema.Many2Many: modelValue := reflect.New(rel.JoinTable.ModelType).Interface() conds := rel.ToQueryConditions(reflectValue) tx.Clauses(clause.Where{Exprs: conds}).Delete(modelValue) } + relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + if tx.Error == nil { cleanUpDeletedRelations := func(data reflect.Value) { if _, zero := rel.Field.ValueOf(data); !zero { fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) - fieldValues := make([]interface{}, len(relFields)) + fieldValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) switch fieldValue.Kind() { case reflect.Slice, reflect.Array: validFieldValues := reflect.Zero(rel.Field.FieldType) for i := 0; i < fieldValue.Len(); i++ { - for idx, field := range relFields { + for idx, field := range rel.FieldSchema.PrimaryFields { fieldValues[idx], _ = field.ValueOf(fieldValue.Index(i)) } @@ -252,13 +254,18 @@ func (association *Association) Delete(values ...interface{}) error { rel.Field.Set(data, validFieldValues.Interface()) case reflect.Struct: - for idx, field := range relFields { + for idx, field := range rel.FieldSchema.PrimaryFields { fieldValues[idx], _ = field.ValueOf(fieldValue) } if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; ok { rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()) - for _, field := range foreignKeyFields { - field.Set(data, reflect.Zero(field.FieldType).Interface()) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } else if ref.PrimaryValue == "" { + // FIXME + ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } } } } @@ -337,9 +344,9 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } case schema.HasMany, schema.Many2Many: elemType := association.Relationship.Field.IndirectFieldType.Elem() - fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(reflectValue)) + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source)) if clear { - fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType) + fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() } appendToFieldValues := func(ev reflect.Value) { @@ -355,14 +362,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ switch rv.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < rv.Len(); i++ { - appendToFieldValues(reflect.Indirect(rv.Index(i))) + appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) } case reflect.Struct: - appendToFieldValues(rv) + appendToFieldValues(rv.Addr()) } if association.Error == nil { - association.Error = association.Relationship.Field.Set(source, fieldValue.Addr().Interface()) + association.Error = association.Relationship.Field.Set(source, fieldValue.Interface()) } } } diff --git a/schema/schema.go b/schema/schema.go index caae55ac..e66084a3 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -22,6 +22,7 @@ type Schema struct { PrioritizedPrimaryField *Field DBNames []string PrimaryFields []*Field + PrimaryFieldDBNames []string Fields []*Field FieldsByName map[string]*Field FieldsByDBName map[string]*Field @@ -165,6 +166,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } + for _, field := range schema.PrimaryFields { + schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) + } + schema.FieldsWithDefaultDBValue = map[string]*Field{} for db, field := range schema.FieldsByDBName { if field.HasDefaultValue && field.DefaultValueInterface == nil { diff --git a/tests/associations_test.go b/tests/associations_test.go index 2b81a719..08733005 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -467,3 +467,152 @@ func TestPolymorphicHasOneAssociationForSlice(t *testing.T) { DB.Model(&pets).Association("Toy").Clear() AssertAssociationCount(t, pets, "Toy", 0, "After Clear") } + +func TestHasManyAssociation(t *testing.T) { + var user = *GetUser("hasmany", Config{Pets: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Pets").Find(&user2.Pets) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Pets", 2, "") + + // Append + var pet = Pet{Name: "pet-has-many-append"} + + if err := DB.Model(&user2).Association("Pets").Append(&pet); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + if pet.ID == 0 { + t.Fatalf("Pet's ID should be created") + } + + user.Pets = append(user.Pets, &pet) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Pets", 3, "AfterAppend") + + var pets = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} + + if err := DB.Model(&user2).Association("Pets").Append(&pets); err != nil { + t.Fatalf("Error happened when append pet, got %v", err) + } + + for _, pet := range pets { + var pet = pet + if pet.ID == 0 { + t.Fatalf("Pet's ID should be created") + } + + user.Pets = append(user.Pets, &pet) + } + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Pets", 5, "AfterAppendSlice") + + // Replace + var pet2 = Pet{Name: "pet-has-many-replace"} + + if err := DB.Model(&user2).Association("Pets").Replace(&pet2); err != nil { + t.Fatalf("Error happened when append pet, got %v", err) + } + + if pet2.ID == 0 { + t.Fatalf("pet2's ID should be created") + } + + user.Pets = []*Pet{&pet2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Pets", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Pets").Delete(&Pet{}); err != nil { + t.Fatalf("Error happened when delete pet, got %v", err) + } + AssertAssociationCount(t, user2, "Pets", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Pets").Delete(&pet2); err != nil { + t.Fatalf("Error happened when delete Pets, got %v", err) + } + AssertAssociationCount(t, user2, "Pets", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Pets").Append(&pet); err != nil { + t.Fatalf("Error happened when append Pets, got %v", err) + } + + AssertAssociationCount(t, user2, "Pets", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Pets").Clear(); err != nil { + t.Errorf("Error happened when clear Pets, got %v", err) + } + + AssertAssociationCount(t, user2, "Pets", 0, "after clear") +} + +func TestHasManyAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-hasmany-1", Config{Pets: 2}), + *GetUser("slice-hasmany-2", Config{Pets: 0}), + *GetUser("slice-hasmany-3", Config{Pets: 4}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Pets", 6, "") + + // Find + var pets []Pet + if DB.Model(&users).Association("Pets").Find(&pets); len(pets) != 6 { + t.Errorf("pets count should be %v, but got %v", 6, len(pets)) + } + + // Append + DB.Model(&users).Association("Pets").Append( + &Pet{Name: "pet-slice-append-1"}, + []*Pet{{Name: "pet-slice-append-2-1"}, {Name: "pet-slice-append-2-2"}}, + &Pet{Name: "pet-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Pets", 10, "After Append") + + // Replace -> same as append + DB.Model(&users).Association("Pets").Replace( + []*Pet{{Name: "pet-slice-replace-1-1"}, {Name: "pet-slice-replace-1-2"}}, + []*Pet{{Name: "pet-slice-replace-2-1"}, {Name: "pet-slice-replace-2-2"}}, + &Pet{Name: "pet-slice-replace-3"}, + ) + + AssertAssociationCount(t, users, "Pets", 5, "After Append") + + // Delete + if err := DB.Model(&users).Association("Pets").Delete(&users[2].Pets); err != nil { + t.Errorf("no error should happend when deleting pet, but got %v", err) + } + + AssertAssociationCount(t, users, "Pets", 4, "after delete") + + if err := DB.Debug().Model(&users).Association("Pets").Delete(users[0].Pets[0], users[1].Pets[1]); err != nil { + t.Errorf("no error should happend when deleting pet, but got %v", err) + } + + AssertAssociationCount(t, users, "Pets", 2, "after delete") + + // Clear + DB.Model(&users).Association("Pets").Clear() + AssertAssociationCount(t, users, "Pets", 0, "After Clear") +} From 135d9f8b0308c4bb24286d907f8d799705a24672 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 25 May 2020 11:49:02 +0800 Subject: [PATCH 328/881] Test HasMany Association for Slice --- association.go | 17 +++-- callbacks.go | 3 + callbacks/associations.go | 4 +- tests/associations_test.go | 152 ++++++++++++++++++++++++++++++++++++- 4 files changed, 165 insertions(+), 11 deletions(-) diff --git a/association.go b/association.go index 9405d962..e3aee8f2 100644 --- a/association.go +++ b/association.go @@ -185,6 +185,7 @@ func (association *Association) Delete(values ...interface{}) error { primaryFields, foreignFields []*schema.Field foreignKeys []string updateAttrs = map[string]interface{}{} + conds []clause.Expression ) for _, ref := range rel.References { @@ -193,6 +194,8 @@ func (association *Association) Delete(values ...interface{}) error { foreignFields = append(foreignFields, ref.ForeignKey) foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) updateAttrs[ref.ForeignKey.DBName] = nil + } else { + conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) } } @@ -205,12 +208,11 @@ func (association *Association) Delete(values ...interface{}) error { ) column, values := schema.ToQueryValues(foreignKeys, queryValues) + conds = append(conds, clause.IN{Column: column, Values: values}) relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, relQueryValues) + conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) - tx.Session(&Session{}).Model(modelValue).Clauses( - clause.IN{Column: column, Values: values}, - clause.IN{Column: relColumn, Values: relValues}, - ).UpdateColumns(updateAttrs) + tx.Session(&Session{}).Model(modelValue).Clauses(conds...).UpdateColumns(updateAttrs) case schema.BelongsTo: var ( modelValue = reflect.New(rel.Schema.ModelType).Interface() @@ -219,12 +221,11 @@ func (association *Association) Delete(values ...interface{}) error { ) column, values := schema.ToQueryValues(rel.Schema.PrimaryFieldDBNames, queryValues) + conds = append(conds, clause.IN{Column: column, Values: values}) relColumn, relValues := schema.ToQueryValues(foreignKeys, relQueryValues) + conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) - tx.Session(&Session{}).Model(modelValue).Clauses( - clause.IN{Column: column, Values: values}, - clause.IN{Column: relColumn, Values: relValues}, - ).UpdateColumns(updateAttrs) + tx.Session(&Session{}).Model(modelValue).Clauses(conds...).UpdateColumns(updateAttrs) case schema.Many2Many: modelValue := reflect.New(rel.JoinTable.ModelType).Interface() conds := rel.ToQueryConditions(reflectValue) diff --git a/callbacks.go b/callbacks.go index 629b90aa..d05947d9 100644 --- a/callbacks.go +++ b/callbacks.go @@ -87,6 +87,9 @@ func (p *processor) Execute(db *DB) { if stmt.Dest != nil { stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest)) + for stmt.ReflectValue.Kind() == reflect.Ptr { + stmt.ReflectValue = stmt.ReflectValue.Elem() + } if !stmt.ReflectValue.IsValid() { db.AddError(fmt.Errorf("invalid value")) } diff --git a/callbacks/associations.go b/callbacks/associations.go index 2342f110..d9ecafc7 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -125,7 +125,7 @@ func SaveAfterAssociations(db *gorm.DB) { if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { elems = reflect.Append(elems, rv) } else { - db.Session(&gorm.Session{}).Save(rv.Interface()) + db.Session(&gorm.Session{}).Save(rv.Addr().Interface()) } } } @@ -192,7 +192,7 @@ func SaveAfterAssociations(db *gorm.DB) { elems = reflect.Append(elems, elem.Addr()) } } else { - db.Session(&gorm.Session{}).Save(elem.Interface()) + db.Session(&gorm.Session{}).Save(elem.Addr().Interface()) } } } diff --git a/tests/associations_test.go b/tests/associations_test.go index 08733005..dd9f7efb 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -606,7 +606,7 @@ func TestHasManyAssociationForSlice(t *testing.T) { AssertAssociationCount(t, users, "Pets", 4, "after delete") - if err := DB.Debug().Model(&users).Association("Pets").Delete(users[0].Pets[0], users[1].Pets[1]); err != nil { + if err := DB.Model(&users).Association("Pets").Delete(users[0].Pets[0], users[1].Pets[1]); err != nil { t.Errorf("no error should happend when deleting pet, but got %v", err) } @@ -616,3 +616,153 @@ func TestHasManyAssociationForSlice(t *testing.T) { DB.Model(&users).Association("Pets").Clear() AssertAssociationCount(t, users, "Pets", 0, "After Clear") } + +func TestPolymorphicHasManyAssociation(t *testing.T) { + var user = *GetUser("hasmany", Config{Toys: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Toys").Find(&user2.Toys) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Toys", 2, "") + + // Append + var toy = Toy{Name: "toy-has-many-append"} + + if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + return + + if toy.ID == 0 { + t.Fatalf("Toy's ID should be created") + } + + user.Toys = append(user.Toys, toy) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Toys", 3, "AfterAppend") + + var toys = []Toy{{Name: "toy-has-many-append-1-1"}, {Name: "toy-has-many-append-1-1"}} + + if err := DB.Model(&user2).Association("Toys").Append(&toys); err != nil { + t.Fatalf("Error happened when append toy, got %v", err) + } + + for _, toy := range toys { + var toy = toy + if toy.ID == 0 { + t.Fatalf("Toy's ID should be created") + } + + user.Toys = append(user.Toys, toy) + } + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Toys", 5, "AfterAppendSlice") + + // Replace + var toy2 = Toy{Name: "toy-has-many-replace"} + + if err := DB.Model(&user2).Association("Toys").Replace(&toy2); err != nil { + t.Fatalf("Error happened when append toy, got %v", err) + } + + if toy2.ID == 0 { + t.Fatalf("toy2's ID should be created") + } + + user.Toys = []Toy{toy2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Toys", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Toys").Delete(&Toy{}); err != nil { + t.Fatalf("Error happened when delete toy, got %v", err) + } + AssertAssociationCount(t, user2, "Toys", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Toys").Delete(&toy2); err != nil { + t.Fatalf("Error happened when delete Toys, got %v", err) + } + AssertAssociationCount(t, user2, "Toys", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil { + t.Fatalf("Error happened when append Toys, got %v", err) + } + + AssertAssociationCount(t, user2, "Toys", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Toys").Clear(); err != nil { + t.Errorf("Error happened when clear Toys, got %v", err) + } + + AssertAssociationCount(t, user2, "Toys", 0, "after clear") +} + +func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-hasmany-1", Config{Toys: 2}), + *GetUser("slice-hasmany-2", Config{Toys: 0}), + *GetUser("slice-hasmany-3", Config{Toys: 4}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Toys", 6, "") + + // Find + var toys []Toy + if DB.Model(&users).Association("Toys").Find(&toys); len(toys) != 6 { + t.Errorf("toys count should be %v, but got %v", 6, len(toys)) + } + + // Append + DB.Model(&users).Association("Toys").Append( + &Toy{Name: "toy-slice-append-1"}, + []Toy{{Name: "toy-slice-append-2-1"}, {Name: "toy-slice-append-2-2"}}, + &Toy{Name: "toy-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Toys", 10, "After Append") + + // Replace -> same as append + DB.Model(&users).Association("Toys").Replace( + []*Toy{{Name: "toy-slice-replace-1-1"}, {Name: "toy-slice-replace-1-2"}}, + []*Toy{{Name: "toy-slice-replace-2-1"}, {Name: "toy-slice-replace-2-2"}}, + &Toy{Name: "toy-slice-replace-3"}, + ) + + AssertAssociationCount(t, users, "Toys", 5, "After Append") + + // Delete + if err := DB.Model(&users).Association("Toys").Delete(&users[2].Toys); err != nil { + t.Errorf("no error should happend when deleting toy, but got %v", err) + } + + AssertAssociationCount(t, users, "Toys", 4, "after delete") + + if err := DB.Model(&users).Association("Toys").Delete(users[0].Toys[0], users[1].Toys[1]); err != nil { + t.Errorf("no error should happend when deleting toy, but got %v", err) + } + + AssertAssociationCount(t, users, "Toys", 2, "after delete") + + // Clear + DB.Model(&users).Association("Toys").Clear() + AssertAssociationCount(t, users, "Toys", 0, "After Clear") +} From cc064f26ee7f0c96fa2b9079469f6136c7945273 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 25 May 2020 23:11:42 +0800 Subject: [PATCH 329/881] Add on conflict support --- callbacks/associations.go | 3 +- callbacks/create.go | 4 +- clause/on_conflict.go | 38 +++++++++++++++++ schema/relationship.go | 2 +- tests/associations_test.go | 87 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 130 insertions(+), 4 deletions(-) create mode 100644 clause/on_conflict.go diff --git a/callbacks/associations.go b/callbacks/associations.go index d9ecafc7..76fc5b81 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -4,6 +4,7 @@ import ( "reflect" "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/schema" "github.com/jinzhu/gorm/utils" ) @@ -282,7 +283,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if joins.Len() > 0 { - db.Session(&gorm.Session{}).Create(joins.Interface()) + db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()) } } } diff --git a/callbacks/create.go b/callbacks/create.go index ff88bc0e..0b30775a 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -51,7 +51,7 @@ func Create(config *Config) func(db *gorm.DB) { }) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) - db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") + db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { @@ -93,7 +93,7 @@ func CreateWithReturning(db *gorm.DB) { }) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) - db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") + db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { db.Statement.WriteString(" RETURNING ") diff --git a/clause/on_conflict.go b/clause/on_conflict.go new file mode 100644 index 00000000..6001399f --- /dev/null +++ b/clause/on_conflict.go @@ -0,0 +1,38 @@ +package clause + +type OnConflict struct { + Columns []Column + Where Where + DoNothing bool + DoUpdates Set +} + +func (OnConflict) Name() string { + return "ON CONFLICT" +} + +// Build build onConflict clause +func (onConflict OnConflict) Build(builder Builder) { + if len(onConflict.Columns) > 0 { + builder.WriteQuoted(onConflict.Columns) // FIXME columns + builder.WriteByte(' ') + } + + if len(onConflict.Where.Exprs) > 0 { + builder.WriteString("WHERE ") + onConflict.Where.Build(builder) + builder.WriteByte(' ') + } + + if onConflict.DoNothing { + builder.WriteString("DO NOTHING") + } else { + builder.WriteString("DO UPDATE SET ") + onConflict.DoUpdates.Build(builder) + } +} + +// MergeClause merge onConflict clauses +func (onConflict OnConflict) MergeClause(clause *Clause) { + clause.Expression = onConflict +} diff --git a/schema/relationship.go b/schema/relationship.go index d10bfe30..3dcef9fc 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -355,7 +355,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds [] for _, ref := range rel.References { if ref.OwnPrimaryKey { foreignFields = append(foreignFields, ref.PrimaryKey) - relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) } else if ref.PrimaryValue != "" { conds = append(conds, clause.Eq{ Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, diff --git a/tests/associations_test.go b/tests/associations_test.go index dd9f7efb..b6ddbd29 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -766,3 +766,90 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { DB.Model(&users).Association("Toys").Clear() AssertAssociationCount(t, users, "Toys", 0, "After Clear") } + +func TestMany2ManyAssociation(t *testing.T) { + var user = *GetUser("many2many", Config{Languages: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Languages").Find(&user2.Languages) + + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Languages", 2, "") + + // Append + var language = Language{Code: "language-has-many-append", Name: "language-has-many-append"} + DB.Create(&language) + + if err := DB.Model(&user2).Association("Languages").Append(&language); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + user.Languages = append(user.Languages, language) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Languages", 3, "AfterAppend") + + var languages = []Language{ + {Code: "language-has-many-append-1-1", Name: "language-has-many-append-1-1"}, + {Code: "language-has-many-append-2-1", Name: "language-has-many-append-2-1"}, + } + DB.Create(&languages) + + if err := DB.Model(&user2).Association("Languages").Append(&languages); err != nil { + t.Fatalf("Error happened when append language, got %v", err) + } + + user.Languages = append(user.Languages, languages...) + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Languages", 5, "AfterAppendSlice") + + // Replace + var language2 = Language{Code: "language-has-many-replace", Name: "language-has-many-replace"} + DB.Create(&language2) + + if err := DB.Model(&user2).Association("Languages").Replace(&language2); err != nil { + t.Fatalf("Error happened when append language, got %v", err) + } + + user.Languages = []Language{language2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Languages", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Languages").Delete(&Language{}); err != nil { + t.Fatalf("Error happened when delete language, got %v", err) + } + AssertAssociationCount(t, user2, "Languages", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Languages").Delete(&language2); err != nil { + t.Fatalf("Error happened when delete Languages, got %v", err) + } + AssertAssociationCount(t, user2, "Languages", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Languages").Append(&language); err != nil { + t.Fatalf("Error happened when append Languages, got %v", err) + } + + AssertAssociationCount(t, user2, "Languages", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Languages").Clear(); err != nil { + t.Errorf("Error happened when clear Languages, got %v", err) + } + + AssertAssociationCount(t, user2, "Languages", 0, "after clear") +} From dea48a8c59db900dee3af5c4c76799bb54f79119 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 26 May 2020 00:16:41 +0800 Subject: [PATCH 330/881] Test Many2Many Association --- association.go | 91 +++++++++++++++++++++++--------------- callbacks/delete.go | 20 ++++++--- errors.go | 2 + logger/sql.go | 4 +- tests/associations_test.go | 1 - 5 files changed, 76 insertions(+), 42 deletions(-) diff --git a/association.go b/association.go index e3aee8f2..49fd4558 100644 --- a/association.go +++ b/association.go @@ -128,49 +128,40 @@ func (association *Association) Replace(values ...interface{}) error { } case schema.Many2Many: var primaryFields, relPrimaryFields []*schema.Field - var foreignKeys, relForeignKeys []string - modelValue := reflect.New(rel.JoinTable.ModelType).Interface() - conds := []clause.Expression{} + var joinPrimaryKeys, joinRelPrimaryKeys []string + var conds []clause.Expression for _, ref := range rel.References { - if ref.OwnPrimaryKey { - primaryFields = append(primaryFields, ref.PrimaryKey) - foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) - } else if ref.PrimaryValue != "" { - conds = append(conds, clause.Eq{ - Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, - }) + if ref.PrimaryValue == "" { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName) + } else { + relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) + joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) + } } else { - relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) - relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) } } - generateConds := func(rv reflect.Value) { - _, values := schema.GetIdentityFieldValuesMap(rv, primaryFields) - column, queryValues := schema.ToQueryValues(foreignKeys, values) + var ( + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + _, queryValues = schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, relQueryValues = schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + ) - relValue := rel.Field.ReflectValueOf(rv) - _, relValues := schema.GetIdentityFieldValuesMap(relValue, relPrimaryFields) - relColumn, relQueryValues := schema.ToQueryValues(relForeignKeys, relValues) - - conds = append(conds, clause.And( - clause.IN{Column: column, Values: queryValues}, - clause.Not(clause.IN{Column: relColumn, Values: relQueryValues}), - )) + if column, values := schema.ToQueryValues(joinPrimaryKeys, queryValues); len(values) > 0 { + conds = append(conds, clause.IN{Column: column, Values: values}) + } else { + return ErrorPrimaryKeyRequired } - switch reflectValue.Kind() { - case reflect.Struct: - generateConds(reflectValue) - case reflect.Slice, reflect.Array: - for i := 0; i < reflectValue.Len(); i++ { - generateConds(reflectValue.Index(i)) - } + if relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, relQueryValues); len(relValues) > 0 { + conds = append(conds, clause.Not(clause.IN{Column: relColumn, Values: relValues})) } - association.DB.Where(conds).Delete(modelValue) + association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue) } } return association.Error @@ -227,9 +218,39 @@ func (association *Association) Delete(values ...interface{}) error { tx.Session(&Session{}).Model(modelValue).Clauses(conds...).UpdateColumns(updateAttrs) case schema.Many2Many: - modelValue := reflect.New(rel.JoinTable.ModelType).Interface() - conds := rel.ToQueryConditions(reflectValue) - tx.Clauses(clause.Where{Exprs: conds}).Delete(modelValue) + var primaryFields, relPrimaryFields []*schema.Field + var joinPrimaryKeys, joinRelPrimaryKeys []string + + for _, ref := range rel.References { + if ref.PrimaryValue == "" { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName) + } else { + relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) + joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) + } + } else { + conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } + } + + var ( + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + _, queryValues = schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, relQueryValues = schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + ) + + if column, values := schema.ToQueryValues(joinPrimaryKeys, queryValues); len(values) > 0 { + conds = append(conds, clause.IN{Column: column, Values: values}) + } else { + return ErrorPrimaryKeyRequired + } + + relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, relQueryValues) + conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) + + tx.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue) } relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) diff --git a/callbacks/delete.go b/callbacks/delete.go index 50b2880a..a88edcf8 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -5,6 +5,7 @@ import ( "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" ) func BeforeDelete(db *gorm.DB) { @@ -37,13 +38,22 @@ func Delete(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Delete{}) values := []reflect.Value{db.Statement.ReflectValue} - if db.Statement.Dest != db.Statement.Model { + if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { values = append(values, reflect.ValueOf(db.Statement.Model)) } - for _, field := range db.Statement.Schema.PrimaryFields { - for _, value := range values { - if value, isZero := field.ValueOf(value); !isZero { - db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + + if db.Statement.Schema != nil { + _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) + column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + db.Where(clause.IN{Column: column, Values: values}) + } else if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { + _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) + column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + db.Where(clause.IN{Column: column, Values: values}) } } } diff --git a/errors.go b/errors.go index 4f2bd4fa..140a5186 100644 --- a/errors.go +++ b/errors.go @@ -21,4 +21,6 @@ var ( ErrUnsupportedRelation = errors.New("unsupported relations") // ErrPtrStructSupported only ptr of struct supported ErrPtrStructSupported = errors.New("only ptr of struct supported") + // ErrorPrimaryKeyRequired primary keys required + ErrorPrimaryKeyRequired = errors.New("primary key required") ) diff --git a/logger/sql.go b/logger/sql.go index 219ae301..bb4e3e06 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -53,7 +53,9 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v } else { rv := reflect.ValueOf(v) - if !rv.IsValid() || rv.IsNil() { + if !rv.IsValid() { + vars[idx] = "NULL" + } else if rv.Kind() == reflect.Ptr && rv.IsNil() { vars[idx] = "NULL" } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { convertParams(reflect.Indirect(rv).Interface(), idx) diff --git a/tests/associations_test.go b/tests/associations_test.go index b6ddbd29..3ab69b42 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -641,7 +641,6 @@ func TestPolymorphicHasManyAssociation(t *testing.T) { if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil { t.Fatalf("Error happened when append account, got %v", err) } - return if toy.ID == 0 { t.Fatalf("Toy's ID should be created") From 457f1e5d7390c2b7f54c6111bfa863cfb35c5dbd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 26 May 2020 01:21:15 +0800 Subject: [PATCH 331/881] Test Many2Many Association for Slice --- association.go | 32 ++++++++++++---- tests/associations_test.go | 78 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 99 insertions(+), 11 deletions(-) diff --git a/association.go b/association.go index 49fd4558..92a19efb 100644 --- a/association.go +++ b/association.go @@ -340,11 +340,16 @@ func (association *Association) Count() (count int64) { return } +type assignBack struct { + Source reflect.Value + Index int + Dest reflect.Value +} + func (association *Association) saveAssociation(clear bool, values ...interface{}) { var ( reflectValue = association.DB.Statement.ReflectValue - assignBacks = [][2]reflect.Value{} - assignBack = association.Relationship.Field.FieldType.Kind() == reflect.Struct + assignBacks []assignBack ) appendToRelations := func(source, rv reflect.Value, clear bool) { @@ -354,14 +359,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ case reflect.Slice, reflect.Array: if rv.Len() > 0 { association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface()) - if assignBack { - assignBacks = append(assignBacks, [2]reflect.Value{source, rv.Index(0)}) + if association.Relationship.Field.FieldType.Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) } } case reflect.Struct: association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface()) - if assignBack { - assignBacks = append(assignBacks, [2]reflect.Value{source, rv}) + if association.Relationship.Field.FieldType.Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) } } case schema.HasMany, schema.Many2Many: @@ -379,6 +384,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } else { association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name) } + + if association.Relationship.Field.IndirectFieldType.Elem().Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{ + Source: source, + Index: fieldValue.Len(), + Dest: ev, + }) + } } switch rv.Kind() { @@ -451,6 +464,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } for _, assignBack := range assignBacks { - reflect.Indirect(assignBack[1]).Set(association.Relationship.Field.ReflectValueOf(assignBack[0])) + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source)) + if assignBack.Index > 0 { + reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1)) + } else { + reflect.Indirect(assignBack.Dest).Set(fieldValue) + } } } diff --git a/tests/associations_test.go b/tests/associations_test.go index 3ab69b42..3aa11edb 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -786,7 +786,7 @@ func TestMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Languages", 2, "") // Append - var language = Language{Code: "language-has-many-append", Name: "language-has-many-append"} + var language = Language{Code: "language-many2many-append", Name: "language-many2many-append"} DB.Create(&language) if err := DB.Model(&user2).Association("Languages").Append(&language); err != nil { @@ -799,8 +799,8 @@ func TestMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Languages", 3, "AfterAppend") var languages = []Language{ - {Code: "language-has-many-append-1-1", Name: "language-has-many-append-1-1"}, - {Code: "language-has-many-append-2-1", Name: "language-has-many-append-2-1"}, + {Code: "language-many2many-append-1-1", Name: "language-many2many-append-1-1"}, + {Code: "language-many2many-append-2-1", Name: "language-many2many-append-2-1"}, } DB.Create(&languages) @@ -815,7 +815,7 @@ func TestMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Languages", 5, "AfterAppendSlice") // Replace - var language2 = Language{Code: "language-has-many-replace", Name: "language-has-many-replace"} + var language2 = Language{Code: "language-many2many-replace", Name: "language-many2many-replace"} DB.Create(&language2) if err := DB.Model(&user2).Association("Languages").Replace(&language2); err != nil { @@ -852,3 +852,73 @@ func TestMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Languages", 0, "after clear") } + +func TestMany2ManyAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-many2many-1", Config{Languages: 2}), + *GetUser("slice-many2many-2", Config{Languages: 0}), + *GetUser("slice-many2many-3", Config{Languages: 4}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Languages", 6, "") + + // Find + var languages []Language + if DB.Model(&users).Association("Languages").Find(&languages); len(languages) != 6 { + t.Errorf("languages count should be %v, but got %v", 6, len(languages)) + } + + // Append + var languages1 = []Language{ + {Code: "language-many2many-append-1", Name: "language-many2many-append-1"}, + } + var languages2 = []Language{} + var languages3 = []Language{ + {Code: "language-many2many-append-3-1", Name: "language-many2many-append-3-1"}, + {Code: "language-many2many-append-3-2", Name: "language-many2many-append-3-2"}, + } + DB.Create(&languages1) + DB.Create(&languages3) + + DB.Model(&users).Association("Languages").Append(&languages1, &languages2, &languages3) + + AssertAssociationCount(t, users, "Languages", 9, "After Append") + + languages2_1 := []*Language{ + {Code: "language-slice-replace-1-1", Name: "language-slice-replace-1-1"}, + {Code: "language-slice-replace-1-2", Name: "language-slice-replace-1-2"}, + } + languages2_2 := []*Language{ + {Code: "language-slice-replace-2-1", Name: "language-slice-replace-2-1"}, + {Code: "language-slice-replace-2-2", Name: "language-slice-replace-2-2"}, + } + languages2_3 := &Language{Code: "language-slice-replace-3", Name: "language-slice-replace-3"} + DB.Create(&languages2_1) + DB.Create(&languages2_2) + DB.Create(&languages2_3) + + // Replace + DB.Model(&users).Association("Languages").Replace(&languages2_1, &languages2_2, languages2_3) + + AssertAssociationCount(t, users, "Languages", 5, "After Replace") + + // Delete + if err := DB.Model(&users).Association("Languages").Delete(&users[2].Languages); err != nil { + t.Errorf("no error should happend when deleting language, but got %v", err) + } + + AssertAssociationCount(t, users, "Languages", 4, "after delete") + + if err := DB.Model(&users).Association("Languages").Delete(users[0].Languages[0], users[1].Languages[1]); err != nil { + t.Errorf("no error should happend when deleting language, but got %v", err) + } + + AssertAssociationCount(t, users, "Languages", 2, "after delete") + + // Clear + DB.Model(&users).Association("Languages").Clear() + AssertAssociationCount(t, users, "Languages", 0, "After Clear") +} From 33a58c548b556a3a6d199f6bbebc134ba26f85d9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 26 May 2020 01:43:10 +0800 Subject: [PATCH 332/881] Test single table has many association --- tests/associations_test.go | 151 +++++++++++++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) diff --git a/tests/associations_test.go b/tests/associations_test.go index 3aa11edb..f01fb92b 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -563,6 +563,101 @@ func TestHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Pets", 0, "after clear") } +func TestSingleTableHasManyAssociation(t *testing.T) { + var user = *GetUser("hasmany", Config{Team: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Team").Find(&user2.Team) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Team", 2, "") + + // Append + var team = *GetUser("team", Config{}) + + if err := DB.Model(&user2).Association("Team").Append(&team); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + if team.ID == 0 { + t.Fatalf("Team's ID should be created") + } + + user.Team = append(user.Team, team) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Team", 3, "AfterAppend") + + var teams = []User{*GetUser("team-append-1", Config{}), *GetUser("team-append-2", Config{})} + + if err := DB.Model(&user2).Association("Team").Append(&teams); err != nil { + t.Fatalf("Error happened when append team, got %v", err) + } + + for _, team := range teams { + var team = team + if team.ID == 0 { + t.Fatalf("Team's ID should be created") + } + + user.Team = append(user.Team, team) + } + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Team", 5, "AfterAppendSlice") + + // Replace + var team2 = *GetUser("team-replace", Config{}) + + if err := DB.Model(&user2).Association("Team").Replace(&team2); err != nil { + t.Fatalf("Error happened when append team, got %v", err) + } + + if team2.ID == 0 { + t.Fatalf("team2's ID should be created") + } + + user.Team = []User{team2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Team", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Team").Delete(&User{}); err != nil { + t.Fatalf("Error happened when delete team, got %v", err) + } + AssertAssociationCount(t, user2, "Team", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Team").Delete(&team2); err != nil { + t.Fatalf("Error happened when delete Team, got %v", err) + } + AssertAssociationCount(t, user2, "Team", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Team").Append(&team); err != nil { + t.Fatalf("Error happened when append Team, got %v", err) + } + + AssertAssociationCount(t, user2, "Team", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Team").Clear(); err != nil { + t.Errorf("Error happened when clear Team, got %v", err) + } + + AssertAssociationCount(t, user2, "Team", 0, "after clear") +} + func TestHasManyAssociationForSlice(t *testing.T) { var users = []User{ *GetUser("slice-hasmany-1", Config{Pets: 2}), @@ -617,6 +712,62 @@ func TestHasManyAssociationForSlice(t *testing.T) { AssertAssociationCount(t, users, "Pets", 0, "After Clear") } +func TestSingleTableHasManyAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-hasmany-1", Config{Team: 2}), + *GetUser("slice-hasmany-2", Config{Team: 0}), + *GetUser("slice-hasmany-3", Config{Team: 4}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + // Count + AssertAssociationCount(t, users, "Team", 6, "") + + // Find + var teams []User + if DB.Model(&users).Association("Team").Find(&teams); len(teams) != 6 { + t.Errorf("teams count should be %v, but got %v", 6, len(teams)) + } + + // Append + DB.Model(&users).Association("Team").Append( + &User{Name: "pet-slice-append-1"}, + []*User{{Name: "pet-slice-append-2-1"}, {Name: "pet-slice-append-2-2"}}, + &User{Name: "pet-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Team", 10, "After Append") + + // Replace -> same as append + DB.Model(&users).Association("Team").Replace( + []*User{{Name: "pet-slice-replace-1-1"}, {Name: "pet-slice-replace-1-2"}}, + []*User{{Name: "pet-slice-replace-2-1"}, {Name: "pet-slice-replace-2-2"}}, + &User{Name: "pet-slice-replace-3"}, + ) + + AssertAssociationCount(t, users, "Team", 5, "After Append") + + // Delete + if err := DB.Model(&users).Association("Team").Delete(&users[2].Team); err != nil { + t.Errorf("no error should happend when deleting pet, but got %v", err) + } + + AssertAssociationCount(t, users, "Team", 4, "after delete") + + if err := DB.Model(&users).Association("Team").Delete(users[0].Team[0], users[1].Team[1]); err != nil { + t.Errorf("no error should happend when deleting pet, but got %v", err) + } + + AssertAssociationCount(t, users, "Team", 2, "after delete") + + // Clear + DB.Model(&users).Association("Team").Clear() + AssertAssociationCount(t, users, "Team", 0, "After Clear") +} + func TestPolymorphicHasManyAssociation(t *testing.T) { var user = *GetUser("hasmany", Config{Toys: 2}) From 8de2bb4eab9a73cab8cd59512329c61c5da51a83 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 26 May 2020 01:57:22 +0800 Subject: [PATCH 333/881] Test single table many2many association --- association.go | 16 +++-- tests/associations_test.go | 135 +++++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 6 deletions(-) diff --git a/association.go b/association.go index 92a19efb..4871a72f 100644 --- a/association.go +++ b/association.go @@ -422,9 +422,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) - for _, ref := range association.Relationship.References { - if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { - ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()) + if association.Relationship.JoinTable == nil { + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { + ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } } } } @@ -446,9 +448,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ case reflect.Struct: if clear && len(values) == 0 { association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) - for _, ref := range association.Relationship.References { - if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { - ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + if association.Relationship.JoinTable == nil { + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { + ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } } } } diff --git a/tests/associations_test.go b/tests/associations_test.go index f01fb92b..a102fa54 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -1073,3 +1073,138 @@ func TestMany2ManyAssociationForSlice(t *testing.T) { DB.Model(&users).Association("Languages").Clear() AssertAssociationCount(t, users, "Languages", 0, "After Clear") } + +func TestSingleTableMany2ManyAssociation(t *testing.T) { + var user = *GetUser("many2many", Config{Friends: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Friends").Find(&user2.Friends) + + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Friends", 2, "") + + // Append + var friend = *GetUser("friend", Config{}) + + if err := DB.Model(&user2).Association("Friends").Append(&friend); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + user.Friends = append(user.Friends, &friend) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Friends", 3, "AfterAppend") + + var friends = []*User{GetUser("friend-append-1", Config{}), GetUser("friend-append-2", Config{})} + + if err := DB.Model(&user2).Association("Friends").Append(&friends); err != nil { + t.Fatalf("Error happened when append friend, got %v", err) + } + + user.Friends = append(user.Friends, friends...) + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Friends", 5, "AfterAppendSlice") + + // Replace + var friend2 = *GetUser("friend-replace-2", Config{}) + + if err := DB.Model(&user2).Association("Friends").Replace(&friend2); err != nil { + t.Fatalf("Error happened when append friend, got %v", err) + } + + user.Friends = []*User{&friend2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Friends", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Friends").Delete(&User{}); err != nil { + t.Fatalf("Error happened when delete friend, got %v", err) + } + AssertAssociationCount(t, user2, "Friends", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Friends").Delete(&friend2); err != nil { + t.Fatalf("Error happened when delete Friends, got %v", err) + } + AssertAssociationCount(t, user2, "Friends", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Friends").Append(&friend); err != nil { + t.Fatalf("Error happened when append Friends, got %v", err) + } + + AssertAssociationCount(t, user2, "Friends", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Friends").Clear(); err != nil { + t.Errorf("Error happened when clear Friends, got %v", err) + } + + AssertAssociationCount(t, user2, "Friends", 0, "after clear") +} + +func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-many2many-1", Config{Team: 2}), + *GetUser("slice-many2many-2", Config{Team: 0}), + *GetUser("slice-many2many-3", Config{Team: 4}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Team", 6, "") + + // Find + var teams []User + if DB.Model(&users).Association("Team").Find(&teams); len(teams) != 6 { + t.Errorf("teams count should be %v, but got %v", 6, len(teams)) + } + + // Append + var teams1 = []User{*GetUser("friend-append-1", Config{})} + var teams2 = []User{} + var teams3 = []*User{GetUser("friend-append-3-1", Config{}), GetUser("friend-append-3-2", Config{})} + + DB.Model(&users).Association("Team").Append(&teams1, &teams2, &teams3) + + AssertAssociationCount(t, users, "Team", 9, "After Append") + + var teams2_1 = []User{*GetUser("friend-replace-1", Config{}), *GetUser("friend-replace-2", Config{})} + var teams2_2 = []User{*GetUser("friend-replace-2-1", Config{}), *GetUser("friend-replace-2-2", Config{})} + var teams2_3 = GetUser("friend-replace-3-1", Config{}) + + // Replace + DB.Model(&users).Association("Team").Replace(&teams2_1, &teams2_2, teams2_3) + + AssertAssociationCount(t, users, "Team", 5, "After Replace") + + // Delete + if err := DB.Model(&users).Association("Team").Delete(&users[2].Team); err != nil { + t.Errorf("no error should happend when deleting team, but got %v", err) + } + + AssertAssociationCount(t, users, "Team", 4, "after delete") + + if err := DB.Model(&users).Association("Team").Delete(users[0].Team[0], users[1].Team[1]); err != nil { + t.Errorf("no error should happend when deleting team, but got %v", err) + } + + AssertAssociationCount(t, users, "Team", 2, "after delete") + + // Clear + DB.Model(&users).Association("Team").Clear() + AssertAssociationCount(t, users, "Team", 0, "After Clear") +} From c299cb8db606d4cc784f2861a597b5970f5e8c09 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 26 May 2020 09:48:12 +0800 Subject: [PATCH 334/881] Refactor association --- association.go | 197 ++-- tests/associations_belongs_to_test.go | 216 +++++ tests/associations_has_many_test.go | 456 ++++++++++ tests/associations_has_one_test.go | 241 +++++ tests/associations_many2many_test.go | 299 +++++++ tests/associations_test.go | 1185 +------------------------ 6 files changed, 1311 insertions(+), 1283 deletions(-) create mode 100644 tests/associations_belongs_to_test.go create mode 100644 tests/associations_has_many_test.go create mode 100644 tests/associations_has_one_test.go create mode 100644 tests/associations_many2many_test.go diff --git a/association.go b/association.go index 4871a72f..5b777465 100644 --- a/association.go +++ b/association.go @@ -41,7 +41,7 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro if association.Error == nil { var ( queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) - tx = association.DB.Model(out).Table("") + tx = association.DB.Model(out) ) if association.Relationship.JoinTable != nil { @@ -80,10 +80,12 @@ func (association *Association) Append(values ...interface{}) error { func (association *Association) Replace(values ...interface{}) error { if association.Error == nil { + // save associations association.saveAssociation(true, values...) + + // set old associations's foreign key to null reflectValue := association.DB.Statement.ReflectValue rel := association.Relationship - switch rel.Type { case schema.BelongsTo: if len(values) == 0 { @@ -97,21 +99,17 @@ func (association *Association) Replace(values ...interface{}) error { } case schema.HasOne, schema.HasMany: var ( - tx = association.DB - primaryFields []*schema.Field - foreignKeys []string - updateMap = map[string]interface{}{} - relPrimaryKeys = []string{} - relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel}) - modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() + primaryFields []*schema.Field + foreignKeys []string + updateMap = map[string]interface{}{} + relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel}) + modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() + tx = association.DB.Model(modelValue) ) - for _, field := range rel.FieldSchema.PrimaryFields { - relPrimaryKeys = append(relPrimaryKeys, field.DBName) - } - if _, qvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(qvs) > 0 { - if column, values := schema.ToQueryValues(relPrimaryKeys, qvs); len(values) > 0 { - tx = tx.Not(clause.IN{Column: column, Values: values}) + if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { + if column, values := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { + tx.Not(clause.IN{Column: column, Values: values}) } } @@ -120,16 +118,22 @@ func (association *Association) Replace(values ...interface{}) error { primaryFields = append(primaryFields, ref.PrimaryKey) foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) updateMap[ref.ForeignKey.DBName] = nil + } else if ref.PrimaryValue != "" { + tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) } } - if _, qvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(qvs) > 0 { - column, values := schema.ToQueryValues(foreignKeys, qvs) - tx.Model(modelValue).Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap) + + if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { + column, values := schema.ToQueryValues(foreignKeys, pvs) + tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap) } case schema.Many2Many: - var primaryFields, relPrimaryFields []*schema.Field - var joinPrimaryKeys, joinRelPrimaryKeys []string - var conds []clause.Expression + var ( + primaryFields, relPrimaryFields []*schema.Field + joinPrimaryKeys, joinRelPrimaryKeys []string + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + tx = association.DB.Model(modelValue) + ) for _, ref := range rel.References { if ref.PrimaryValue == "" { @@ -141,27 +145,23 @@ func (association *Association) Replace(values ...interface{}) error { joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) } } else { - conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + tx.Clauses(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) } } - var ( - modelValue = reflect.New(rel.JoinTable.ModelType).Interface() - _, queryValues = schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) - _, relQueryValues = schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) - ) - - if column, values := schema.ToQueryValues(joinPrimaryKeys, queryValues); len(values) > 0 { - conds = append(conds, clause.IN{Column: column, Values: values}) + _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + if column, values := schema.ToQueryValues(joinPrimaryKeys, pvs); len(values) > 0 { + tx.Where(clause.IN{Column: column, Values: values}) } else { return ErrorPrimaryKeyRequired } - if relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, relQueryValues); len(relValues) > 0 { - conds = append(conds, clause.Not(clause.IN{Column: relColumn, Values: relValues})) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + if relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, rvs); len(relValues) > 0 { + tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) } - association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue) + tx.Delete(modelValue) } } return association.Error @@ -172,7 +172,6 @@ func (association *Association) Delete(values ...interface{}) error { var ( reflectValue = association.DB.Statement.ReflectValue rel = association.Relationship - tx = association.DB primaryFields, foreignFields []*schema.Field foreignKeys []string updateAttrs = map[string]interface{}{} @@ -191,35 +190,36 @@ func (association *Association) Delete(values ...interface{}) error { } switch rel.Type { - case schema.HasOne, schema.HasMany: - var ( - modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() - _, queryValues = schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) - _, relQueryValues = schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) - ) - - column, values := schema.ToQueryValues(foreignKeys, queryValues) - conds = append(conds, clause.IN{Column: column, Values: values}) - relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, relQueryValues) - conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) - - tx.Session(&Session{}).Model(modelValue).Clauses(conds...).UpdateColumns(updateAttrs) case schema.BelongsTo: - var ( - modelValue = reflect.New(rel.Schema.ModelType).Interface() - _, queryValues = schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) - _, relQueryValues = schema.GetIdentityFieldValuesMapFromValues(values, primaryFields) - ) + tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) - column, values := schema.ToQueryValues(rel.Schema.PrimaryFieldDBNames, queryValues) - conds = append(conds, clause.IN{Column: column, Values: values}) - relColumn, relValues := schema.ToQueryValues(foreignKeys, relQueryValues) + _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) + pcolumn, pvalues := schema.ToQueryValues(rel.Schema.PrimaryFieldDBNames, pvs) + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + + _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields) + relColumn, relValues := schema.ToQueryValues(foreignKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) - tx.Session(&Session{}).Model(modelValue).Clauses(conds...).UpdateColumns(updateAttrs) + association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + case schema.HasOne, schema.HasMany: + tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) + + _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + pcolumn, pvalues := schema.ToQueryValues(foreignKeys, pvs) + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + + _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, rvs) + conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) + + association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error case schema.Many2Many: - var primaryFields, relPrimaryFields []*schema.Field - var joinPrimaryKeys, joinRelPrimaryKeys []string + var ( + primaryFields, relPrimaryFields []*schema.Field + joinPrimaryKeys, joinRelPrimaryKeys []string + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + ) for _, ref := range rel.References { if ref.PrimaryValue == "" { @@ -235,41 +235,34 @@ func (association *Association) Delete(values ...interface{}) error { } } - var ( - modelValue = reflect.New(rel.JoinTable.ModelType).Interface() - _, queryValues = schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) - _, relQueryValues = schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) - ) + _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + pcolumn, pvalues := schema.ToQueryValues(joinPrimaryKeys, pvs) + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) - if column, values := schema.ToQueryValues(joinPrimaryKeys, queryValues); len(values) > 0 { - conds = append(conds, clause.IN{Column: column, Values: values}) - } else { - return ErrorPrimaryKeyRequired - } - - relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, relQueryValues) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) - tx.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue) + association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error } - relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + if association.Error == nil { + relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) - if tx.Error == nil { cleanUpDeletedRelations := func(data reflect.Value) { if _, zero := rel.Field.ValueOf(data); !zero { fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) + primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) - fieldValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) switch fieldValue.Kind() { case reflect.Slice, reflect.Array: - validFieldValues := reflect.Zero(rel.Field.FieldType) + validFieldValues := reflect.Zero(rel.Field.IndirectFieldType) for i := 0; i < fieldValue.Len(); i++ { for idx, field := range rel.FieldSchema.PrimaryFields { - fieldValues[idx], _ = field.ValueOf(fieldValue.Index(i)) + primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i)) } - if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; !ok { + if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok { validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i)) } } @@ -277,16 +270,19 @@ func (association *Association) Delete(values ...interface{}) error { rel.Field.Set(data, validFieldValues.Interface()) case reflect.Struct: for idx, field := range rel.FieldSchema.PrimaryFields { - fieldValues[idx], _ = field.ValueOf(fieldValue) + primaryValues[idx], _ = field.ValueOf(fieldValue) } - if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; ok { + + if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()) - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) - } else if ref.PrimaryValue == "" { - // FIXME - ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + + if rel.JoinTable == nil { + for _, ref := range rel.References { + if ref.OwnPrimaryKey || ref.PrimaryValue != "" { + ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } else { + ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } } } } @@ -302,10 +298,9 @@ func (association *Association) Delete(values ...interface{}) error { case reflect.Struct: cleanUpDeletedRelations(reflectValue) } - } else { - association.Error = tx.Error } } + return association.Error } @@ -349,7 +344,7 @@ type assignBack struct { func (association *Association) saveAssociation(clear bool, values ...interface{}) { var ( reflectValue = association.DB.Statement.ReflectValue - assignBacks []assignBack + assignBacks []assignBack // assign association values back to arguments after save ) appendToRelations := func(source, rv reflect.Value, clear bool) { @@ -359,12 +354,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ case reflect.Slice, reflect.Array: if rv.Len() > 0 { association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface()) + if association.Relationship.Field.FieldType.Kind() == reflect.Struct { assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) } } case reflect.Struct: association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface()) + if association.Relationship.Field.FieldType.Kind() == reflect.Struct { assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) } @@ -385,12 +382,8 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name) } - if association.Relationship.Field.IndirectFieldType.Elem().Kind() == reflect.Struct { - assignBacks = append(assignBacks, assignBack{ - Source: source, - Index: fieldValue.Len(), - Dest: ev, - }) + if elemType.Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{Source: source, Dest: ev, Index: fieldValue.Len()}) } } @@ -409,10 +402,10 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } } - selectedColumns := []string{association.Relationship.Name} + selectedSaveColumns := []string{association.Relationship.Name} for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey { - selectedColumns = append(selectedColumns, ref.ForeignKey.Name) + selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name) } } @@ -422,6 +415,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + if association.Relationship.JoinTable == nil { for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { @@ -432,6 +426,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } break } + association.Error = errors.New("invalid association values, length doesn't match") return } @@ -439,15 +434,13 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ for i := 0; i < reflectValue.Len(); i++ { appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) - if len(values) > 0 { - // TODO support save slice data, sql with case - err := association.DB.Session(&Session{}).Select(selectedColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error - association.DB.AddError(err) - } + // TODO support save slice data, sql with case? + association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error } case reflect.Struct: if clear && len(values) == 0 { association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + if association.Relationship.JoinTable == nil { for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { @@ -463,7 +456,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } if len(values) > 0 { - association.DB.Session(&Session{}).Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface()) + association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Addr().Interface()).Error } } diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go new file mode 100644 index 00000000..236af191 --- /dev/null +++ b/tests/associations_belongs_to_test.go @@ -0,0 +1,216 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestBelongsToAssociation(t *testing.T) { + var user = *GetUser("belongs-to", Config{Company: true, Manager: true}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Company").Find(&user2.Company) + user2.Manager = &User{} + DB.Model(&user2).Association("Manager").Find(user2.Manager) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Company", 1, "") + AssertAssociationCount(t, user, "Manager", 1, "") + + // Append + var company = Company{Name: "company-belongs-to-append"} + var manager = GetUser("manager-belongs-to-append", Config{}) + + if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { + t.Fatalf("Error happened when append Company, got %v", err) + } + + if company.ID == 0 { + t.Fatalf("Company's ID should be created") + } + + if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil { + t.Fatalf("Error happened when append Manager, got %v", err) + } + + if manager.ID == 0 { + t.Fatalf("Manager's ID should be created") + } + + user.Company = company + user.Manager = manager + user.CompanyID = &company.ID + user.ManagerID = &manager.ID + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Company", 1, "AfterAppend") + AssertAssociationCount(t, user2, "Manager", 1, "AfterAppend") + + // Replace + var company2 = Company{Name: "company-belongs-to-replace"} + var manager2 = GetUser("manager-belongs-to-replace", Config{}) + + if err := DB.Model(&user2).Association("Company").Replace(&company2); err != nil { + t.Fatalf("Error happened when replace Company, got %v", err) + } + + if company2.ID == 0 { + t.Fatalf("Company's ID should be created") + } + + if err := DB.Model(&user2).Association("Manager").Replace(manager2); err != nil { + t.Fatalf("Error happened when replace Manager, got %v", err) + } + + if manager2.ID == 0 { + t.Fatalf("Manager's ID should be created") + } + + user.Company = company2 + user.Manager = manager2 + user.CompanyID = &company2.ID + user.ManagerID = &manager2.ID + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Company", 1, "AfterReplace") + AssertAssociationCount(t, user2, "Manager", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil { + t.Fatalf("Error happened when delete Company, got %v", err) + } + AssertAssociationCount(t, user2, "Company", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil { + t.Fatalf("Error happened when delete Company, got %v", err) + } + AssertAssociationCount(t, user2, "Company", 0, "after delete") + + if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil { + t.Fatalf("Error happened when delete Manager, got %v", err) + } + AssertAssociationCount(t, user2, "Manager", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil { + t.Fatalf("Error happened when delete Manager, got %v", err) + } + AssertAssociationCount(t, user2, "Manager", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { + t.Fatalf("Error happened when append Company, got %v", err) + } + + if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil { + t.Fatalf("Error happened when append Manager, got %v", err) + } + + AssertAssociationCount(t, user2, "Company", 1, "after prepare data") + AssertAssociationCount(t, user2, "Manager", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Company").Clear(); err != nil { + t.Errorf("Error happened when clear Company, got %v", err) + } + + if err := DB.Model(&user2).Association("Manager").Clear(); err != nil { + t.Errorf("Error happened when clear Manager, got %v", err) + } + + AssertAssociationCount(t, user2, "Company", 0, "after clear") + AssertAssociationCount(t, user2, "Manager", 0, "after clear") +} + +func TestBelongsToAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-belongs-to-1", Config{Company: true, Manager: true}), + *GetUser("slice-belongs-to-2", Config{Company: true, Manager: false}), + *GetUser("slice-belongs-to-3", Config{Company: true, Manager: true}), + } + + DB.Create(&users) + + AssertAssociationCount(t, users, "Company", 3, "") + AssertAssociationCount(t, users, "Manager", 2, "") + + // Find + var companies []Company + if DB.Model(&users).Association("Company").Find(&companies); len(companies) != 3 { + t.Errorf("companies count should be %v, but got %v", 3, len(companies)) + } + + var managers []User + if DB.Model(&users).Association("Manager").Find(&managers); len(managers) != 2 { + t.Errorf("managers count should be %v, but got %v", 2, len(managers)) + } + + // Append + DB.Model(&users).Association("Company").Append( + &Company{Name: "company-slice-append-1"}, + &Company{Name: "company-slice-append-2"}, + &Company{Name: "company-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Company", 3, "After Append") + + DB.Model(&users).Association("Manager").Append( + GetUser("manager-slice-belongs-to-1", Config{}), + GetUser("manager-slice-belongs-to-2", Config{}), + GetUser("manager-slice-belongs-to-3", Config{}), + ) + AssertAssociationCount(t, users, "Manager", 3, "After Append") + + if err := DB.Model(&users).Association("Manager").Append( + GetUser("manager-slice-belongs-to-test-1", Config{}), + ).Error; err == nil { + t.Errorf("unmatched length when update user's manager") + } + + // Replace -> same as append + + // Delete + if err := DB.Model(&users).Association("Company").Delete(&users[0].Company); err != nil { + t.Errorf("no error should happend when deleting company, but got %v", err) + } + + if users[0].CompanyID != nil || users[0].Company.ID != 0 { + t.Errorf("users[0]'s company should be deleted'") + } + + AssertAssociationCount(t, users, "Company", 2, "After Delete") + + // Clear + DB.Model(&users).Association("Company").Clear() + AssertAssociationCount(t, users, "Company", 0, "After Clear") + + DB.Model(&users).Association("Manager").Clear() + AssertAssociationCount(t, users, "Manager", 0, "After Clear") + + // shared company + company := Company{Name: "shared"} + if err := DB.Model(&users[0]).Association("Company").Append(&company); err != nil { + t.Errorf("Error happened when append company to user, got %v", err) + } + + if err := DB.Model(&users[1]).Association("Company").Append(&company); err != nil { + t.Errorf("Error happened when append company to user, got %v", err) + } + + if users[0].CompanyID == nil || users[1].CompanyID == nil || *users[0].CompanyID != *users[1].CompanyID { + t.Errorf("user's company id should exists and equal, but its: %v, %v", users[0].CompanyID, users[1].CompanyID) + } + + DB.Model(&users[0]).Association("Company").Delete(&company) + AssertAssociationCount(t, users[0], "Company", 0, "After Delete") + AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete") +} diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go new file mode 100644 index 00000000..2269d701 --- /dev/null +++ b/tests/associations_has_many_test.go @@ -0,0 +1,456 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestHasManyAssociation(t *testing.T) { + var user = *GetUser("hasmany", Config{Pets: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Pets").Find(&user2.Pets) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Pets", 2, "") + + // Append + var pet = Pet{Name: "pet-has-many-append"} + + if err := DB.Model(&user2).Association("Pets").Append(&pet); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + if pet.ID == 0 { + t.Fatalf("Pet's ID should be created") + } + + user.Pets = append(user.Pets, &pet) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Pets", 3, "AfterAppend") + + var pets = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} + + if err := DB.Model(&user2).Association("Pets").Append(&pets); err != nil { + t.Fatalf("Error happened when append pet, got %v", err) + } + + for _, pet := range pets { + var pet = pet + if pet.ID == 0 { + t.Fatalf("Pet's ID should be created") + } + + user.Pets = append(user.Pets, &pet) + } + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Pets", 5, "AfterAppendSlice") + + // Replace + var pet2 = Pet{Name: "pet-has-many-replace"} + + if err := DB.Model(&user2).Association("Pets").Replace(&pet2); err != nil { + t.Fatalf("Error happened when append pet, got %v", err) + } + + if pet2.ID == 0 { + t.Fatalf("pet2's ID should be created") + } + + user.Pets = []*Pet{&pet2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Pets", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Pets").Delete(&Pet{}); err != nil { + t.Fatalf("Error happened when delete pet, got %v", err) + } + AssertAssociationCount(t, user2, "Pets", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Pets").Delete(&pet2); err != nil { + t.Fatalf("Error happened when delete Pets, got %v", err) + } + AssertAssociationCount(t, user2, "Pets", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Pets").Append(&pet); err != nil { + t.Fatalf("Error happened when append Pets, got %v", err) + } + + AssertAssociationCount(t, user2, "Pets", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Pets").Clear(); err != nil { + t.Errorf("Error happened when clear Pets, got %v", err) + } + + AssertAssociationCount(t, user2, "Pets", 0, "after clear") +} + +func TestSingleTableHasManyAssociation(t *testing.T) { + var user = *GetUser("hasmany", Config{Team: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Team").Find(&user2.Team) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Team", 2, "") + + // Append + var team = *GetUser("team", Config{}) + + if err := DB.Model(&user2).Association("Team").Append(&team); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + if team.ID == 0 { + t.Fatalf("Team's ID should be created") + } + + user.Team = append(user.Team, team) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Team", 3, "AfterAppend") + + var teams = []User{*GetUser("team-append-1", Config{}), *GetUser("team-append-2", Config{})} + + if err := DB.Model(&user2).Association("Team").Append(&teams); err != nil { + t.Fatalf("Error happened when append team, got %v", err) + } + + for _, team := range teams { + var team = team + if team.ID == 0 { + t.Fatalf("Team's ID should be created") + } + + user.Team = append(user.Team, team) + } + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Team", 5, "AfterAppendSlice") + + // Replace + var team2 = *GetUser("team-replace", Config{}) + + if err := DB.Model(&user2).Association("Team").Replace(&team2); err != nil { + t.Fatalf("Error happened when append team, got %v", err) + } + + if team2.ID == 0 { + t.Fatalf("team2's ID should be created") + } + + user.Team = []User{team2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Team", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Team").Delete(&User{}); err != nil { + t.Fatalf("Error happened when delete team, got %v", err) + } + AssertAssociationCount(t, user2, "Team", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Team").Delete(&team2); err != nil { + t.Fatalf("Error happened when delete Team, got %v", err) + } + AssertAssociationCount(t, user2, "Team", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Team").Append(&team); err != nil { + t.Fatalf("Error happened when append Team, got %v", err) + } + + AssertAssociationCount(t, user2, "Team", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Team").Clear(); err != nil { + t.Errorf("Error happened when clear Team, got %v", err) + } + + AssertAssociationCount(t, user2, "Team", 0, "after clear") +} + +func TestHasManyAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-hasmany-1", Config{Pets: 2}), + *GetUser("slice-hasmany-2", Config{Pets: 0}), + *GetUser("slice-hasmany-3", Config{Pets: 4}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Pets", 6, "") + + // Find + var pets []Pet + if DB.Model(&users).Association("Pets").Find(&pets); len(pets) != 6 { + t.Errorf("pets count should be %v, but got %v", 6, len(pets)) + } + + // Append + DB.Model(&users).Association("Pets").Append( + &Pet{Name: "pet-slice-append-1"}, + []*Pet{{Name: "pet-slice-append-2-1"}, {Name: "pet-slice-append-2-2"}}, + &Pet{Name: "pet-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Pets", 10, "After Append") + + // Replace -> same as append + DB.Model(&users).Association("Pets").Replace( + []*Pet{{Name: "pet-slice-replace-1-1"}, {Name: "pet-slice-replace-1-2"}}, + []*Pet{{Name: "pet-slice-replace-2-1"}, {Name: "pet-slice-replace-2-2"}}, + &Pet{Name: "pet-slice-replace-3"}, + ) + + AssertAssociationCount(t, users, "Pets", 5, "After Append") + + // Delete + if err := DB.Model(&users).Association("Pets").Delete(&users[2].Pets); err != nil { + t.Errorf("no error should happend when deleting pet, but got %v", err) + } + + AssertAssociationCount(t, users, "Pets", 4, "after delete") + + if err := DB.Model(&users).Association("Pets").Delete(users[0].Pets[0], users[1].Pets[1]); err != nil { + t.Errorf("no error should happend when deleting pet, but got %v", err) + } + + AssertAssociationCount(t, users, "Pets", 2, "after delete") + + // Clear + DB.Model(&users).Association("Pets").Clear() + AssertAssociationCount(t, users, "Pets", 0, "After Clear") +} + +func TestSingleTableHasManyAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-hasmany-1", Config{Team: 2}), + *GetUser("slice-hasmany-2", Config{Team: 0}), + *GetUser("slice-hasmany-3", Config{Team: 4}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + // Count + AssertAssociationCount(t, users, "Team", 6, "") + + // Find + var teams []User + if DB.Model(&users).Association("Team").Find(&teams); len(teams) != 6 { + t.Errorf("teams count should be %v, but got %v", 6, len(teams)) + } + + // Append + DB.Model(&users).Association("Team").Append( + &User{Name: "pet-slice-append-1"}, + []*User{{Name: "pet-slice-append-2-1"}, {Name: "pet-slice-append-2-2"}}, + &User{Name: "pet-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Team", 10, "After Append") + + // Replace -> same as append + DB.Model(&users).Association("Team").Replace( + []*User{{Name: "pet-slice-replace-1-1"}, {Name: "pet-slice-replace-1-2"}}, + []*User{{Name: "pet-slice-replace-2-1"}, {Name: "pet-slice-replace-2-2"}}, + &User{Name: "pet-slice-replace-3"}, + ) + + AssertAssociationCount(t, users, "Team", 5, "After Append") + + // Delete + if err := DB.Model(&users).Association("Team").Delete(&users[2].Team); err != nil { + t.Errorf("no error should happend when deleting pet, but got %v", err) + } + + AssertAssociationCount(t, users, "Team", 4, "after delete") + + if err := DB.Model(&users).Association("Team").Delete(users[0].Team[0], users[1].Team[1]); err != nil { + t.Errorf("no error should happend when deleting pet, but got %v", err) + } + + AssertAssociationCount(t, users, "Team", 2, "after delete") + + // Clear + DB.Model(&users).Association("Team").Clear() + AssertAssociationCount(t, users, "Team", 0, "After Clear") +} + +func TestPolymorphicHasManyAssociation(t *testing.T) { + var user = *GetUser("hasmany", Config{Toys: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Toys").Find(&user2.Toys) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Toys", 2, "") + + // Append + var toy = Toy{Name: "toy-has-many-append"} + + if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + if toy.ID == 0 { + t.Fatalf("Toy's ID should be created") + } + + user.Toys = append(user.Toys, toy) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Toys", 3, "AfterAppend") + + var toys = []Toy{{Name: "toy-has-many-append-1-1"}, {Name: "toy-has-many-append-1-1"}} + + if err := DB.Model(&user2).Association("Toys").Append(&toys); err != nil { + t.Fatalf("Error happened when append toy, got %v", err) + } + + for _, toy := range toys { + var toy = toy + if toy.ID == 0 { + t.Fatalf("Toy's ID should be created") + } + + user.Toys = append(user.Toys, toy) + } + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Toys", 5, "AfterAppendSlice") + + // Replace + var toy2 = Toy{Name: "toy-has-many-replace"} + + if err := DB.Model(&user2).Association("Toys").Replace(&toy2); err != nil { + t.Fatalf("Error happened when append toy, got %v", err) + } + + if toy2.ID == 0 { + t.Fatalf("toy2's ID should be created") + } + + user.Toys = []Toy{toy2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Toys", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Toys").Delete(&Toy{}); err != nil { + t.Fatalf("Error happened when delete toy, got %v", err) + } + AssertAssociationCount(t, user2, "Toys", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Toys").Delete(&toy2); err != nil { + t.Fatalf("Error happened when delete Toys, got %v", err) + } + AssertAssociationCount(t, user2, "Toys", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil { + t.Fatalf("Error happened when append Toys, got %v", err) + } + + AssertAssociationCount(t, user2, "Toys", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Toys").Clear(); err != nil { + t.Errorf("Error happened when clear Toys, got %v", err) + } + + AssertAssociationCount(t, user2, "Toys", 0, "after clear") +} + +func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-hasmany-1", Config{Toys: 2}), + *GetUser("slice-hasmany-2", Config{Toys: 0}), + *GetUser("slice-hasmany-3", Config{Toys: 4}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Toys", 6, "") + + // Find + var toys []Toy + if DB.Model(&users).Association("Toys").Find(&toys); len(toys) != 6 { + t.Errorf("toys count should be %v, but got %v", 6, len(toys)) + } + + // Append + DB.Model(&users).Association("Toys").Append( + &Toy{Name: "toy-slice-append-1"}, + []Toy{{Name: "toy-slice-append-2-1"}, {Name: "toy-slice-append-2-2"}}, + &Toy{Name: "toy-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Toys", 10, "After Append") + + // Replace -> same as append + DB.Model(&users).Association("Toys").Replace( + []*Toy{{Name: "toy-slice-replace-1-1"}, {Name: "toy-slice-replace-1-2"}}, + []*Toy{{Name: "toy-slice-replace-2-1"}, {Name: "toy-slice-replace-2-2"}}, + &Toy{Name: "toy-slice-replace-3"}, + ) + + AssertAssociationCount(t, users, "Toys", 5, "After Append") + + // Delete + if err := DB.Model(&users).Association("Toys").Delete(&users[2].Toys); err != nil { + t.Errorf("no error should happend when deleting toy, but got %v", err) + } + + AssertAssociationCount(t, users, "Toys", 4, "after delete") + + if err := DB.Model(&users).Association("Toys").Delete(users[0].Toys[0], users[1].Toys[1]); err != nil { + t.Errorf("no error should happend when deleting toy, but got %v", err) + } + + AssertAssociationCount(t, users, "Toys", 2, "after delete") + + // Clear + DB.Model(&users).Association("Toys").Clear() + AssertAssociationCount(t, users, "Toys", 0, "After Clear") +} diff --git a/tests/associations_has_one_test.go b/tests/associations_has_one_test.go new file mode 100644 index 00000000..a863cb36 --- /dev/null +++ b/tests/associations_has_one_test.go @@ -0,0 +1,241 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestHasOneAssociation(t *testing.T) { + var user = *GetUser("hasone", Config{Account: true}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Account").Find(&user2.Account) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Account", 1, "") + + // Append + var account = Account{Number: "account-has-one-append"} + + if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + if account.ID == 0 { + t.Fatalf("Account's ID should be created") + } + + user.Account = account + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Account", 1, "AfterAppend") + + // Replace + var account2 = Account{Number: "account-has-one-replace"} + + if err := DB.Model(&user2).Association("Account").Replace(&account2); err != nil { + t.Fatalf("Error happened when append Account, got %v", err) + } + + if account2.ID == 0 { + t.Fatalf("account2's ID should be created") + } + + user.Account = account2 + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Account", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Account").Delete(&Account{}); err != nil { + t.Fatalf("Error happened when delete account, got %v", err) + } + AssertAssociationCount(t, user2, "Account", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Account").Delete(&account2); err != nil { + t.Fatalf("Error happened when delete Account, got %v", err) + } + AssertAssociationCount(t, user2, "Account", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { + t.Fatalf("Error happened when append Account, got %v", err) + } + + AssertAssociationCount(t, user2, "Account", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Account").Clear(); err != nil { + t.Errorf("Error happened when clear Account, got %v", err) + } + + AssertAssociationCount(t, user2, "Account", 0, "after clear") +} + +func TestHasOneAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-hasone-1", Config{Account: true}), + *GetUser("slice-hasone-2", Config{Account: false}), + *GetUser("slice-hasone-3", Config{Account: true}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Account", 2, "") + + // Find + var accounts []Account + if DB.Model(&users).Association("Account").Find(&accounts); len(accounts) != 2 { + t.Errorf("accounts count should be %v, but got %v", 3, len(accounts)) + } + + // Append + DB.Model(&users).Association("Account").Append( + &Account{Number: "account-slice-append-1"}, + &Account{Number: "account-slice-append-2"}, + &Account{Number: "account-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Account", 3, "After Append") + + // Replace -> same as append + + // Delete + if err := DB.Model(&users).Association("Account").Delete(&users[0].Account); err != nil { + t.Errorf("no error should happend when deleting account, but got %v", err) + } + + AssertAssociationCount(t, users, "Account", 2, "after delete") + + // Clear + DB.Model(&users).Association("Account").Clear() + AssertAssociationCount(t, users, "Account", 0, "After Clear") +} + +func TestPolymorphicHasOneAssociation(t *testing.T) { + var pet = Pet{Name: "hasone", Toy: Toy{Name: "toy-has-one"}} + + if err := DB.Create(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckPet(t, pet, pet) + + // Find + var pet2 Pet + DB.Find(&pet2, "id = ?", pet.ID) + DB.Model(&pet2).Association("Toy").Find(&pet2.Toy) + CheckPet(t, pet2, pet) + + // Count + AssertAssociationCount(t, pet, "Toy", 1, "") + + // Append + var toy = Toy{Name: "toy-has-one-append"} + + if err := DB.Model(&pet2).Association("Toy").Append(&toy); err != nil { + t.Fatalf("Error happened when append toy, got %v", err) + } + + if toy.ID == 0 { + t.Fatalf("Toy's ID should be created") + } + + pet.Toy = toy + CheckPet(t, pet2, pet) + + AssertAssociationCount(t, pet, "Toy", 1, "AfterAppend") + + // Replace + var toy2 = Toy{Name: "toy-has-one-replace"} + + if err := DB.Model(&pet2).Association("Toy").Replace(&toy2); err != nil { + t.Fatalf("Error happened when append Toy, got %v", err) + } + + if toy2.ID == 0 { + t.Fatalf("toy2's ID should be created") + } + + pet.Toy = toy2 + CheckPet(t, pet2, pet) + + AssertAssociationCount(t, pet2, "Toy", 1, "AfterReplace") + + // Delete + if err := DB.Model(&pet2).Association("Toy").Delete(&Toy{}); err != nil { + t.Fatalf("Error happened when delete toy, got %v", err) + } + AssertAssociationCount(t, pet2, "Toy", 1, "after delete non-existing data") + + if err := DB.Model(&pet2).Association("Toy").Delete(&toy2); err != nil { + t.Fatalf("Error happened when delete Toy, got %v", err) + } + AssertAssociationCount(t, pet2, "Toy", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&pet2).Association("Toy").Append(&toy); err != nil { + t.Fatalf("Error happened when append Toy, got %v", err) + } + + AssertAssociationCount(t, pet2, "Toy", 1, "after prepare data") + + // Clear + if err := DB.Model(&pet2).Association("Toy").Clear(); err != nil { + t.Errorf("Error happened when clear Toy, got %v", err) + } + + AssertAssociationCount(t, pet2, "Toy", 0, "after clear") +} + +func TestPolymorphicHasOneAssociationForSlice(t *testing.T) { + var pets = []Pet{ + {Name: "hasone-1", Toy: Toy{Name: "toy-has-one"}}, + {Name: "hasone-2", Toy: Toy{}}, + {Name: "hasone-3", Toy: Toy{Name: "toy-has-one"}}, + } + + DB.Create(&pets) + + // Count + AssertAssociationCount(t, pets, "Toy", 2, "") + + // Find + var toys []Toy + if DB.Model(&pets).Association("Toy").Find(&toys); len(toys) != 2 { + t.Errorf("toys count should be %v, but got %v", 3, len(toys)) + } + + // Append + DB.Model(&pets).Association("Toy").Append( + &Toy{Name: "toy-slice-append-1"}, + &Toy{Name: "toy-slice-append-2"}, + &Toy{Name: "toy-slice-append-3"}, + ) + + AssertAssociationCount(t, pets, "Toy", 3, "After Append") + + // Replace -> same as append + + // Delete + if err := DB.Model(&pets).Association("Toy").Delete(&pets[0].Toy); err != nil { + t.Errorf("no error should happend when deleting toy, but got %v", err) + } + + AssertAssociationCount(t, pets, "Toy", 2, "after delete") + + // Clear + DB.Model(&pets).Association("Toy").Clear() + AssertAssociationCount(t, pets, "Toy", 0, "After Clear") +} diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go new file mode 100644 index 00000000..a2db9675 --- /dev/null +++ b/tests/associations_many2many_test.go @@ -0,0 +1,299 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestMany2ManyAssociation(t *testing.T) { + var user = *GetUser("many2many", Config{Languages: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Languages").Find(&user2.Languages) + + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Languages", 2, "") + + // Append + var language = Language{Code: "language-many2many-append", Name: "language-many2many-append"} + DB.Create(&language) + + if err := DB.Model(&user2).Association("Languages").Append(&language); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + user.Languages = append(user.Languages, language) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Languages", 3, "AfterAppend") + + var languages = []Language{ + {Code: "language-many2many-append-1-1", Name: "language-many2many-append-1-1"}, + {Code: "language-many2many-append-2-1", Name: "language-many2many-append-2-1"}, + } + DB.Create(&languages) + + if err := DB.Model(&user2).Association("Languages").Append(&languages); err != nil { + t.Fatalf("Error happened when append language, got %v", err) + } + + user.Languages = append(user.Languages, languages...) + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Languages", 5, "AfterAppendSlice") + + // Replace + var language2 = Language{Code: "language-many2many-replace", Name: "language-many2many-replace"} + DB.Create(&language2) + + if err := DB.Model(&user2).Association("Languages").Replace(&language2); err != nil { + t.Fatalf("Error happened when append language, got %v", err) + } + + user.Languages = []Language{language2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Languages", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Languages").Delete(&Language{}); err != nil { + t.Fatalf("Error happened when delete language, got %v", err) + } + AssertAssociationCount(t, user2, "Languages", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Languages").Delete(&language2); err != nil { + t.Fatalf("Error happened when delete Languages, got %v", err) + } + AssertAssociationCount(t, user2, "Languages", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Languages").Append(&language); err != nil { + t.Fatalf("Error happened when append Languages, got %v", err) + } + + AssertAssociationCount(t, user2, "Languages", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Languages").Clear(); err != nil { + t.Errorf("Error happened when clear Languages, got %v", err) + } + + AssertAssociationCount(t, user2, "Languages", 0, "after clear") +} + +func TestMany2ManyAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-many2many-1", Config{Languages: 2}), + *GetUser("slice-many2many-2", Config{Languages: 0}), + *GetUser("slice-many2many-3", Config{Languages: 4}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Languages", 6, "") + + // Find + var languages []Language + if DB.Model(&users).Association("Languages").Find(&languages); len(languages) != 6 { + t.Errorf("languages count should be %v, but got %v", 6, len(languages)) + } + + // Append + var languages1 = []Language{ + {Code: "language-many2many-append-1", Name: "language-many2many-append-1"}, + } + var languages2 = []Language{} + var languages3 = []Language{ + {Code: "language-many2many-append-3-1", Name: "language-many2many-append-3-1"}, + {Code: "language-many2many-append-3-2", Name: "language-many2many-append-3-2"}, + } + DB.Create(&languages1) + DB.Create(&languages3) + + DB.Model(&users).Association("Languages").Append(&languages1, &languages2, &languages3) + + AssertAssociationCount(t, users, "Languages", 9, "After Append") + + languages2_1 := []*Language{ + {Code: "language-slice-replace-1-1", Name: "language-slice-replace-1-1"}, + {Code: "language-slice-replace-1-2", Name: "language-slice-replace-1-2"}, + } + languages2_2 := []*Language{ + {Code: "language-slice-replace-2-1", Name: "language-slice-replace-2-1"}, + {Code: "language-slice-replace-2-2", Name: "language-slice-replace-2-2"}, + } + languages2_3 := &Language{Code: "language-slice-replace-3", Name: "language-slice-replace-3"} + DB.Create(&languages2_1) + DB.Create(&languages2_2) + DB.Create(&languages2_3) + + // Replace + DB.Model(&users).Association("Languages").Replace(&languages2_1, &languages2_2, languages2_3) + + AssertAssociationCount(t, users, "Languages", 5, "After Replace") + + // Delete + if err := DB.Model(&users).Association("Languages").Delete(&users[2].Languages); err != nil { + t.Errorf("no error should happend when deleting language, but got %v", err) + } + + AssertAssociationCount(t, users, "Languages", 4, "after delete") + + if err := DB.Model(&users).Association("Languages").Delete(users[0].Languages[0], users[1].Languages[1]); err != nil { + t.Errorf("no error should happend when deleting language, but got %v", err) + } + + AssertAssociationCount(t, users, "Languages", 2, "after delete") + + // Clear + DB.Model(&users).Association("Languages").Clear() + AssertAssociationCount(t, users, "Languages", 0, "After Clear") +} + +func TestSingleTableMany2ManyAssociation(t *testing.T) { + var user = *GetUser("many2many", Config{Friends: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Friends").Find(&user2.Friends) + + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Friends", 2, "") + + // Append + var friend = *GetUser("friend", Config{}) + + if err := DB.Model(&user2).Association("Friends").Append(&friend); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + user.Friends = append(user.Friends, &friend) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Friends", 3, "AfterAppend") + + var friends = []*User{GetUser("friend-append-1", Config{}), GetUser("friend-append-2", Config{})} + + if err := DB.Model(&user2).Association("Friends").Append(&friends); err != nil { + t.Fatalf("Error happened when append friend, got %v", err) + } + + user.Friends = append(user.Friends, friends...) + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Friends", 5, "AfterAppendSlice") + + // Replace + var friend2 = *GetUser("friend-replace-2", Config{}) + + if err := DB.Model(&user2).Association("Friends").Replace(&friend2); err != nil { + t.Fatalf("Error happened when append friend, got %v", err) + } + + user.Friends = []*User{&friend2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Friends", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Friends").Delete(&User{}); err != nil { + t.Fatalf("Error happened when delete friend, got %v", err) + } + AssertAssociationCount(t, user2, "Friends", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Friends").Delete(&friend2); err != nil { + t.Fatalf("Error happened when delete Friends, got %v", err) + } + AssertAssociationCount(t, user2, "Friends", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Friends").Append(&friend); err != nil { + t.Fatalf("Error happened when append Friends, got %v", err) + } + + AssertAssociationCount(t, user2, "Friends", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Friends").Clear(); err != nil { + t.Errorf("Error happened when clear Friends, got %v", err) + } + + AssertAssociationCount(t, user2, "Friends", 0, "after clear") +} + +func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-many2many-1", Config{Team: 2}), + *GetUser("slice-many2many-2", Config{Team: 0}), + *GetUser("slice-many2many-3", Config{Team: 4}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Team", 6, "") + + // Find + var teams []User + if DB.Model(&users).Association("Team").Find(&teams); len(teams) != 6 { + t.Errorf("teams count should be %v, but got %v", 6, len(teams)) + } + + // Append + var teams1 = []User{*GetUser("friend-append-1", Config{})} + var teams2 = []User{} + var teams3 = []*User{GetUser("friend-append-3-1", Config{}), GetUser("friend-append-3-2", Config{})} + + DB.Model(&users).Association("Team").Append(&teams1, &teams2, &teams3) + + AssertAssociationCount(t, users, "Team", 9, "After Append") + + var teams2_1 = []User{*GetUser("friend-replace-1", Config{}), *GetUser("friend-replace-2", Config{})} + var teams2_2 = []User{*GetUser("friend-replace-2-1", Config{}), *GetUser("friend-replace-2-2", Config{})} + var teams2_3 = GetUser("friend-replace-3-1", Config{}) + + // Replace + DB.Model(&users).Association("Team").Replace(&teams2_1, &teams2_2, teams2_3) + + AssertAssociationCount(t, users, "Team", 5, "After Replace") + + // Delete + if err := DB.Model(&users).Association("Team").Delete(&users[2].Team); err != nil { + t.Errorf("no error should happend when deleting team, but got %v", err) + } + + AssertAssociationCount(t, users, "Team", 4, "after delete") + + if err := DB.Model(&users).Association("Team").Delete(users[0].Team[0], users[1].Team[1]); err != nil { + t.Errorf("no error should happend when deleting team, but got %v", err) + } + + AssertAssociationCount(t, users, "Team", 2, "after delete") + + // Clear + DB.Model(&users).Association("Team").Clear() + AssertAssociationCount(t, users, "Team", 0, "After Clear") +} diff --git a/tests/associations_test.go b/tests/associations_test.go index a102fa54..89bbe142 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -25,1186 +25,9 @@ func AssertAssociationCount(t *testing.T, data interface{}, name string, result } } -func TestBelongsToAssociation(t *testing.T) { - var user = *GetUser("belongs-to", Config{Company: true, Manager: true}) - - if err := DB.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) +func TestInvalidAssociation(t *testing.T) { + var user = *GetUser("invalid", Config{Company: true, Manager: true}) + if err := DB.Model(&user).Association("Invalid").Find(&user.Company).Error; err == nil { + t.Errorf("should return errors for invalid association, but got nil") } - - CheckUser(t, user, user) - - // Find - var user2 User - DB.Find(&user2, "id = ?", user.ID) - DB.Model(&user2).Association("Company").Find(&user2.Company) - user2.Manager = &User{} - DB.Model(&user2).Association("Manager").Find(user2.Manager) - CheckUser(t, user2, user) - - // Count - AssertAssociationCount(t, user, "Company", 1, "") - AssertAssociationCount(t, user, "Manager", 1, "") - - // Append - var company = Company{Name: "company-belongs-to-append"} - var manager = GetUser("manager-belongs-to-append", Config{}) - - if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { - t.Fatalf("Error happened when append Company, got %v", err) - } - - if company.ID == 0 { - t.Fatalf("Company's ID should be created") - } - - if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil { - t.Fatalf("Error happened when append Manager, got %v", err) - } - - if manager.ID == 0 { - t.Fatalf("Manager's ID should be created") - } - - user.Company = company - user.Manager = manager - user.CompanyID = &company.ID - user.ManagerID = &manager.ID - CheckUser(t, user2, user) - - AssertAssociationCount(t, user2, "Company", 1, "AfterAppend") - AssertAssociationCount(t, user2, "Manager", 1, "AfterAppend") - - // Replace - var company2 = Company{Name: "company-belongs-to-replace"} - var manager2 = GetUser("manager-belongs-to-replace", Config{}) - - if err := DB.Model(&user2).Association("Company").Replace(&company2); err != nil { - t.Fatalf("Error happened when replace Company, got %v", err) - } - - if company2.ID == 0 { - t.Fatalf("Company's ID should be created") - } - - if err := DB.Model(&user2).Association("Manager").Replace(manager2); err != nil { - t.Fatalf("Error happened when replace Manager, got %v", err) - } - - if manager2.ID == 0 { - t.Fatalf("Manager's ID should be created") - } - - user.Company = company2 - user.Manager = manager2 - user.CompanyID = &company2.ID - user.ManagerID = &manager2.ID - CheckUser(t, user2, user) - - AssertAssociationCount(t, user2, "Company", 1, "AfterReplace") - AssertAssociationCount(t, user2, "Manager", 1, "AfterReplace") - - // Delete - if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil { - t.Fatalf("Error happened when delete Company, got %v", err) - } - AssertAssociationCount(t, user2, "Company", 1, "after delete non-existing data") - - if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil { - t.Fatalf("Error happened when delete Company, got %v", err) - } - AssertAssociationCount(t, user2, "Company", 0, "after delete") - - if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil { - t.Fatalf("Error happened when delete Manager, got %v", err) - } - AssertAssociationCount(t, user2, "Manager", 1, "after delete non-existing data") - - if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil { - t.Fatalf("Error happened when delete Manager, got %v", err) - } - AssertAssociationCount(t, user2, "Manager", 0, "after delete") - - // Prepare Data for Clear - if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { - t.Fatalf("Error happened when append Company, got %v", err) - } - - if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil { - t.Fatalf("Error happened when append Manager, got %v", err) - } - - AssertAssociationCount(t, user2, "Company", 1, "after prepare data") - AssertAssociationCount(t, user2, "Manager", 1, "after prepare data") - - // Clear - if err := DB.Model(&user2).Association("Company").Clear(); err != nil { - t.Errorf("Error happened when clear Company, got %v", err) - } - - if err := DB.Model(&user2).Association("Manager").Clear(); err != nil { - t.Errorf("Error happened when clear Manager, got %v", err) - } - - AssertAssociationCount(t, user2, "Company", 0, "after clear") - AssertAssociationCount(t, user2, "Manager", 0, "after clear") -} - -func TestBelongsToAssociationForSlice(t *testing.T) { - var users = []User{ - *GetUser("slice-belongs-to-1", Config{Company: true, Manager: true}), - *GetUser("slice-belongs-to-2", Config{Company: true, Manager: false}), - *GetUser("slice-belongs-to-3", Config{Company: true, Manager: true}), - } - - DB.Create(&users) - - AssertAssociationCount(t, users, "Company", 3, "") - AssertAssociationCount(t, users, "Manager", 2, "") - - // Find - var companies []Company - if DB.Model(&users).Association("Company").Find(&companies); len(companies) != 3 { - t.Errorf("companies count should be %v, but got %v", 3, len(companies)) - } - - var managers []User - if DB.Model(&users).Association("Manager").Find(&managers); len(managers) != 2 { - t.Errorf("managers count should be %v, but got %v", 2, len(managers)) - } - - // Append - DB.Model(&users).Association("Company").Append( - &Company{Name: "company-slice-append-1"}, - &Company{Name: "company-slice-append-2"}, - &Company{Name: "company-slice-append-3"}, - ) - - AssertAssociationCount(t, users, "Company", 3, "After Append") - - DB.Model(&users).Association("Manager").Append( - GetUser("manager-slice-belongs-to-1", Config{}), - GetUser("manager-slice-belongs-to-2", Config{}), - GetUser("manager-slice-belongs-to-3", Config{}), - ) - AssertAssociationCount(t, users, "Manager", 3, "After Append") - - if err := DB.Model(&users).Association("Manager").Append( - GetUser("manager-slice-belongs-to-test-1", Config{}), - ).Error; err == nil { - t.Errorf("unmatched length when update user's manager") - } - - // Replace -> same as append - - // Delete - if err := DB.Model(&users).Association("Company").Delete(&users[0].Company); err != nil { - t.Errorf("no error should happend when deleting company, but got %v", err) - } - - if users[0].CompanyID != nil || users[0].Company.ID != 0 { - t.Errorf("users[0]'s company should be deleted'") - } - - AssertAssociationCount(t, users, "Company", 2, "After Delete") - - // Clear - DB.Model(&users).Association("Company").Clear() - AssertAssociationCount(t, users, "Company", 0, "After Clear") - - DB.Model(&users).Association("Manager").Clear() - AssertAssociationCount(t, users, "Manager", 0, "After Clear") - - // shared company - company := Company{Name: "shared"} - if err := DB.Model(&users[0]).Association("Company").Append(&company); err != nil { - t.Errorf("Error happened when append company to user, got %v", err) - } - - if err := DB.Model(&users[1]).Association("Company").Append(&company); err != nil { - t.Errorf("Error happened when append company to user, got %v", err) - } - - if users[0].CompanyID == nil || users[1].CompanyID == nil || *users[0].CompanyID != *users[1].CompanyID { - t.Errorf("user's company id should exists and equal, but its: %v, %v", users[0].CompanyID, users[1].CompanyID) - } - - DB.Model(&users[0]).Association("Company").Delete(&company) - AssertAssociationCount(t, users[0], "Company", 0, "After Delete") - AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete") -} - -func TestHasOneAssociation(t *testing.T) { - var user = *GetUser("hasone", Config{Account: true}) - - if err := DB.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - CheckUser(t, user, user) - - // Find - var user2 User - DB.Find(&user2, "id = ?", user.ID) - DB.Model(&user2).Association("Account").Find(&user2.Account) - CheckUser(t, user2, user) - - // Count - AssertAssociationCount(t, user, "Account", 1, "") - - // Append - var account = Account{Number: "account-has-one-append"} - - if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { - t.Fatalf("Error happened when append account, got %v", err) - } - - if account.ID == 0 { - t.Fatalf("Account's ID should be created") - } - - user.Account = account - CheckUser(t, user2, user) - - AssertAssociationCount(t, user, "Account", 1, "AfterAppend") - - // Replace - var account2 = Account{Number: "account-has-one-replace"} - - if err := DB.Model(&user2).Association("Account").Replace(&account2); err != nil { - t.Fatalf("Error happened when append Account, got %v", err) - } - - if account2.ID == 0 { - t.Fatalf("account2's ID should be created") - } - - user.Account = account2 - CheckUser(t, user2, user) - - AssertAssociationCount(t, user2, "Account", 1, "AfterReplace") - - // Delete - if err := DB.Model(&user2).Association("Account").Delete(&Account{}); err != nil { - t.Fatalf("Error happened when delete account, got %v", err) - } - AssertAssociationCount(t, user2, "Account", 1, "after delete non-existing data") - - if err := DB.Model(&user2).Association("Account").Delete(&account2); err != nil { - t.Fatalf("Error happened when delete Account, got %v", err) - } - AssertAssociationCount(t, user2, "Account", 0, "after delete") - - // Prepare Data for Clear - if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { - t.Fatalf("Error happened when append Account, got %v", err) - } - - AssertAssociationCount(t, user2, "Account", 1, "after prepare data") - - // Clear - if err := DB.Model(&user2).Association("Account").Clear(); err != nil { - t.Errorf("Error happened when clear Account, got %v", err) - } - - AssertAssociationCount(t, user2, "Account", 0, "after clear") -} - -func TestHasOneAssociationForSlice(t *testing.T) { - var users = []User{ - *GetUser("slice-hasone-1", Config{Account: true}), - *GetUser("slice-hasone-2", Config{Account: false}), - *GetUser("slice-hasone-3", Config{Account: true}), - } - - DB.Create(&users) - - // Count - AssertAssociationCount(t, users, "Account", 2, "") - - // Find - var accounts []Account - if DB.Model(&users).Association("Account").Find(&accounts); len(accounts) != 2 { - t.Errorf("accounts count should be %v, but got %v", 3, len(accounts)) - } - - // Append - DB.Model(&users).Association("Account").Append( - &Account{Number: "account-slice-append-1"}, - &Account{Number: "account-slice-append-2"}, - &Account{Number: "account-slice-append-3"}, - ) - - AssertAssociationCount(t, users, "Account", 3, "After Append") - - // Replace -> same as append - - // Delete - if err := DB.Model(&users).Association("Account").Delete(&users[0].Account); err != nil { - t.Errorf("no error should happend when deleting account, but got %v", err) - } - - AssertAssociationCount(t, users, "Account", 2, "after delete") - - // Clear - DB.Model(&users).Association("Account").Clear() - AssertAssociationCount(t, users, "Account", 0, "After Clear") -} - -func TestPolymorphicHasOneAssociation(t *testing.T) { - var pet = Pet{Name: "hasone", Toy: Toy{Name: "toy-has-one"}} - - if err := DB.Create(&pet).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - CheckPet(t, pet, pet) - - // Find - var pet2 Pet - DB.Find(&pet2, "id = ?", pet.ID) - DB.Model(&pet2).Association("Toy").Find(&pet2.Toy) - CheckPet(t, pet2, pet) - - // Count - AssertAssociationCount(t, pet, "Toy", 1, "") - - // Append - var toy = Toy{Name: "toy-has-one-append"} - - if err := DB.Model(&pet2).Association("Toy").Append(&toy); err != nil { - t.Fatalf("Error happened when append toy, got %v", err) - } - - if toy.ID == 0 { - t.Fatalf("Toy's ID should be created") - } - - pet.Toy = toy - CheckPet(t, pet2, pet) - - AssertAssociationCount(t, pet, "Toy", 1, "AfterAppend") - - // Replace - var toy2 = Toy{Name: "toy-has-one-replace"} - - if err := DB.Model(&pet2).Association("Toy").Replace(&toy2); err != nil { - t.Fatalf("Error happened when append Toy, got %v", err) - } - - if toy2.ID == 0 { - t.Fatalf("toy2's ID should be created") - } - - pet.Toy = toy2 - CheckPet(t, pet2, pet) - - AssertAssociationCount(t, pet2, "Toy", 1, "AfterReplace") - - // Delete - if err := DB.Model(&pet2).Association("Toy").Delete(&Toy{}); err != nil { - t.Fatalf("Error happened when delete toy, got %v", err) - } - AssertAssociationCount(t, pet2, "Toy", 1, "after delete non-existing data") - - if err := DB.Model(&pet2).Association("Toy").Delete(&toy2); err != nil { - t.Fatalf("Error happened when delete Toy, got %v", err) - } - AssertAssociationCount(t, pet2, "Toy", 0, "after delete") - - // Prepare Data for Clear - if err := DB.Model(&pet2).Association("Toy").Append(&toy); err != nil { - t.Fatalf("Error happened when append Toy, got %v", err) - } - - AssertAssociationCount(t, pet2, "Toy", 1, "after prepare data") - - // Clear - if err := DB.Model(&pet2).Association("Toy").Clear(); err != nil { - t.Errorf("Error happened when clear Toy, got %v", err) - } - - AssertAssociationCount(t, pet2, "Toy", 0, "after clear") -} - -func TestPolymorphicHasOneAssociationForSlice(t *testing.T) { - var pets = []Pet{ - {Name: "hasone-1", Toy: Toy{Name: "toy-has-one"}}, - {Name: "hasone-2", Toy: Toy{}}, - {Name: "hasone-3", Toy: Toy{Name: "toy-has-one"}}, - } - - DB.Create(&pets) - - // Count - AssertAssociationCount(t, pets, "Toy", 2, "") - - // Find - var toys []Toy - if DB.Model(&pets).Association("Toy").Find(&toys); len(toys) != 2 { - t.Errorf("toys count should be %v, but got %v", 3, len(toys)) - } - - // Append - DB.Model(&pets).Association("Toy").Append( - &Toy{Name: "toy-slice-append-1"}, - &Toy{Name: "toy-slice-append-2"}, - &Toy{Name: "toy-slice-append-3"}, - ) - - AssertAssociationCount(t, pets, "Toy", 3, "After Append") - - // Replace -> same as append - - // Delete - if err := DB.Model(&pets).Association("Toy").Delete(&pets[0].Toy); err != nil { - t.Errorf("no error should happend when deleting toy, but got %v", err) - } - - AssertAssociationCount(t, pets, "Toy", 2, "after delete") - - // Clear - DB.Model(&pets).Association("Toy").Clear() - AssertAssociationCount(t, pets, "Toy", 0, "After Clear") -} - -func TestHasManyAssociation(t *testing.T) { - var user = *GetUser("hasmany", Config{Pets: 2}) - - if err := DB.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - CheckUser(t, user, user) - - // Find - var user2 User - DB.Find(&user2, "id = ?", user.ID) - DB.Model(&user2).Association("Pets").Find(&user2.Pets) - CheckUser(t, user2, user) - - // Count - AssertAssociationCount(t, user, "Pets", 2, "") - - // Append - var pet = Pet{Name: "pet-has-many-append"} - - if err := DB.Model(&user2).Association("Pets").Append(&pet); err != nil { - t.Fatalf("Error happened when append account, got %v", err) - } - - if pet.ID == 0 { - t.Fatalf("Pet's ID should be created") - } - - user.Pets = append(user.Pets, &pet) - CheckUser(t, user2, user) - - AssertAssociationCount(t, user, "Pets", 3, "AfterAppend") - - var pets = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} - - if err := DB.Model(&user2).Association("Pets").Append(&pets); err != nil { - t.Fatalf("Error happened when append pet, got %v", err) - } - - for _, pet := range pets { - var pet = pet - if pet.ID == 0 { - t.Fatalf("Pet's ID should be created") - } - - user.Pets = append(user.Pets, &pet) - } - - CheckUser(t, user2, user) - - AssertAssociationCount(t, user, "Pets", 5, "AfterAppendSlice") - - // Replace - var pet2 = Pet{Name: "pet-has-many-replace"} - - if err := DB.Model(&user2).Association("Pets").Replace(&pet2); err != nil { - t.Fatalf("Error happened when append pet, got %v", err) - } - - if pet2.ID == 0 { - t.Fatalf("pet2's ID should be created") - } - - user.Pets = []*Pet{&pet2} - CheckUser(t, user2, user) - - AssertAssociationCount(t, user2, "Pets", 1, "AfterReplace") - - // Delete - if err := DB.Model(&user2).Association("Pets").Delete(&Pet{}); err != nil { - t.Fatalf("Error happened when delete pet, got %v", err) - } - AssertAssociationCount(t, user2, "Pets", 1, "after delete non-existing data") - - if err := DB.Model(&user2).Association("Pets").Delete(&pet2); err != nil { - t.Fatalf("Error happened when delete Pets, got %v", err) - } - AssertAssociationCount(t, user2, "Pets", 0, "after delete") - - // Prepare Data for Clear - if err := DB.Model(&user2).Association("Pets").Append(&pet); err != nil { - t.Fatalf("Error happened when append Pets, got %v", err) - } - - AssertAssociationCount(t, user2, "Pets", 1, "after prepare data") - - // Clear - if err := DB.Model(&user2).Association("Pets").Clear(); err != nil { - t.Errorf("Error happened when clear Pets, got %v", err) - } - - AssertAssociationCount(t, user2, "Pets", 0, "after clear") -} - -func TestSingleTableHasManyAssociation(t *testing.T) { - var user = *GetUser("hasmany", Config{Team: 2}) - - if err := DB.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - CheckUser(t, user, user) - - // Find - var user2 User - DB.Find(&user2, "id = ?", user.ID) - DB.Model(&user2).Association("Team").Find(&user2.Team) - CheckUser(t, user2, user) - - // Count - AssertAssociationCount(t, user, "Team", 2, "") - - // Append - var team = *GetUser("team", Config{}) - - if err := DB.Model(&user2).Association("Team").Append(&team); err != nil { - t.Fatalf("Error happened when append account, got %v", err) - } - - if team.ID == 0 { - t.Fatalf("Team's ID should be created") - } - - user.Team = append(user.Team, team) - CheckUser(t, user2, user) - - AssertAssociationCount(t, user, "Team", 3, "AfterAppend") - - var teams = []User{*GetUser("team-append-1", Config{}), *GetUser("team-append-2", Config{})} - - if err := DB.Model(&user2).Association("Team").Append(&teams); err != nil { - t.Fatalf("Error happened when append team, got %v", err) - } - - for _, team := range teams { - var team = team - if team.ID == 0 { - t.Fatalf("Team's ID should be created") - } - - user.Team = append(user.Team, team) - } - - CheckUser(t, user2, user) - - AssertAssociationCount(t, user, "Team", 5, "AfterAppendSlice") - - // Replace - var team2 = *GetUser("team-replace", Config{}) - - if err := DB.Model(&user2).Association("Team").Replace(&team2); err != nil { - t.Fatalf("Error happened when append team, got %v", err) - } - - if team2.ID == 0 { - t.Fatalf("team2's ID should be created") - } - - user.Team = []User{team2} - CheckUser(t, user2, user) - - AssertAssociationCount(t, user2, "Team", 1, "AfterReplace") - - // Delete - if err := DB.Model(&user2).Association("Team").Delete(&User{}); err != nil { - t.Fatalf("Error happened when delete team, got %v", err) - } - AssertAssociationCount(t, user2, "Team", 1, "after delete non-existing data") - - if err := DB.Model(&user2).Association("Team").Delete(&team2); err != nil { - t.Fatalf("Error happened when delete Team, got %v", err) - } - AssertAssociationCount(t, user2, "Team", 0, "after delete") - - // Prepare Data for Clear - if err := DB.Model(&user2).Association("Team").Append(&team); err != nil { - t.Fatalf("Error happened when append Team, got %v", err) - } - - AssertAssociationCount(t, user2, "Team", 1, "after prepare data") - - // Clear - if err := DB.Model(&user2).Association("Team").Clear(); err != nil { - t.Errorf("Error happened when clear Team, got %v", err) - } - - AssertAssociationCount(t, user2, "Team", 0, "after clear") -} - -func TestHasManyAssociationForSlice(t *testing.T) { - var users = []User{ - *GetUser("slice-hasmany-1", Config{Pets: 2}), - *GetUser("slice-hasmany-2", Config{Pets: 0}), - *GetUser("slice-hasmany-3", Config{Pets: 4}), - } - - DB.Create(&users) - - // Count - AssertAssociationCount(t, users, "Pets", 6, "") - - // Find - var pets []Pet - if DB.Model(&users).Association("Pets").Find(&pets); len(pets) != 6 { - t.Errorf("pets count should be %v, but got %v", 6, len(pets)) - } - - // Append - DB.Model(&users).Association("Pets").Append( - &Pet{Name: "pet-slice-append-1"}, - []*Pet{{Name: "pet-slice-append-2-1"}, {Name: "pet-slice-append-2-2"}}, - &Pet{Name: "pet-slice-append-3"}, - ) - - AssertAssociationCount(t, users, "Pets", 10, "After Append") - - // Replace -> same as append - DB.Model(&users).Association("Pets").Replace( - []*Pet{{Name: "pet-slice-replace-1-1"}, {Name: "pet-slice-replace-1-2"}}, - []*Pet{{Name: "pet-slice-replace-2-1"}, {Name: "pet-slice-replace-2-2"}}, - &Pet{Name: "pet-slice-replace-3"}, - ) - - AssertAssociationCount(t, users, "Pets", 5, "After Append") - - // Delete - if err := DB.Model(&users).Association("Pets").Delete(&users[2].Pets); err != nil { - t.Errorf("no error should happend when deleting pet, but got %v", err) - } - - AssertAssociationCount(t, users, "Pets", 4, "after delete") - - if err := DB.Model(&users).Association("Pets").Delete(users[0].Pets[0], users[1].Pets[1]); err != nil { - t.Errorf("no error should happend when deleting pet, but got %v", err) - } - - AssertAssociationCount(t, users, "Pets", 2, "after delete") - - // Clear - DB.Model(&users).Association("Pets").Clear() - AssertAssociationCount(t, users, "Pets", 0, "After Clear") -} - -func TestSingleTableHasManyAssociationForSlice(t *testing.T) { - var users = []User{ - *GetUser("slice-hasmany-1", Config{Team: 2}), - *GetUser("slice-hasmany-2", Config{Team: 0}), - *GetUser("slice-hasmany-3", Config{Team: 4}), - } - - if err := DB.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - // Count - AssertAssociationCount(t, users, "Team", 6, "") - - // Find - var teams []User - if DB.Model(&users).Association("Team").Find(&teams); len(teams) != 6 { - t.Errorf("teams count should be %v, but got %v", 6, len(teams)) - } - - // Append - DB.Model(&users).Association("Team").Append( - &User{Name: "pet-slice-append-1"}, - []*User{{Name: "pet-slice-append-2-1"}, {Name: "pet-slice-append-2-2"}}, - &User{Name: "pet-slice-append-3"}, - ) - - AssertAssociationCount(t, users, "Team", 10, "After Append") - - // Replace -> same as append - DB.Model(&users).Association("Team").Replace( - []*User{{Name: "pet-slice-replace-1-1"}, {Name: "pet-slice-replace-1-2"}}, - []*User{{Name: "pet-slice-replace-2-1"}, {Name: "pet-slice-replace-2-2"}}, - &User{Name: "pet-slice-replace-3"}, - ) - - AssertAssociationCount(t, users, "Team", 5, "After Append") - - // Delete - if err := DB.Model(&users).Association("Team").Delete(&users[2].Team); err != nil { - t.Errorf("no error should happend when deleting pet, but got %v", err) - } - - AssertAssociationCount(t, users, "Team", 4, "after delete") - - if err := DB.Model(&users).Association("Team").Delete(users[0].Team[0], users[1].Team[1]); err != nil { - t.Errorf("no error should happend when deleting pet, but got %v", err) - } - - AssertAssociationCount(t, users, "Team", 2, "after delete") - - // Clear - DB.Model(&users).Association("Team").Clear() - AssertAssociationCount(t, users, "Team", 0, "After Clear") -} - -func TestPolymorphicHasManyAssociation(t *testing.T) { - var user = *GetUser("hasmany", Config{Toys: 2}) - - if err := DB.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - CheckUser(t, user, user) - - // Find - var user2 User - DB.Find(&user2, "id = ?", user.ID) - DB.Model(&user2).Association("Toys").Find(&user2.Toys) - CheckUser(t, user2, user) - - // Count - AssertAssociationCount(t, user, "Toys", 2, "") - - // Append - var toy = Toy{Name: "toy-has-many-append"} - - if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil { - t.Fatalf("Error happened when append account, got %v", err) - } - - if toy.ID == 0 { - t.Fatalf("Toy's ID should be created") - } - - user.Toys = append(user.Toys, toy) - CheckUser(t, user2, user) - - AssertAssociationCount(t, user, "Toys", 3, "AfterAppend") - - var toys = []Toy{{Name: "toy-has-many-append-1-1"}, {Name: "toy-has-many-append-1-1"}} - - if err := DB.Model(&user2).Association("Toys").Append(&toys); err != nil { - t.Fatalf("Error happened when append toy, got %v", err) - } - - for _, toy := range toys { - var toy = toy - if toy.ID == 0 { - t.Fatalf("Toy's ID should be created") - } - - user.Toys = append(user.Toys, toy) - } - - CheckUser(t, user2, user) - - AssertAssociationCount(t, user, "Toys", 5, "AfterAppendSlice") - - // Replace - var toy2 = Toy{Name: "toy-has-many-replace"} - - if err := DB.Model(&user2).Association("Toys").Replace(&toy2); err != nil { - t.Fatalf("Error happened when append toy, got %v", err) - } - - if toy2.ID == 0 { - t.Fatalf("toy2's ID should be created") - } - - user.Toys = []Toy{toy2} - CheckUser(t, user2, user) - - AssertAssociationCount(t, user2, "Toys", 1, "AfterReplace") - - // Delete - if err := DB.Model(&user2).Association("Toys").Delete(&Toy{}); err != nil { - t.Fatalf("Error happened when delete toy, got %v", err) - } - AssertAssociationCount(t, user2, "Toys", 1, "after delete non-existing data") - - if err := DB.Model(&user2).Association("Toys").Delete(&toy2); err != nil { - t.Fatalf("Error happened when delete Toys, got %v", err) - } - AssertAssociationCount(t, user2, "Toys", 0, "after delete") - - // Prepare Data for Clear - if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil { - t.Fatalf("Error happened when append Toys, got %v", err) - } - - AssertAssociationCount(t, user2, "Toys", 1, "after prepare data") - - // Clear - if err := DB.Model(&user2).Association("Toys").Clear(); err != nil { - t.Errorf("Error happened when clear Toys, got %v", err) - } - - AssertAssociationCount(t, user2, "Toys", 0, "after clear") -} - -func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { - var users = []User{ - *GetUser("slice-hasmany-1", Config{Toys: 2}), - *GetUser("slice-hasmany-2", Config{Toys: 0}), - *GetUser("slice-hasmany-3", Config{Toys: 4}), - } - - DB.Create(&users) - - // Count - AssertAssociationCount(t, users, "Toys", 6, "") - - // Find - var toys []Toy - if DB.Model(&users).Association("Toys").Find(&toys); len(toys) != 6 { - t.Errorf("toys count should be %v, but got %v", 6, len(toys)) - } - - // Append - DB.Model(&users).Association("Toys").Append( - &Toy{Name: "toy-slice-append-1"}, - []Toy{{Name: "toy-slice-append-2-1"}, {Name: "toy-slice-append-2-2"}}, - &Toy{Name: "toy-slice-append-3"}, - ) - - AssertAssociationCount(t, users, "Toys", 10, "After Append") - - // Replace -> same as append - DB.Model(&users).Association("Toys").Replace( - []*Toy{{Name: "toy-slice-replace-1-1"}, {Name: "toy-slice-replace-1-2"}}, - []*Toy{{Name: "toy-slice-replace-2-1"}, {Name: "toy-slice-replace-2-2"}}, - &Toy{Name: "toy-slice-replace-3"}, - ) - - AssertAssociationCount(t, users, "Toys", 5, "After Append") - - // Delete - if err := DB.Model(&users).Association("Toys").Delete(&users[2].Toys); err != nil { - t.Errorf("no error should happend when deleting toy, but got %v", err) - } - - AssertAssociationCount(t, users, "Toys", 4, "after delete") - - if err := DB.Model(&users).Association("Toys").Delete(users[0].Toys[0], users[1].Toys[1]); err != nil { - t.Errorf("no error should happend when deleting toy, but got %v", err) - } - - AssertAssociationCount(t, users, "Toys", 2, "after delete") - - // Clear - DB.Model(&users).Association("Toys").Clear() - AssertAssociationCount(t, users, "Toys", 0, "After Clear") -} - -func TestMany2ManyAssociation(t *testing.T) { - var user = *GetUser("many2many", Config{Languages: 2}) - - if err := DB.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - CheckUser(t, user, user) - - // Find - var user2 User - DB.Find(&user2, "id = ?", user.ID) - DB.Model(&user2).Association("Languages").Find(&user2.Languages) - - CheckUser(t, user2, user) - - // Count - AssertAssociationCount(t, user, "Languages", 2, "") - - // Append - var language = Language{Code: "language-many2many-append", Name: "language-many2many-append"} - DB.Create(&language) - - if err := DB.Model(&user2).Association("Languages").Append(&language); err != nil { - t.Fatalf("Error happened when append account, got %v", err) - } - - user.Languages = append(user.Languages, language) - CheckUser(t, user2, user) - - AssertAssociationCount(t, user, "Languages", 3, "AfterAppend") - - var languages = []Language{ - {Code: "language-many2many-append-1-1", Name: "language-many2many-append-1-1"}, - {Code: "language-many2many-append-2-1", Name: "language-many2many-append-2-1"}, - } - DB.Create(&languages) - - if err := DB.Model(&user2).Association("Languages").Append(&languages); err != nil { - t.Fatalf("Error happened when append language, got %v", err) - } - - user.Languages = append(user.Languages, languages...) - - CheckUser(t, user2, user) - - AssertAssociationCount(t, user, "Languages", 5, "AfterAppendSlice") - - // Replace - var language2 = Language{Code: "language-many2many-replace", Name: "language-many2many-replace"} - DB.Create(&language2) - - if err := DB.Model(&user2).Association("Languages").Replace(&language2); err != nil { - t.Fatalf("Error happened when append language, got %v", err) - } - - user.Languages = []Language{language2} - CheckUser(t, user2, user) - - AssertAssociationCount(t, user2, "Languages", 1, "AfterReplace") - - // Delete - if err := DB.Model(&user2).Association("Languages").Delete(&Language{}); err != nil { - t.Fatalf("Error happened when delete language, got %v", err) - } - AssertAssociationCount(t, user2, "Languages", 1, "after delete non-existing data") - - if err := DB.Model(&user2).Association("Languages").Delete(&language2); err != nil { - t.Fatalf("Error happened when delete Languages, got %v", err) - } - AssertAssociationCount(t, user2, "Languages", 0, "after delete") - - // Prepare Data for Clear - if err := DB.Model(&user2).Association("Languages").Append(&language); err != nil { - t.Fatalf("Error happened when append Languages, got %v", err) - } - - AssertAssociationCount(t, user2, "Languages", 1, "after prepare data") - - // Clear - if err := DB.Model(&user2).Association("Languages").Clear(); err != nil { - t.Errorf("Error happened when clear Languages, got %v", err) - } - - AssertAssociationCount(t, user2, "Languages", 0, "after clear") -} - -func TestMany2ManyAssociationForSlice(t *testing.T) { - var users = []User{ - *GetUser("slice-many2many-1", Config{Languages: 2}), - *GetUser("slice-many2many-2", Config{Languages: 0}), - *GetUser("slice-many2many-3", Config{Languages: 4}), - } - - DB.Create(&users) - - // Count - AssertAssociationCount(t, users, "Languages", 6, "") - - // Find - var languages []Language - if DB.Model(&users).Association("Languages").Find(&languages); len(languages) != 6 { - t.Errorf("languages count should be %v, but got %v", 6, len(languages)) - } - - // Append - var languages1 = []Language{ - {Code: "language-many2many-append-1", Name: "language-many2many-append-1"}, - } - var languages2 = []Language{} - var languages3 = []Language{ - {Code: "language-many2many-append-3-1", Name: "language-many2many-append-3-1"}, - {Code: "language-many2many-append-3-2", Name: "language-many2many-append-3-2"}, - } - DB.Create(&languages1) - DB.Create(&languages3) - - DB.Model(&users).Association("Languages").Append(&languages1, &languages2, &languages3) - - AssertAssociationCount(t, users, "Languages", 9, "After Append") - - languages2_1 := []*Language{ - {Code: "language-slice-replace-1-1", Name: "language-slice-replace-1-1"}, - {Code: "language-slice-replace-1-2", Name: "language-slice-replace-1-2"}, - } - languages2_2 := []*Language{ - {Code: "language-slice-replace-2-1", Name: "language-slice-replace-2-1"}, - {Code: "language-slice-replace-2-2", Name: "language-slice-replace-2-2"}, - } - languages2_3 := &Language{Code: "language-slice-replace-3", Name: "language-slice-replace-3"} - DB.Create(&languages2_1) - DB.Create(&languages2_2) - DB.Create(&languages2_3) - - // Replace - DB.Model(&users).Association("Languages").Replace(&languages2_1, &languages2_2, languages2_3) - - AssertAssociationCount(t, users, "Languages", 5, "After Replace") - - // Delete - if err := DB.Model(&users).Association("Languages").Delete(&users[2].Languages); err != nil { - t.Errorf("no error should happend when deleting language, but got %v", err) - } - - AssertAssociationCount(t, users, "Languages", 4, "after delete") - - if err := DB.Model(&users).Association("Languages").Delete(users[0].Languages[0], users[1].Languages[1]); err != nil { - t.Errorf("no error should happend when deleting language, but got %v", err) - } - - AssertAssociationCount(t, users, "Languages", 2, "after delete") - - // Clear - DB.Model(&users).Association("Languages").Clear() - AssertAssociationCount(t, users, "Languages", 0, "After Clear") -} - -func TestSingleTableMany2ManyAssociation(t *testing.T) { - var user = *GetUser("many2many", Config{Friends: 2}) - - if err := DB.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - CheckUser(t, user, user) - - // Find - var user2 User - DB.Find(&user2, "id = ?", user.ID) - DB.Model(&user2).Association("Friends").Find(&user2.Friends) - - CheckUser(t, user2, user) - - // Count - AssertAssociationCount(t, user, "Friends", 2, "") - - // Append - var friend = *GetUser("friend", Config{}) - - if err := DB.Model(&user2).Association("Friends").Append(&friend); err != nil { - t.Fatalf("Error happened when append account, got %v", err) - } - - user.Friends = append(user.Friends, &friend) - CheckUser(t, user2, user) - - AssertAssociationCount(t, user, "Friends", 3, "AfterAppend") - - var friends = []*User{GetUser("friend-append-1", Config{}), GetUser("friend-append-2", Config{})} - - if err := DB.Model(&user2).Association("Friends").Append(&friends); err != nil { - t.Fatalf("Error happened when append friend, got %v", err) - } - - user.Friends = append(user.Friends, friends...) - - CheckUser(t, user2, user) - - AssertAssociationCount(t, user, "Friends", 5, "AfterAppendSlice") - - // Replace - var friend2 = *GetUser("friend-replace-2", Config{}) - - if err := DB.Model(&user2).Association("Friends").Replace(&friend2); err != nil { - t.Fatalf("Error happened when append friend, got %v", err) - } - - user.Friends = []*User{&friend2} - CheckUser(t, user2, user) - - AssertAssociationCount(t, user2, "Friends", 1, "AfterReplace") - - // Delete - if err := DB.Model(&user2).Association("Friends").Delete(&User{}); err != nil { - t.Fatalf("Error happened when delete friend, got %v", err) - } - AssertAssociationCount(t, user2, "Friends", 1, "after delete non-existing data") - - if err := DB.Model(&user2).Association("Friends").Delete(&friend2); err != nil { - t.Fatalf("Error happened when delete Friends, got %v", err) - } - AssertAssociationCount(t, user2, "Friends", 0, "after delete") - - // Prepare Data for Clear - if err := DB.Model(&user2).Association("Friends").Append(&friend); err != nil { - t.Fatalf("Error happened when append Friends, got %v", err) - } - - AssertAssociationCount(t, user2, "Friends", 1, "after prepare data") - - // Clear - if err := DB.Model(&user2).Association("Friends").Clear(); err != nil { - t.Errorf("Error happened when clear Friends, got %v", err) - } - - AssertAssociationCount(t, user2, "Friends", 0, "after clear") -} - -func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) { - var users = []User{ - *GetUser("slice-many2many-1", Config{Team: 2}), - *GetUser("slice-many2many-2", Config{Team: 0}), - *GetUser("slice-many2many-3", Config{Team: 4}), - } - - DB.Create(&users) - - // Count - AssertAssociationCount(t, users, "Team", 6, "") - - // Find - var teams []User - if DB.Model(&users).Association("Team").Find(&teams); len(teams) != 6 { - t.Errorf("teams count should be %v, but got %v", 6, len(teams)) - } - - // Append - var teams1 = []User{*GetUser("friend-append-1", Config{})} - var teams2 = []User{} - var teams3 = []*User{GetUser("friend-append-3-1", Config{}), GetUser("friend-append-3-2", Config{})} - - DB.Model(&users).Association("Team").Append(&teams1, &teams2, &teams3) - - AssertAssociationCount(t, users, "Team", 9, "After Append") - - var teams2_1 = []User{*GetUser("friend-replace-1", Config{}), *GetUser("friend-replace-2", Config{})} - var teams2_2 = []User{*GetUser("friend-replace-2-1", Config{}), *GetUser("friend-replace-2-2", Config{})} - var teams2_3 = GetUser("friend-replace-3-1", Config{}) - - // Replace - DB.Model(&users).Association("Team").Replace(&teams2_1, &teams2_2, teams2_3) - - AssertAssociationCount(t, users, "Team", 5, "After Replace") - - // Delete - if err := DB.Model(&users).Association("Team").Delete(&users[2].Team); err != nil { - t.Errorf("no error should happend when deleting team, but got %v", err) - } - - AssertAssociationCount(t, users, "Team", 4, "after delete") - - if err := DB.Model(&users).Association("Team").Delete(users[0].Team[0], users[1].Team[1]); err != nil { - t.Errorf("no error should happend when deleting team, but got %v", err) - } - - AssertAssociationCount(t, users, "Team", 2, "after delete") - - // Clear - DB.Model(&users).Association("Team").Clear() - AssertAssociationCount(t, users, "Team", 0, "After Clear") } From 51c5be05039ff1ca287d1353b3bd539f5984f032 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 26 May 2020 21:30:17 +0800 Subject: [PATCH 335/881] Finish Scan support --- callbacks/query.go | 6 +++++- finisher_api.go | 22 ++++++++++++---------- tests/scan_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 11 deletions(-) create mode 100644 tests/scan_test.go diff --git a/callbacks/query.go b/callbacks/query.go index 95b5ead3..c9fa160f 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -90,7 +90,11 @@ func Query(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.From{}) } - db.Statement.AddClause(clauseSelect) + if len(clauseSelect.Columns) > 0 { + db.Statement.AddClause(clauseSelect) + } else { + db.Statement.AddClauseIfNotExists(clauseSelect) + } db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } diff --git a/finisher_api.go b/finisher_api.go index c64ecdda..84168e23 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -48,7 +48,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { } // First find first record that match given conditions, order by primary key -func (db *DB) First(out interface{}, conds ...interface{}) (tx *DB) { +func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) @@ -56,25 +56,25 @@ func (db *DB) First(out interface{}, conds ...interface{}) (tx *DB) { tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) } tx.Statement.RaiseErrorOnNotFound = true - tx.Statement.Dest = out + tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) return } // Take return a record that match given conditions, the order will depend on the database implementation -func (db *DB) Take(out interface{}, conds ...interface{}) (tx *DB) { +func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1) if len(conds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) } tx.Statement.RaiseErrorOnNotFound = true - tx.Statement.Dest = out + tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) return } // Last find last record that match given conditions, order by primary key -func (db *DB) Last(out interface{}, conds ...interface{}) (tx *DB) { +func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Desc: true, @@ -83,28 +83,28 @@ func (db *DB) Last(out interface{}, conds ...interface{}) (tx *DB) { tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) } tx.Statement.RaiseErrorOnNotFound = true - tx.Statement.Dest = out + tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) return } // Find find records that match given conditions -func (db *DB) Find(out interface{}, conds ...interface{}) (tx *DB) { +func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) } - tx.Statement.Dest = out + tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) return } -func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) { +func (db *DB) FirstOrInit(dest interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() return } -func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { +func (db *DB) FirstOrCreate(dest interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() return } @@ -181,6 +181,8 @@ func (db *DB) Rows() (*sql.Rows, error) { // Scan scan value to a struct func (db *DB) Scan(dest interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.Dest = dest + tx.callbacks.Query().Execute(tx) return } diff --git a/tests/scan_test.go b/tests/scan_test.go new file mode 100644 index 00000000..f7a14636 --- /dev/null +++ b/tests/scan_test.go @@ -0,0 +1,40 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestScan(t *testing.T) { + user1 := User{Name: "ScanUser1", Age: 1} + user2 := User{Name: "ScanUser2", Age: 10} + user3 := User{Name: "ScanUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + type result struct { + Name string + Age int + } + + var res result + DB.Table("users").Select("name, age").Where("id = ?", user3.ID).Scan(&res) + if res.Name != user3.Name || res.Age != int(user3.Age) { + t.Errorf("Scan into struct should work") + } + + var doubleAgeRes = &result{} + if err := DB.Debug().Table("users").Select("age + age as age").Where("id = ?", user3.ID).Scan(&doubleAgeRes).Error; err != nil { + t.Errorf("Scan to pointer of pointer") + } + + if doubleAgeRes.Age != int(res.Age)*2 { + t.Errorf("Scan double age as age, expect: %v, got %v", res.Age*2, doubleAgeRes.Age) + } + + var ress []result + DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&ress) + if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { + t.Errorf("Scan into struct map") + } +} From 5be642a435afd43d0346c81a1c50da4e205c23f5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 26 May 2020 23:13:05 +0800 Subject: [PATCH 336/881] Add ScanRows support --- callbacks/query.go | 2 +- finisher_api.go | 9 +++++-- callbacks/scan.go => scan.go | 20 +++++++++------- tests/scan_test.go | 46 ++++++++++++++++++++++++++++++++---- 4 files changed, 61 insertions(+), 16 deletions(-) rename callbacks/scan.go => scan.go (91%) diff --git a/callbacks/query.go b/callbacks/query.go index c9fa160f..84b9ed98 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -105,7 +105,7 @@ func Query(db *gorm.DB) { } defer rows.Close() - Scan(rows, db) + gorm.Scan(rows, db, false) } func Preload(db *gorm.DB) { diff --git a/finisher_api.go b/finisher_api.go index 84168e23..04b25ed2 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -186,8 +186,13 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { return } -func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error { - return nil +func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { + tx := db.getInstance() + tx.Error = tx.Statement.Parse(dest) + tx.Statement.Dest = dest + tx.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(dest)) + Scan(rows, tx, true) + return tx.Error } // Transaction start a transaction as a block, return error will rollback, otherwise to commit. diff --git a/callbacks/scan.go b/scan.go similarity index 91% rename from callbacks/scan.go rename to scan.go index 9ffcab4a..d2169f87 100644 --- a/callbacks/scan.go +++ b/scan.go @@ -1,15 +1,14 @@ -package callbacks +package gorm import ( "database/sql" "reflect" "strings" - "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/schema" ) -func Scan(rows *sql.Rows, db *gorm.DB) { +func Scan(rows *sql.Rows, db *DB, initialized bool) { columns, _ := rows.Columns() values := make([]interface{}, len(columns)) @@ -19,7 +18,7 @@ func Scan(rows *sql.Rows, db *gorm.DB) { values[idx] = new(interface{}) } - if rows.Next() { + if initialized || rows.Next() { db.RowsAffected++ rows.Scan(values...) } @@ -39,7 +38,8 @@ func Scan(rows *sql.Rows, db *gorm.DB) { values[idx] = new(interface{}) } - for rows.Next() { + for initialized || rows.Next() { + initialized = false db.RowsAffected++ rows.Scan(values...) @@ -50,7 +50,8 @@ func Scan(rows *sql.Rows, db *gorm.DB) { *dest = append(*dest, v) } case *int, *int64, *uint, *uint64: - for rows.Next() { + for initialized || rows.Next() { + initialized = false db.RowsAffected++ rows.Scan(dest) } @@ -78,7 +79,8 @@ func Scan(rows *sql.Rows, db *gorm.DB) { } } - for rows.Next() { + for initialized || rows.Next() { + initialized = false elem := reflect.New(db.Statement.Schema.ModelType).Elem() for idx, field := range fields { if field != nil { @@ -118,7 +120,7 @@ func Scan(rows *sql.Rows, db *gorm.DB) { } } - if rows.Next() { + if initialized || rows.Next() { db.RowsAffected++ if err := rows.Scan(values...); err != nil { db.AddError(err) @@ -128,6 +130,6 @@ func Scan(rows *sql.Rows, db *gorm.DB) { } if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound { - db.AddError(gorm.ErrRecordNotFound) + db.AddError(ErrRecordNotFound) } } diff --git a/tests/scan_test.go b/tests/scan_test.go index f7a14636..fc6c1721 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -1,6 +1,9 @@ package tests_test import ( + "reflect" + "sort" + "strings" "testing" . "github.com/jinzhu/gorm/tests" @@ -24,7 +27,7 @@ func TestScan(t *testing.T) { } var doubleAgeRes = &result{} - if err := DB.Debug().Table("users").Select("age + age as age").Where("id = ?", user3.ID).Scan(&doubleAgeRes).Error; err != nil { + if err := DB.Table("users").Select("age + age as age").Where("id = ?", user3.ID).Scan(&doubleAgeRes).Error; err != nil { t.Errorf("Scan to pointer of pointer") } @@ -32,9 +35,44 @@ func TestScan(t *testing.T) { t.Errorf("Scan double age as age, expect: %v, got %v", res.Age*2, doubleAgeRes.Age) } - var ress []result - DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&ress) - if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { + var results []result + DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results) + + sort.Slice(results, func(i, j int) bool { + return strings.Compare(results[i].Name, results[j].Name) < -1 + }) + + if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { t.Errorf("Scan into struct map") } } + +func TestScanRows(t *testing.T) { + user1 := User{Name: "ScanRowsUser1", Age: 1} + user2 := User{Name: "ScanRowsUser2", Age: 10} + user3 := User{Name: "ScanRowsUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() + if err != nil { + t.Errorf("Not error should happen, got %v", err) + } + + type Result struct { + Name string + Age int + } + + var results []Result + for rows.Next() { + var result Result + if err := DB.ScanRows(rows, &result); err != nil { + t.Errorf("should get no error, but got %v", err) + } + results = append(results, result) + } + + if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { + t.Errorf("Should find expected results") + } +} From ac8708b5008bff7459701dc7485300919df4dbbb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 28 May 2020 13:12:56 +0800 Subject: [PATCH 337/881] Add FirstOrInit support --- chainable_api.go | 6 +++-- clause/expression.go | 11 --------- finisher_api.go | 46 +++++++++++++++++++++++++++++++++++- statement.go | 49 ++++++++++++++++++++++++-------------- tests/upsert_test.go | 56 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 136 insertions(+), 32 deletions(-) create mode 100644 tests/upsert_test.go diff --git a/chainable_api.go b/chainable_api.go index 6b91c9ad..8336b787 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -224,13 +224,15 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { return } -func (db *DB) Assign(attrs ...interface{}) (tx *DB) { +func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.attrs = attrs return } -func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { +func (db *DB) Assign(attrs ...interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.assigns = attrs return } diff --git a/clause/expression.go b/clause/expression.go index 872736ce..067774d4 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -171,14 +171,3 @@ func (like Like) NegationBuild(builder Builder) { builder.WriteString(" NOT LIKE ") builder.AddVar(builder, like.Value) } - -// Map -type Map map[interface{}]interface{} - -func (m Map) Build(builder Builder) { - // TODO -} - -func (m Map) NegationBuild(builder Builder) { - // TODO -} diff --git a/finisher_api.go b/finisher_api.go index 04b25ed2..2590e422 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -2,6 +2,7 @@ package gorm import ( "database/sql" + "errors" "reflect" "strings" @@ -99,13 +100,56 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { return } -func (db *DB) FirstOrInit(dest interface{}, where ...interface{}) (tx *DB) { +func (tx *DB) assignExprsToValue(exprs []clause.Expression) { + for _, expr := range exprs { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + if field := tx.Statement.Schema.LookUpField(column); field != nil { + field.Set(tx.Statement.ReflectValue, eq.Value) + } + case clause.Column: + if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { + field.Set(tx.Statement.ReflectValue, eq.Value) + } + default: + } + } + } +} + +func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() + if tx = tx.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { + if c, ok := tx.Statement.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok { + tx.assignExprsToValue(where.Exprs) + } + } + + // initialize with attrs, conds + if len(tx.Statement.attrs) > 0 { + exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:]) + tx.assignExprsToValue(exprs) + } + tx.Error = nil + } + + // initialize with attrs, conds + if len(tx.Statement.assigns) > 0 { + exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]) + tx.assignExprsToValue(exprs) + } return } func (db *DB) FirstOrCreate(dest interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() + // if err := tx.First(dest, conds...).Error; errors.Is(err, ErrRecordNotFound) { + // // initialize with attrs, conds + // } + + // assign dest return } diff --git a/statement.go b/statement.go index d37622dd..51dea6fc 100644 --- a/statement.go +++ b/statement.go @@ -34,6 +34,8 @@ type Statement struct { SQL strings.Builder Vars []interface{} NamedVars []sql.NamedArg + attrs []interface{} + assigns []interface{} } // StatementModifier statement modifier interface @@ -195,7 +197,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { } // BuildCondtion build condition -func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) { +func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conds []clause.Expression) { if sql, ok := query.(string); ok { if i, err := strconv.Atoi(sql); err == nil { query = i @@ -212,42 +214,53 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con switch v := arg.(type) { case clause.Expression: - conditions = append(conditions, v) + conds = append(conds, v) case *DB: if v.Statement == nil { if cs, ok := v.Statement.Clauses["WHERE"]; ok { - conditions = append(conditions, cs.Expression) + conds = append(conds, cs.Expression) } } case map[interface{}]interface{}: - var clauseMap = clause.Map{} for i, j := range v { - clauseMap[i] = j + conds = append(conds, clause.Eq{Column: i, Value: j}) } - conditions = append(conditions, clauseMap) case map[string]string: - var clauseMap = clause.Map{} for i, j := range v { - clauseMap[i] = j + conds = append(conds, clause.Eq{Column: i, Value: j}) } - conditions = append(conditions, clauseMap) case map[string]interface{}: - var clauseMap = clause.Map{} for i, j := range v { - clauseMap[i] = j + conds = append(conds, clause.Eq{Column: i, Value: j}) } - conditions = append(conditions, clauseMap) default: - // TODO check is struct - // struct, slice -> ids + reflectValue := reflect.Indirect(reflect.ValueOf(arg)) + if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + switch reflectValue.Kind() { + case reflect.Struct: + for _, field := range s.FieldsByDBName { + if v, isZero := field.ValueOf(reflectValue); !isZero { + conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) + } + } + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + for _, field := range s.FieldsByDBName { + if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { + conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) + } + } + } + } + } } } - if len(conditions) == 0 { - conditions = append(conditions, clause.IN{Column: clause.PrimaryColumn, Values: args}) + if len(conds) == 0 { + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) } - return conditions + return } // Build build sql with clauses names @@ -337,7 +350,7 @@ func (stmt *Statement) reinit() { // return true // }) - stmt.Schema = nil + // stmt.Schema = nil stmt.SQL.Reset() stmt.Vars = nil stmt.NamedVars = nil diff --git a/tests/upsert_test.go b/tests/upsert_test.go new file mode 100644 index 00000000..728550d5 --- /dev/null +++ b/tests/upsert_test.go @@ -0,0 +1,56 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestFindOrInitialize(t *testing.T) { + var user1, user2, user3, user4, user5, user6 User + if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil { + t.Errorf("no error should happen when FirstOrInit, but got %v", err) + } + if user1.Name != "find or init" || user1.ID != 0 || user1.Age != 33 { + t.Errorf("user should be initialized with search value") + } + + DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2) + if user2.Name != "find or init" || user2.ID != 0 || user2.Age != 33 { + t.Errorf("user should be initialized with search value") + } + + DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"}) + if user3.Name != "find or init 2" || user3.ID != 0 { + t.Errorf("user should be initialized with inline search value") + } + + DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4) + if user4.Name != "find or init" || user4.ID != 0 || user4.Age != 44 { + t.Errorf("user should be initialized with search value and attrs") + } + + DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4) + if user4.Name != "find or init" || user4.ID != 0 || user4.Age != 44 { + t.Errorf("user should be initialized with search value and assign attrs") + } + + DB.Save(&User{Name: "find or init", Age: 33}) + DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5) + if user5.Name != "find or init" || user5.ID == 0 || user5.Age != 33 { + t.Errorf("user should be found and not initialized by Attrs") + } + + DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6) + if user6.Name != "find or init" || user6.ID == 0 || user6.Age != 33 { + t.Errorf("user should be found with FirstOrInit") + } + + DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6) + if user6.Name != "find or init" || user6.ID == 0 || user6.Age != 44 { + t.Errorf("user should be found and updated with assigned attrs") + } +} + +func TestFindOrCreate(t *testing.T) { +} From dca5244387642c000bed71b5d0a195b711860cd8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 28 May 2020 16:10:10 +0800 Subject: [PATCH 338/881] Add FirstOrCreate support --- finisher_api.go | 53 +++++++++++++++++++++++++++++++++++------ statement.go | 18 ++++++++++---- tests/upsert_test.go | 56 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 11 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 2590e422..c47e12af 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -129,7 +129,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { // initialize with attrs, conds if len(tx.Statement.attrs) > 0 { - exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:]) + exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) tx.assignExprsToValue(exprs) } tx.Error = nil @@ -137,19 +137,54 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { // initialize with attrs, conds if len(tx.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]) + exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) tx.assignExprsToValue(exprs) } return } -func (db *DB) FirstOrCreate(dest interface{}, where ...interface{}) (tx *DB) { +func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() - // if err := tx.First(dest, conds...).Error; errors.Is(err, ErrRecordNotFound) { - // // initialize with attrs, conds - // } + if err := tx.First(dest, conds...).Error; errors.Is(err, ErrRecordNotFound) { + tx.Error = nil + + if c, ok := tx.Statement.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok { + tx.assignExprsToValue(where.Exprs) + } + } + + // initialize with attrs, conds + if len(tx.Statement.attrs) > 0 { + exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) + tx.assignExprsToValue(exprs) + } + + // initialize with attrs, conds + if len(tx.Statement.assigns) > 0 { + exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) + tx.assignExprsToValue(exprs) + } + + return tx.Create(dest) + } else if len(tx.Statement.assigns) > 0 { + exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]) + assigns := map[string]interface{}{} + for _, expr := range exprs { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + assigns[column] = eq.Value + case clause.Column: + assigns[column.Name] = eq.Value + default: + } + } + } + + return tx.Model(dest).Updates(assigns) + } - // assign dest return } @@ -307,3 +342,7 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx.callbacks.Raw().Execute(tx) return } + +func (db *DB) RecordNotFound() bool { + return errors.Is(db.Error, ErrRecordNotFound) +} diff --git a/statement.go b/statement.go index 51dea6fc..b110ac1b 100644 --- a/statement.go +++ b/statement.go @@ -203,6 +203,8 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con query = i } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} + } else if len(args) == 1 { + return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}} } } @@ -238,16 +240,24 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { switch reflectValue.Kind() { case reflect.Struct: - for _, field := range s.FieldsByDBName { + for _, field := range s.Fields { if v, isZero := field.ValueOf(reflectValue); !isZero { - conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) + if field.DBName == "" { + conds = append(conds, clause.Eq{Column: field.Name, Value: v}) + } else { + conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) + } } } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { - for _, field := range s.FieldsByDBName { + for _, field := range s.Fields { if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { - conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) + if field.DBName == "" { + conds = append(conds, clause.Eq{Column: field.Name, Value: v}) + } else { + conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) + } } } } diff --git a/tests/upsert_test.go b/tests/upsert_test.go index 728550d5..bd540620 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -2,6 +2,7 @@ package tests_test import ( "testing" + "time" . "github.com/jinzhu/gorm/tests" ) @@ -53,4 +54,59 @@ func TestFindOrInitialize(t *testing.T) { } func TestFindOrCreate(t *testing.T) { + var user1, user2, user3, user4, user5, user6, user7, user8 User + DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1) + if user1.Name != "find or create" || user1.ID == 0 || user1.Age != 33 { + t.Errorf("user should be created with search value") + } + + DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user2) + if user1.ID != user2.ID || user2.Name != "find or create" || user2.ID == 0 || user2.Age != 33 { + t.Errorf("user should be created with search value") + } + + DB.FirstOrCreate(&user3, map[string]interface{}{"name": "find or create 2"}) + if user3.Name != "find or create 2" || user3.ID == 0 { + t.Errorf("user should be created with inline search value") + } + + DB.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4) + if user4.Name != "find or create 3" || user4.ID == 0 || user4.Age != 44 { + t.Errorf("user should be created with search value and attrs") + } + + updatedAt1 := user4.UpdatedAt + DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4) + if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("UpdateAt should be changed when update values with assign") + } + + DB.Where(&User{Name: "find or create 4"}).Assign(User{Age: 44}).FirstOrCreate(&user4) + if user4.Name != "find or create 4" || user4.ID == 0 || user4.Age != 44 { + t.Errorf("user should be created with search value and assigned attrs") + } + + DB.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5) + if user5.Name != "find or create" || user5.ID == 0 || user5.Age != 33 { + t.Errorf("user should be found and not initialized by Attrs") + } + + DB.Where(&User{Name: "find or create"}).Assign(User{Age: 44}).FirstOrCreate(&user6) + if user6.Name != "find or create" || user6.ID == 0 || user6.Age != 44 { + t.Errorf("user should be found and updated with assigned attrs") + } + + DB.Where(&User{Name: "find or create"}).Find(&user7) + if user7.Name != "find or create" || user7.ID == 0 || user7.Age != 44 { + t.Errorf("user should be found and updated with assigned attrs") + } + + DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, Account: Account{Number: "1231231231"}, Pets: []*Pet{{Name: "first_or_create_pet1"}, {Name: "first_or_create_pet2"}}}).FirstOrCreate(&user8) + if DB.Where("name = ?", "first_or_create_pet1").First(&Pet{}).RecordNotFound() { + t.Errorf("has many association should be saved") + } + + if DB.Where("number = ?", "1231231231").First(&Account{}).RecordNotFound() { + t.Errorf("belongs to association should be saved") + } } From 55074213bc94fea6c3adc03fd1bdf4f12d7b0472 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 29 May 2020 07:35:45 +0800 Subject: [PATCH 339/881] Add SoftDelete support --- association.go | 12 +++--- callbacks/create.go | 37 ++++++++++++----- callbacks/delete.go | 21 ++++++---- callbacks/query.go | 6 +++ callbacks/update.go | 20 ++++++--- chainable_api.go | 1 + model.go | 2 +- schema/field.go | 22 ++++++++++ schema/schema.go | 16 ++++++++ soft_delete.go | 86 +++++++++++++++++++++++++++++++++++++++ statement.go | 1 + tests/soft_delete_test.go | 28 +++++++++++++ tests/upsert_test.go | 6 ++- 13 files changed, 225 insertions(+), 33 deletions(-) create mode 100644 soft_delete.go create mode 100644 tests/soft_delete_test.go diff --git a/association.go b/association.go index 5b777465..bed89837 100644 --- a/association.go +++ b/association.go @@ -44,11 +44,7 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro tx = association.DB.Model(out) ) - if association.Relationship.JoinTable != nil { - for _, queryClause := range association.Relationship.JoinTable.QueryClauses { - tx.Clauses(queryClause) - } - + if association.Relationship.JoinTable != nil && !tx.Statement.Unscoped { tx.Clauses(clause.From{Joins: []clause.Join{{ Table: clause.Table{Name: association.Relationship.JoinTable.Table}, ON: clause.Where{Exprs: queryConds}, @@ -317,8 +313,10 @@ func (association *Association) Count() (count int64) { ) if association.Relationship.JoinTable != nil { - for _, queryClause := range association.Relationship.JoinTable.QueryClauses { - tx.Clauses(queryClause) + if !tx.Statement.Unscoped { + for _, queryClause := range association.Relationship.JoinTable.QueryClauses { + tx.Clauses(queryClause) + } } tx.Clauses(clause.From{Joins: []clause.Join{{ diff --git a/callbacks/create.go b/callbacks/create.go index 0b30775a..18f25c9a 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -46,12 +46,21 @@ func Create(config *Config) func(db *gorm.DB) { return CreateWithReturning } else { return func(db *gorm.DB) { - db.Statement.AddClauseIfNotExists(clause.Insert{ - Table: clause.Table{Name: db.Statement.Table}, - }) - db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Insert{ + Table: clause.Table{Name: db.Statement.Table}, + }) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + + db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") + } - db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { @@ -88,12 +97,20 @@ func Create(config *Config) func(db *gorm.DB) { } func CreateWithReturning(db *gorm.DB) { - db.Statement.AddClauseIfNotExists(clause.Insert{ - Table: clause.Table{Name: db.Statement.Table}, - }) - db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } + } - db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Insert{ + Table: clause.Table{Name: db.Statement.Table}, + }) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + + db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") + } if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { db.Statement.WriteString(" RETURNING ") diff --git a/callbacks/delete.go b/callbacks/delete.go index a88edcf8..1c59afbe 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -1,6 +1,7 @@ package callbacks import ( + "fmt" "reflect" "github.com/jinzhu/gorm" @@ -34,26 +35,30 @@ func BeforeDelete(db *gorm.DB) { } func Delete(db *gorm.DB) { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.DeleteClauses { + db.Statement.AddClause(c) + fmt.Println(db.Statement.SQL.String()) + } + } + if db.Statement.SQL.String() == "" { db.Statement.AddClauseIfNotExists(clause.Delete{}) - values := []reflect.Value{db.Statement.ReflectValue} - if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { - values = append(values, reflect.ValueOf(db.Statement.Model)) - } - if db.Statement.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { - db.Where(clause.IN{Column: column, Values: values}) - } else if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { + db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + + if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { - db.Where(clause.IN{Column: column, Values: values}) + db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) } } } diff --git a/callbacks/query.go b/callbacks/query.go index 84b9ed98..ee3f5c8d 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -12,6 +12,12 @@ import ( ) func Query(db *gorm.DB) { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.QueryClauses { + db.Statement.AddClause(c) + } + } + if db.Statement.SQL.String() == "" { clauseSelect := clause.Select{} diff --git a/callbacks/update.go b/callbacks/update.go index f9b20981..f56aa22c 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -44,13 +44,21 @@ func BeforeUpdate(db *gorm.DB) { } func Update(db *gorm.DB) { - db.Statement.AddClauseIfNotExists(clause.Update{}) - if set := ConvertToAssignments(db.Statement); len(set) != 0 { - db.Statement.AddClause(set) - } else { - return + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Update{}) + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + db.Statement.AddClause(set) + } else { + return + } + db.Statement.Build("UPDATE", "SET", "WHERE") } - db.Statement.Build("UPDATE", "SET", "WHERE") result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/chainable_api.go b/chainable_api.go index 8336b787..afcdccd2 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -238,6 +238,7 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) { func (db *DB) Unscoped() (tx *DB) { tx = db.getInstance() + tx.Statement.Unscoped = true return } diff --git a/model.go b/model.go index fdee99dc..dcc3cdc2 100644 --- a/model.go +++ b/model.go @@ -11,5 +11,5 @@ type Model struct { ID uint `gorm:"primarykey"` CreatedAt time.Time UpdatedAt time.Time - DeletedAt *time.Time `gorm:"index"` + DeletedAt DeletedAt `gorm:"index"` } diff --git a/schema/field.go b/schema/field.go index 8b8b190d..75ff71f6 100644 --- a/schema/field.go +++ b/schema/field.go @@ -86,6 +86,23 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } fieldValue := reflect.New(field.IndirectFieldType) + + if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses()...) + } + + if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses()...) + } + + if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses()...) + } + + if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses()...) + } + // if field is valuer, used its value or first fields as data type if valuer, isValueOf := fieldValue.Interface().(driver.Valuer); isValueOf { var overrideFieldValue bool @@ -283,6 +300,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.TagSettings[k] = v } } + + field.Schema.CreateClauses = append(field.Schema.CreateClauses, field.EmbeddedSchema.CreateClauses...) + field.Schema.QueryClauses = append(field.Schema.QueryClauses, field.EmbeddedSchema.QueryClauses...) + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, field.EmbeddedSchema.UpdateClauses...) + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, field.EmbeddedSchema.DeleteClauses...) } return field diff --git a/schema/schema.go b/schema/schema.go index e66084a3..77b9832c 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -42,6 +42,22 @@ type Schema struct { cacheStore *sync.Map } +type CreateClausesInterface interface { + CreateClauses() []clause.Interface +} + +type QueryClausesInterface interface { + QueryClauses() []clause.Interface +} + +type UpdateClausesInterface interface { + UpdateClauses() []clause.Interface +} + +type DeleteClausesInterface interface { + DeleteClauses() []clause.Interface +} + func (schema Schema) String() string { if schema.ModelType.Name() == "" { return fmt.Sprintf("%v(%v)", schema.Name, schema.Table) diff --git a/soft_delete.go b/soft_delete.go new file mode 100644 index 00000000..138c9c63 --- /dev/null +++ b/soft_delete.go @@ -0,0 +1,86 @@ +package gorm + +import ( + "database/sql" + "database/sql/driver" + "reflect" + "time" + + "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" +) + +type DeletedAt sql.NullTime + +// Scan implements the Scanner interface. +func (n *DeletedAt) Scan(value interface{}) error { + return (*sql.NullTime)(n).Scan(value) +} + +// Value implements the driver Valuer interface. +func (n DeletedAt) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Time, nil +} + +func (DeletedAt) QueryClauses() []clause.Interface { + return []clause.Interface{ + clause.Where{Exprs: []clause.Expression{ + clause.Eq{ + Column: clause.Column{Table: clause.CurrentTable, Name: "deleted_at"}, + Value: nil, + }, + }}, + } +} + +func (DeletedAt) DeleteClauses() []clause.Interface { + return []clause.Interface{SoftDeleteClause{}} +} + +type SoftDeleteClause struct { +} + +func (SoftDeleteClause) Name() string { + return "" +} + +func (SoftDeleteClause) Build(clause.Builder) { +} + +func (SoftDeleteClause) MergeClause(*clause.Clause) { +} + +func (SoftDeleteClause) ModifyStatement(stmt *Statement) { + if stmt.SQL.String() == "" { + stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: time.Now()}}) + + if stmt.Schema != nil { + _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) + column, values := schema.ToQueryValues(stmt.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + + if stmt.Dest != stmt.Model && stmt.Model != nil { + _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) + column, values = schema.ToQueryValues(stmt.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + } + } + + if _, ok := stmt.Clauses["WHERE"]; !ok { + stmt.DB.AddError(ErrMissingWhereClause) + return + } + + stmt.AddClauseIfNotExists(clause.Update{}) + stmt.Build("UPDATE", "SET", "WHERE") + } +} diff --git a/statement.go b/statement.go index b110ac1b..626ca689 100644 --- a/statement.go +++ b/statement.go @@ -19,6 +19,7 @@ type Statement struct { *DB Table string Model interface{} + Unscoped bool Dest interface{} ReflectValue reflect.Value Clauses map[string]clause.Clause diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go new file mode 100644 index 00000000..f91052c1 --- /dev/null +++ b/tests/soft_delete_test.go @@ -0,0 +1,28 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestSoftDelete(t *testing.T) { + user := *GetUser("SoftDelete", Config{}) + DB.Save(&user) + if err := DB.Delete(&user).Error; err != nil { + t.Fatalf("No error should happen when soft delete user, but got %v", err) + } + + if DB.First(&User{}, "name = ?", user.Name).Error == nil { + t.Errorf("Can't find a soft deleted record") + } + + if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err) + } + + DB.Unscoped().Delete(&user) + if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() { + t.Errorf("Can't find permanently deleted record") + } +} diff --git a/tests/upsert_test.go b/tests/upsert_test.go index bd540620..615ead95 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -12,6 +12,7 @@ func TestFindOrInitialize(t *testing.T) { if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil { t.Errorf("no error should happen when FirstOrInit, but got %v", err) } + if user1.Name != "find or init" || user1.ID != 0 || user1.Age != 33 { t.Errorf("user should be initialized with search value") } @@ -55,7 +56,10 @@ func TestFindOrInitialize(t *testing.T) { func TestFindOrCreate(t *testing.T) { var user1, user2, user3, user4, user5, user6, user7, user8 User - DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1) + if err := DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1).Error; err != nil { + t.Errorf("no error should happen when FirstOrInit, but got %v", err) + } + if user1.Name != "find or create" || user1.ID == 0 || user1.Age != 33 { t.Errorf("user should be created with search value") } From d05128be7868349084a8e3818a2676976cfac97a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 29 May 2020 22:34:35 +0800 Subject: [PATCH 340/881] OnConflict support for mysql --- clause/clause.go | 6 ++---- dialects/mysql/mysql.go | 34 ++++++++++++++++++++++++++++++++++ gorm.go | 4 ++++ statement.go | 2 +- 4 files changed, 41 insertions(+), 5 deletions(-) diff --git a/clause/clause.go b/clause/clause.go index 59b229ce..9a5d1273 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -8,9 +8,7 @@ type Interface interface { } // ClauseBuilder clause builder, allows to custmize how to build clause -type ClauseBuilder interface { - Build(Clause, Builder) -} +type ClauseBuilder func(Clause, Builder) type Writer interface { WriteByte(byte) error @@ -38,7 +36,7 @@ type Clause struct { // Build build clause func (c Clause) Build(builder Builder) { if c.Builder != nil { - c.Builder.Build(c, builder) + c.Builder(c, builder) } else { builders := c.BeforeExpressions if c.Name != "" { diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index 7b8f0491..6ca9f5f5 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -26,9 +26,43 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) db.ConnPool, err = sql.Open("mysql", dialector.DSN) + + for k, v := range dialector.ClauseBuilders() { + db.ClauseBuilders[k] = v + } return } +func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { + return map[string]clause.ClauseBuilder{ + "ON CONFLICT": func(c clause.Clause, builder clause.Builder) { + if onConflict, ok := c.Expression.(clause.OnConflict); ok { + builder.WriteString("ON DUPLICATE KEY UPDATE ") + if len(onConflict.DoUpdates) == 0 { + if s := builder.(*gorm.Statement).Schema; s != nil { + var column clause.Column + onConflict.DoNothing = false + + if s.PrioritizedPrimaryField != nil { + column = clause.Column{Name: s.PrioritizedPrimaryField.DBName} + } else { + for _, field := range s.FieldsByDBName { + column = clause.Column{Name: field.DBName} + break + } + } + onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}} + } + } + + onConflict.DoUpdates.Build(builder) + } else { + c.Build(builder) + } + }, + } +} + func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { return Migrator{migrator.Migrator{Config: migrator.Config{ DB: db, diff --git a/gorm.go b/gorm.go index 1fa69383..942024cf 100644 --- a/gorm.go +++ b/gorm.go @@ -95,6 +95,10 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { db.callbacks = initializeCallbacks(db) + if config.ClauseBuilders == nil { + config.ClauseBuilders = map[string]clause.ClauseBuilder{} + } + if dialector != nil { err = dialector.Initialize(db) } diff --git a/statement.go b/statement.go index 626ca689..f81ae0e5 100644 --- a/statement.go +++ b/statement.go @@ -286,7 +286,7 @@ func (stmt *Statement) Build(clauses ...string) { firstClauseWritten = true if b, ok := stmt.DB.ClauseBuilders[name]; ok { - b.Build(c, stmt) + b(c, stmt) } else { c.Build(stmt) } From 6f4602af11c17d79610386df1112b2bf13fe509b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 29 May 2020 23:38:03 +0800 Subject: [PATCH 341/881] Fix mysql tests --- callbacks/preload.go | 6 +++++- logger/logger.go | 4 ++-- scan.go | 8 ++++++++ schema/field.go | 39 ++++++++++++++++++++++++++++++++------- 4 files changed, 47 insertions(+), 10 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index f48777c2..cfea4f94 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -101,7 +101,11 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } for _, data := range identityMap[utils.ToStringKey(fieldValues...)] { - reflectFieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) + reflectFieldValue := rel.Field.ReflectValueOf(data) + if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { + reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) + } + reflectFieldValue = reflect.Indirect(reflectFieldValue) switch reflectFieldValue.Kind() { case reflect.Struct: rel.Field.Set(data, reflectResults.Index(i).Interface()) diff --git a/logger/logger.go b/logger/logger.go index 24cee821..7121b4fb 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -78,7 +78,7 @@ func New(writer Writer, config Config) Interface { traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" } - return logger{ + return &logger{ Writer: writer, Config: config, infoStr: infoStr, @@ -98,7 +98,7 @@ type logger struct { } // LogMode log mode -func (l logger) LogMode(level LogLevel) Interface { +func (l *logger) LogMode(level LogLevel) Interface { l.LogLevel = level return l } diff --git a/scan.go b/scan.go index d2169f87..c223f6eb 100644 --- a/scan.go +++ b/scan.go @@ -87,6 +87,10 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { values[idx] = field.ReflectValueOf(elem).Addr().Interface() } else if joinFields[idx][0] != nil { relValue := joinFields[idx][0].ReflectValueOf(elem) + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + relValue.Set(reflect.New(relValue.Type().Elem())) + } + values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface() } } @@ -110,6 +114,10 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + relValue.Set(reflect.New(relValue.Type().Elem())) + } + values[idx] = field.ReflectValueOf(relValue).Addr().Interface() continue } diff --git a/schema/field.go b/schema/field.go index 75ff71f6..f4fbad95 100644 --- a/schema/field.go +++ b/schema/field.go @@ -353,9 +353,6 @@ func (field *Field) setupValuerAndSetter() { if field.FieldType.Kind() == reflect.Ptr { field.ReflectValueOf = func(value reflect.Value) reflect.Value { fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(field.FieldType.Elem())) - } return fieldValue } } else { @@ -406,7 +403,14 @@ func (field *Field) setupValuerAndSetter() { return setter(value, v) } } else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { - field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) + fieldValue := field.ReflectValueOf(value) + if fieldValue.IsNil() { + if v == nil { + return nil + } + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) } else if reflectV.Kind() == reflect.Ptr { return field.Set(value, reflectV.Elem().Interface()) } else { @@ -607,12 +611,26 @@ func (field *Field) setupValuerAndSetter() { field.Set = func(value reflect.Value, v interface{}) error { switch data := v.(type) { case time.Time: - field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(v)) + fieldValue := field.ReflectValueOf(value) + if fieldValue.IsNil() { + if v == nil { + return nil + } + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + fieldValue.Elem().Set(reflect.ValueOf(v)) case *time.Time: field.ReflectValueOf(value).Set(reflect.ValueOf(v)) case string: if t, err := now.Parse(data); err == nil { - field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(t)) + fieldValue := field.ReflectValueOf(value) + if fieldValue.IsNil() { + if v == "" { + return nil + } + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + fieldValue.Elem().Set(reflect.ValueOf(t)) } else { return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) } @@ -651,7 +669,14 @@ func (field *Field) setupValuerAndSetter() { if reflectV.Type().ConvertibleTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { - field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) + fieldValue := field.ReflectValueOf(value) + if fieldValue.IsNil() { + if v == nil { + return nil + } + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) From db428f221f8a09c1af532fb248ffffd18082a156 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 30 May 2020 00:16:33 +0800 Subject: [PATCH 342/881] Fix postgres tests --- callbacks/query.go | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index ee3f5c8d..6edfee0b 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -51,29 +51,33 @@ func Query(db *gorm.DB) { for name, conds := range db.Statement.Joins { if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { + tableAliasName := relation.Name + for _, s := range relation.FieldSchema.DBNames { clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: relation.Name, + Table: tableAliasName, Name: s, - Alias: relation.Name + "__" + s, + Alias: tableAliasName + "__" + s, }) } var exprs []clause.Expression for _, ref := range relation.References { if ref.OwnPrimaryKey { - exprs = append(exprs, clause.Expr{ - SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.PrimaryKey.DBName, relation.Name, ref.ForeignKey.DBName), + exprs = append(exprs, clause.Eq{ + Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, }) } else { if ref.PrimaryValue == "" { - exprs = append(exprs, clause.Expr{ - SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.ForeignKey.DBName, relation.Name, ref.PrimaryKey.DBName), + exprs = append(exprs, clause.Eq{ + Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, }) } else { - exprs = append(exprs, clause.Expr{ - SQL: fmt.Sprintf("%s.%s = ?", relation.Name, ref.PrimaryKey.DBName), - Vars: []interface{}{ref.PrimaryValue}, + exprs = append(exprs, clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, }) } } @@ -81,7 +85,7 @@ func Query(db *gorm.DB) { joins = append(joins, clause.Join{ Type: clause.LeftJoin, - Table: clause.Table{Name: relation.FieldSchema.Table, Alias: relation.Name}, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, ON: clause.Where{Exprs: exprs}, }) } else { From c07a08d88bc4ea7fccf90bcc08b6e2264cf0f78c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 30 May 2020 10:43:41 +0800 Subject: [PATCH 343/881] Support mssql --- dialects/mssql/create.go | 95 ++++++++++++++++++++++++++++++++++++++++ dialects/mssql/mssql.go | 28 ++++++++++++ 2 files changed, 123 insertions(+) create mode 100644 dialects/mssql/create.go diff --git a/dialects/mssql/create.go b/dialects/mssql/create.go new file mode 100644 index 00000000..4aecce10 --- /dev/null +++ b/dialects/mssql/create.go @@ -0,0 +1,95 @@ +package mssql + +import ( + "reflect" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/clause" +) + +func Create(db *gorm.DB) { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Insert{ + Table: clause.Table{Name: db.Statement.Table}, + }) + db.Statement.AddClause(callbacks.ConvertToCreateValues(db.Statement)) + + db.Statement.Build("INSERT") + db.Statement.WriteByte(' ') + + c := db.Statement.Clauses["VALUES"] + if values, ok := c.Expression.(clause.Values); ok { + if len(values.Columns) > 0 { + db.Statement.WriteByte('(') + for idx, column := range values.Columns { + if idx > 0 { + db.Statement.WriteByte(',') + } + db.Statement.WriteQuoted(column) + } + db.Statement.WriteByte(')') + + if db.Statement.Schema.PrioritizedPrimaryField != nil { + db.Statement.WriteString(" OUTPUT INSERTED.") + db.Statement.WriteQuoted(db.Statement.Schema.PrioritizedPrimaryField.DBName) + } + + db.Statement.WriteString(" VALUES ") + + for idx, value := range values.Values { + if idx > 0 { + db.Statement.WriteByte(',') + } + + db.Statement.WriteByte('(') + db.Statement.AddVar(db.Statement, value...) + db.Statement.WriteByte(')') + } + } else { + db.Statement.WriteString("DEFAULT VALUES") + } + } + + db.Statement.WriteByte(' ') + db.Statement.Build("ON CONFLICT") + } + + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + defer rows.Close() + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for rows.Next() { + // for idx, field := range fields { + // values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + // } + + values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + if err := rows.Scan(values); err != nil { + db.AddError(err) + } + db.RowsAffected++ + } + case reflect.Struct: + // for idx, field := range fields { + // values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + // } + values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + + if rows.Next() { + err = rows.Scan(values) + } + } + } else { + db.AddError(err) + } +} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index ad6782c7..35fcb484 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -26,10 +26,38 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) + db.Callback().Create().Replace("gorm:create", Create) db.ConnPool, err = sql.Open("sqlserver", dialector.DSN) + + for k, v := range dialector.ClauseBuilders() { + db.ClauseBuilders[k] = v + } return } +func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { + return map[string]clause.ClauseBuilder{ + "LIMIT": func(c clause.Clause, builder clause.Builder) { + if limit, ok := c.Expression.(clause.Limit); ok { + if limit.Offset > 0 { + builder.WriteString("OFFSET ") + builder.WriteString(strconv.Itoa(limit.Offset)) + builder.WriteString("ROWS") + } + + if limit.Limit > 0 { + if limit.Offset == 0 { + builder.WriteString(" OFFSET 0 ROWS") + } + builder.WriteString(" FETCH NEXT ") + builder.WriteString(strconv.Itoa(limit.Limit)) + builder.WriteString(" ROWS ONLY") + } + } + }, + } +} + func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { return Migrator{migrator.Migrator{Config: migrator.Config{ DB: db, From cc07ee0444cac16388778b413be93e877ed80816 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 30 May 2020 12:46:30 +0800 Subject: [PATCH 344/881] Support mssql merge --- dialects/mssql/create.go | 143 +++++++++++++++++++++++++++++---------- dialects/mssql/mssql.go | 2 +- 2 files changed, 108 insertions(+), 37 deletions(-) diff --git a/dialects/mssql/create.go b/dialects/mssql/create.go index 4aecce10..9183ba76 100644 --- a/dialects/mssql/create.go +++ b/dialects/mssql/create.go @@ -16,49 +16,48 @@ func Create(db *gorm.DB) { } if db.Statement.SQL.String() == "" { - db.Statement.AddClauseIfNotExists(clause.Insert{ - Table: clause.Table{Name: db.Statement.Table}, - }) - db.Statement.AddClause(callbacks.ConvertToCreateValues(db.Statement)) + c := db.Statement.Clauses["ON CONFLICT"] + onConflict, hasConflict := c.Expression.(clause.OnConflict) - db.Statement.Build("INSERT") - db.Statement.WriteByte(' ') - - c := db.Statement.Clauses["VALUES"] - if values, ok := c.Expression.(clause.Values); ok { - if len(values.Columns) > 0 { - db.Statement.WriteByte('(') - for idx, column := range values.Columns { - if idx > 0 { - db.Statement.WriteByte(',') - } - db.Statement.WriteQuoted(column) - } - db.Statement.WriteByte(')') - - if db.Statement.Schema.PrioritizedPrimaryField != nil { - db.Statement.WriteString(" OUTPUT INSERTED.") - db.Statement.WriteQuoted(db.Statement.Schema.PrioritizedPrimaryField.DBName) - } - - db.Statement.WriteString(" VALUES ") - - for idx, value := range values.Values { - if idx > 0 { - db.Statement.WriteByte(',') - } + if hasConflict { + MergeCreate(db, onConflict) + } else { + db.Statement.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: db.Statement.Table}}) + db.Statement.Build("INSERT") + db.Statement.WriteByte(' ') + db.Statement.AddClause(callbacks.ConvertToCreateValues(db.Statement)) + if values, ok := db.Statement.Clauses["VALUES"].Expression.(clause.Values); ok { + if len(values.Columns) > 0 { db.Statement.WriteByte('(') - db.Statement.AddVar(db.Statement, value...) + for idx, column := range values.Columns { + if idx > 0 { + db.Statement.WriteByte(',') + } + db.Statement.WriteQuoted(column) + } db.Statement.WriteByte(')') + + outputInserted(db) + + db.Statement.WriteString(" VALUES ") + + for idx, value := range values.Values { + if idx > 0 { + db.Statement.WriteByte(',') + } + + db.Statement.WriteByte('(') + db.Statement.AddVar(db.Statement, value...) + db.Statement.WriteByte(')') + } + + db.Statement.WriteString(";") + } else { + db.Statement.WriteString("DEFAULT VALUES") } - } else { - db.Statement.WriteString("DEFAULT VALUES") } } - - db.Statement.WriteByte(' ') - db.Statement.Build("ON CONFLICT") } rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) @@ -93,3 +92,75 @@ func Create(db *gorm.DB) { db.AddError(err) } } + +func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) { + values := callbacks.ConvertToCreateValues(db.Statement) + + db.Statement.WriteString("MERGE INTO ") + db.Statement.WriteQuoted(db.Statement.Table) + db.Statement.WriteString(" USING (VALUES") + for idx, value := range values.Values { + if idx > 0 { + db.Statement.WriteByte(',') + } + + db.Statement.WriteByte('(') + db.Statement.AddVar(db.Statement, value...) + db.Statement.WriteByte(')') + } + + db.Statement.WriteString(") AS source (") + for idx, column := range values.Columns { + if idx > 0 { + db.Statement.WriteByte(',') + } + db.Statement.WriteQuoted(column.Name) + } + db.Statement.WriteString(") ON ") + + var where clause.Where + for _, field := range db.Statement.Schema.PrimaryFields { + where.Exprs = append(where.Exprs, clause.Eq{ + Column: clause.Column{Table: db.Statement.Table, Name: field.DBName}, + Value: clause.Column{Table: "source", Name: field.DBName}, + }) + } + where.Build(db.Statement) + + if len(onConflict.DoUpdates) > 0 { + db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ") + onConflict.DoUpdates.Build(db.Statement) + } + + db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (") + + for idx, column := range values.Columns { + if idx > 0 { + db.Statement.WriteByte(',') + } + db.Statement.WriteQuoted(column.Name) + } + + db.Statement.WriteString(") VALUES (") + + for idx, column := range values.Columns { + if idx > 0 { + db.Statement.WriteByte(',') + } + db.Statement.WriteQuoted(clause.Column{ + Table: "source", + Name: column.Name, + }) + } + + db.Statement.WriteString(")") + outputInserted(db) + db.Statement.WriteString(";") +} + +func outputInserted(db *gorm.DB) { + if db.Statement.Schema.PrioritizedPrimaryField != nil { + db.Statement.WriteString(" OUTPUT INSERTED.") + db.Statement.WriteQuoted(db.Statement.Schema.PrioritizedPrimaryField.DBName) + } +} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 35fcb484..de82f375 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -112,7 +112,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { if size > 0 && size <= 4000 { return fmt.Sprintf("nvarchar(%d)", size) } - return "ntext" + return "nvarchar(MAX)" case schema.Time: return "datetimeoffset" case schema.Bytes: From 05e1af3bfbe34c1a04645b1559d662d013e74a9f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 30 May 2020 13:46:33 +0800 Subject: [PATCH 345/881] Test Upsert --- dialects/mssql/create.go | 18 ++++++++++++++++++ tests/upsert_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/dialects/mssql/create.go b/dialects/mssql/create.go index 9183ba76..b17a2227 100644 --- a/dialects/mssql/create.go +++ b/dialects/mssql/create.go @@ -6,6 +6,7 @@ import ( "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" ) func Create(db *gorm.DB) { @@ -85,6 +86,7 @@ func Create(db *gorm.DB) { values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() if rows.Next() { + db.RowsAffected++ err = rows.Scan(values) } } @@ -95,6 +97,16 @@ func Create(db *gorm.DB) { func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) { values := callbacks.ConvertToCreateValues(db.Statement) + setIdentityInsert := false + + if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil { + if field.DataType == schema.Int || field.DataType == schema.Uint { + setIdentityInsert = true + db.Statement.WriteString("SET IDENTITY_INSERT ") + db.Statement.WriteQuoted(db.Statement.Table) + db.Statement.WriteString("ON;") + } + } db.Statement.WriteString("MERGE INTO ") db.Statement.WriteQuoted(db.Statement.Table) @@ -156,6 +168,12 @@ func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) { db.Statement.WriteString(")") outputInserted(db) db.Statement.WriteString(";") + + if setIdentityInsert { + db.Statement.WriteString("SET IDENTITY_INSERT ") + db.Statement.WriteQuoted(db.Statement.Table) + db.Statement.WriteString("OFF;") + } } func outputInserted(db *gorm.DB) { diff --git a/tests/upsert_test.go b/tests/upsert_test.go index 615ead95..6f67f603 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -4,9 +4,49 @@ import ( "testing" "time" + "github.com/jinzhu/gorm/clause" . "github.com/jinzhu/gorm/tests" ) +func TestUpsert(t *testing.T) { + lang := Language{Code: "upsert", Name: "Upsert"} + DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang) + + lang2 := Language{Code: "upsert", Name: "Upsert"} + DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang2) + + var langs []Language + if err := DB.Find(&langs, "code = ?", lang.Code).Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs) != 1 { + t.Errorf("should only find only 1 languages, but got %+v", langs) + } +} + +func TestUpsertSlice(t *testing.T) { + langs := []Language{ + {Code: "upsert-slice1", Name: "Upsert-slice1"}, + {Code: "upsert-slice2", Name: "Upsert-slice2"}, + {Code: "upsert-slice3", Name: "Upsert-slice3"}, + } + DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs) + + var langs2 []Language + if err := DB.Find(&langs2, "code LIKE ?", "upsert-slice%").Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs2) != 3 { + t.Errorf("should only find only 3 languages, but got %+v", langs2) + } + + DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs) + var langs3 []Language + if err := DB.Find(&langs3, "code LIKE ?", "upsert-slice%").Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs3) != 3 { + t.Errorf("should only find only 3 languages, but got %+v", langs3) + } +} + func TestFindOrInitialize(t *testing.T) { var user1, user2, user3, user4, user5, user6 User if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil { From d2741ae51eddfe927c503626d435b4a3444996fc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 30 May 2020 14:29:05 +0800 Subject: [PATCH 346/881] Fix test failed due to time round --- tests/utils.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utils.go b/tests/utils.go index 001d77e9..92163d5c 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -86,8 +86,8 @@ func AssertEqual(t *testing.T, got, expect interface{}) { if curTime, ok := got.(time.Time); ok { format := "2006-01-02T15:04:05Z07:00" - if curTime.Format(format) != expect.(time.Time).Format(format) { - t.Errorf("%v: expect: %v, got %v", utils.FileWithLineNum(), expect.(time.Time).Format(format), curTime.Format(format)) + if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) { + t.Errorf("%v: expect: %v, got %v", utils.FileWithLineNum(), expect.(time.Time).Round(time.Second).Format(format), curTime.Round(time.Second).Format(format)) } } else if got != expect { t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) From abae7f71c5deac2dac48101dd622824bbd2499a2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 30 May 2020 16:03:27 +0800 Subject: [PATCH 347/881] Test non std primary key and default value --- callbacks/update.go | 4 +- schema/field.go | 2 + tests/non_std_test.go | 63 +++++++++++++++++ tests/update_belongs_to_test.go | 25 +++++++ tests/update_has_many_test.go | 41 +++++++++++ tests/update_has_one_test.go | 43 ++++++++++++ tests/update_many2many_test.go | 29 ++++++++ tests/update_test.go | 120 ++------------------------------ 8 files changed, 211 insertions(+), 116 deletions(-) create mode 100644 tests/non_std_test.go create mode 100644 tests/update_belongs_to_test.go create mode 100644 tests/update_has_many_test.go create mode 100644 tests/update_has_one_test.go create mode 100644 tests/update_many2many_test.go diff --git a/callbacks/update.go b/callbacks/update.go index f56aa22c..17de97f0 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -119,7 +119,9 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } case reflect.Struct: assignValue = func(field *schema.Field, value interface{}) { - field.Set(reflectModelValue, value) + if reflectModelValue.CanAddr() { + field.Set(reflectModelValue, value) + } } default: assignValue = func(field *schema.Field, value interface{}) { diff --git a/schema/field.go b/schema/field.go index f4fbad95..d435c928 100644 --- a/schema/field.go +++ b/schema/field.go @@ -231,6 +231,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.String: field.DataType = String if field.HasDefaultValue { + field.DefaultValue = strings.Trim(field.DefaultValue, "'") + field.DefaultValue = strings.Trim(field.DefaultValue, "\"") field.DefaultValueInterface = field.DefaultValue } case reflect.Struct: diff --git a/tests/non_std_test.go b/tests/non_std_test.go new file mode 100644 index 00000000..b8a278fe --- /dev/null +++ b/tests/non_std_test.go @@ -0,0 +1,63 @@ +package tests_test + +import ( + "testing" + "time" + + . "github.com/jinzhu/gorm/tests" +) + +type Animal struct { + Counter uint64 `gorm:"primary_key:yes"` + Name string `gorm:"DEFAULT:'galeone'"` + From string //test reserved sql keyword as field name + Age time.Time `gorm:"DEFAULT:current_timestamp"` + unexported string // unexported value + CreatedAt time.Time + UpdatedAt time.Time +} + +func init() { + DB.Migrator().DropTable(&Animal{}) + DB.AutoMigrate(&Animal{}) +} + +func TestNonStdPrimaryKeyAndDefaultValues(t *testing.T) { + animal := Animal{Name: "Ferdinand"} + DB.Save(&animal) + updatedAt1 := animal.UpdatedAt + + DB.Save(&animal).Update("name", "Francis") + if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("UpdatedAt should be updated") + } + + var animals []Animal + DB.Find(&animals) + if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { + t.Error("RowsAffected should be correct when do batch update") + } + + animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone) + DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched + DB.First(&animal, animal.Counter) + if animal.Name != "galeone" { + t.Errorf("Name fields shouldn't be changed if untouched, but got %v", animal.Name) + } + + // When changing a field with a default value, the change must occur + animal.Name = "amazing horse" + DB.Save(&animal) + DB.First(&animal, animal.Counter) + if animal.Name != "amazing horse" { + t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name) + } + + // When changing a field with a default value with blank value + animal.Name = "" + DB.Save(&animal) + DB.First(&animal, animal.Counter) + if animal.Name != "" { + t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name) + } +} diff --git a/tests/update_belongs_to_test.go b/tests/update_belongs_to_test.go new file mode 100644 index 00000000..267fd4e8 --- /dev/null +++ b/tests/update_belongs_to_test.go @@ -0,0 +1,25 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestUpdateBelongsTo(t *testing.T) { + var user = *GetUser("update-belongs-to", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Company = Company{Name: "company-belongs-to-association"} + user.Manager = &User{Name: "manager-belongs-to-association"} + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) +} diff --git a/tests/update_has_many_test.go b/tests/update_has_many_test.go new file mode 100644 index 00000000..e723b940 --- /dev/null +++ b/tests/update_has_many_test.go @@ -0,0 +1,41 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestUpdateHasManyAssociations(t *testing.T) { + var user = *GetUser("update-has-many", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Pets = []*Pet{{Name: "pet1"}, {Name: "pet2"}} + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Pets").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + t.Run("Polymorphic", func(t *testing.T) { + var user = *GetUser("update-has-many", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Toys = []Toy{{Name: "toy1"}, {Name: "toy2"}} + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Toys").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + }) +} diff --git a/tests/update_has_one_test.go b/tests/update_has_one_test.go new file mode 100644 index 00000000..4c5036cf --- /dev/null +++ b/tests/update_has_one_test.go @@ -0,0 +1,43 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestUpdateHasOne(t *testing.T) { + var user = *GetUser("update-has-one", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Account = Account{Number: "account-has-one-association"} + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Account").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + t.Run("Polymorphic", func(t *testing.T) { + var pet = Pet{Name: "create"} + + if err := DB.Create(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + pet.Toy = Toy{Name: "Update-HasOneAssociation-Polymorphic"} + + if err := DB.Save(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var pet2 Pet + DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID) + CheckPet(t, pet2, pet) + }) +} diff --git a/tests/update_many2many_test.go b/tests/update_many2many_test.go new file mode 100644 index 00000000..bc7a60af --- /dev/null +++ b/tests/update_many2many_test.go @@ -0,0 +1,29 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestUpdateMany2ManyAssociations(t *testing.T) { + var user = *GetUser("update-many2many", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Languages = []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}} + for _, lang := range user.Languages { + DB.Create(&lang) + } + user.Friends = []*User{{Name: "friend-1"}, {Name: "friend-2"}} + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) +} diff --git a/tests/update_test.go b/tests/update_test.go index 10835f97..71da0751 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -18,7 +18,7 @@ func TestUpdate(t *testing.T) { lastUpdatedAt time.Time ) - checkUpdatedTime := func(name string, n time.Time) { + checkUpdatedAtChanged := func(name string, n time.Time) { if n.UnixNano() == lastUpdatedAt.UnixNano() { t.Errorf("%v: user's updated at should be changed, but got %v, was %v", name, n, lastUpdatedAt) } @@ -52,7 +52,7 @@ func TestUpdate(t *testing.T) { } else if user.Age != 10 { t.Errorf("Age should equals to 10, but got %v", user.Age) } - checkUpdatedTime("Update", user.UpdatedAt) + checkUpdatedAtChanged("Update", user.UpdatedAt) checkOtherData("Update") var result User @@ -70,7 +70,7 @@ func TestUpdate(t *testing.T) { } else if user.Active != true { t.Errorf("Active should be true, but got %v", user.Active) } - checkUpdatedTime("Updates with map", user.UpdatedAt) + checkUpdatedAtChanged("Updates with map", user.UpdatedAt) checkOtherData("Updates with map") var result2 User @@ -85,7 +85,7 @@ func TestUpdate(t *testing.T) { } else if user.Age != 2 { t.Errorf("Age should equals to 2, but got %v", user.Age) } - checkUpdatedTime("Updates with struct", user.UpdatedAt) + checkUpdatedAtChanged("Updates with struct", user.UpdatedAt) checkOtherData("Updates with struct") var result3 User @@ -104,7 +104,7 @@ func TestUpdate(t *testing.T) { } else if user.Active != false { t.Errorf("Active should equals to false, but got %v", user.Active) } - checkUpdatedTime("Save", user.UpdatedAt) + checkUpdatedAtChanged("Save", user.UpdatedAt) checkOtherData("Save") var result4 User @@ -114,113 +114,3 @@ func TestUpdate(t *testing.T) { CheckUser(t, result4, *user) } } - -func TestUpdateBelongsTo(t *testing.T) { - var user = *GetUser("update-belongs-to", Config{}) - - if err := DB.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - user.Company = Company{Name: "company-belongs-to-association"} - user.Manager = &User{Name: "manager-belongs-to-association"} - if err := DB.Save(&user).Error; err != nil { - t.Fatalf("errors happened when update: %v", err) - } - - var user2 User - DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID) - CheckUser(t, user2, user) -} - -func TestUpdateHasOne(t *testing.T) { - var user = *GetUser("update-has-one", Config{}) - - if err := DB.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - user.Account = Account{Number: "account-has-one-association"} - - if err := DB.Save(&user).Error; err != nil { - t.Fatalf("errors happened when update: %v", err) - } - - var user2 User - DB.Preload("Account").Find(&user2, "id = ?", user.ID) - CheckUser(t, user2, user) - - t.Run("Polymorphic", func(t *testing.T) { - var pet = Pet{Name: "create"} - - if err := DB.Create(&pet).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - pet.Toy = Toy{Name: "Update-HasOneAssociation-Polymorphic"} - - if err := DB.Save(&pet).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - var pet2 Pet - DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID) - CheckPet(t, pet2, pet) - }) -} - -func TestUpdateHasManyAssociations(t *testing.T) { - var user = *GetUser("update-has-many", Config{}) - - if err := DB.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - user.Pets = []*Pet{{Name: "pet1"}, {Name: "pet2"}} - if err := DB.Save(&user).Error; err != nil { - t.Fatalf("errors happened when update: %v", err) - } - - var user2 User - DB.Preload("Pets").Find(&user2, "id = ?", user.ID) - CheckUser(t, user2, user) - - t.Run("Polymorphic", func(t *testing.T) { - var user = *GetUser("update-has-many", Config{}) - - if err := DB.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - user.Toys = []Toy{{Name: "toy1"}, {Name: "toy2"}} - if err := DB.Save(&user).Error; err != nil { - t.Fatalf("errors happened when update: %v", err) - } - - var user2 User - DB.Preload("Toys").Find(&user2, "id = ?", user.ID) - CheckUser(t, user2, user) - }) -} - -func TestUpdateMany2ManyAssociations(t *testing.T) { - var user = *GetUser("update-many2many", Config{}) - - if err := DB.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - user.Languages = []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}} - for _, lang := range user.Languages { - DB.Create(&lang) - } - user.Friends = []*User{{Name: "friend-1"}, {Name: "friend-2"}} - - if err := DB.Save(&user).Error; err != nil { - t.Fatalf("errors happened when update: %v", err) - } - - var user2 User - DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) - CheckUser(t, user2, user) -} From 028c9d6e17d733aae984ea1b21ce250822507a92 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 30 May 2020 16:47:16 +0800 Subject: [PATCH 348/881] Test Updates --- callbacks/update.go | 4 +-- dialects/mssql/mssql_test.go | 35 ----------------------- dialects/mysql/mysql_test.go | 35 ----------------------- dialects/postgres/postgres_test.go | 35 ----------------------- dialects/sqlite/sqlite_test.go | 31 -------------------- gorm.go | 4 +++ schema/field_test.go | 8 +++--- tests/tests_all.sh | 8 ------ tests/update_test.go | 45 ++++++++++++++++++++++++++++++ 9 files changed, 55 insertions(+), 150 deletions(-) delete mode 100644 dialects/mssql/mssql_test.go delete mode 100644 dialects/mysql/mysql_test.go delete mode 100644 dialects/postgres/postgres_test.go delete mode 100644 dialects/sqlite/sqlite_test.go diff --git a/callbacks/update.go b/callbacks/update.go index 17de97f0..7e8c0f3e 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -164,7 +164,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { case reflect.Struct: set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) for _, field := range stmt.Schema.FieldsByDBName { - if !field.PrimaryKey || stmt.Dest != stmt.Model { + if !field.PrimaryKey || (!stmt.ReflectValue.CanAddr() || stmt.Dest != stmt.Model) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { value, isZero := field.ValueOf(stmt.ReflectValue) if field.AutoUpdateTime > 0 { @@ -186,7 +186,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } - if stmt.Dest != stmt.Model { + if !stmt.ReflectValue.CanAddr() || stmt.Dest != stmt.Model { reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Model)) switch reflectValue.Kind() { case reflect.Slice, reflect.Array: diff --git a/dialects/mssql/mssql_test.go b/dialects/mssql/mssql_test.go deleted file mode 100644 index 49b3cd6a..00000000 --- a/dialects/mssql/mssql_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package mssql_test - -import ( - "fmt" - "os" - "testing" - - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/dialects/mssql" - "github.com/jinzhu/gorm/tests" -) - -var ( - DB *gorm.DB - err error -) - -func init() { - dsn := "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" - if os.Getenv("GORM_DSN") != "" { - dsn = os.Getenv("GORM_DSN") - } - - if DB, err = gorm.Open(mssql.Open(dsn), &gorm.Config{}); err != nil { - panic(fmt.Sprintf("failed to initialize database, got error %v", err)) - } -} - -func TestCURD(t *testing.T) { - tests.RunTestsSuit(t, DB) -} - -func TestMigrate(t *testing.T) { - tests.TestMigrate(t, DB) -} diff --git a/dialects/mysql/mysql_test.go b/dialects/mysql/mysql_test.go deleted file mode 100644 index cb3b240a..00000000 --- a/dialects/mysql/mysql_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package mysql_test - -import ( - "fmt" - "os" - "testing" - - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/dialects/mysql" - "github.com/jinzhu/gorm/tests" -) - -var ( - DB *gorm.DB - err error -) - -func init() { - dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" - if os.Getenv("GORM_DSN") != "" { - dsn = os.Getenv("GORM_DSN") - } - - if DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}); err != nil { - panic(fmt.Sprintf("failed to initialize database, got error %v", err)) - } -} - -func TestCURD(t *testing.T) { - tests.RunTestsSuit(t, DB) -} - -func TestMigrate(t *testing.T) { - tests.TestMigrate(t, DB) -} diff --git a/dialects/postgres/postgres_test.go b/dialects/postgres/postgres_test.go deleted file mode 100644 index 2185c19c..00000000 --- a/dialects/postgres/postgres_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package postgres_test - -import ( - "fmt" - "os" - "testing" - - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/dialects/postgres" - "github.com/jinzhu/gorm/tests" -) - -var ( - DB *gorm.DB - err error -) - -func init() { - dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" - if os.Getenv("GORM_DSN") != "" { - dsn = os.Getenv("GORM_DSN") - } - - if DB, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}); err != nil { - panic(fmt.Sprintf("failed to initialize database, got error %v", err)) - } -} - -func TestCURD(t *testing.T) { - tests.RunTestsSuit(t, DB) -} - -func TestMigrate(t *testing.T) { - tests.TestMigrate(t, DB) -} diff --git a/dialects/sqlite/sqlite_test.go b/dialects/sqlite/sqlite_test.go deleted file mode 100644 index a42bc8ee..00000000 --- a/dialects/sqlite/sqlite_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package sqlite_test - -import ( - "fmt" - "os" - "path/filepath" - "testing" - - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/dialects/sqlite" - "github.com/jinzhu/gorm/tests" -) - -var ( - DB *gorm.DB - err error -) - -func init() { - if DB, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}); err != nil { - panic(fmt.Sprintf("failed to initialize database, got error %v", err)) - } -} - -func TestCURD(t *testing.T) { - tests.RunTestsSuit(t, DB) -} - -func TestMigrate(t *testing.T) { - tests.TestMigrate(t, DB) -} diff --git a/gorm.go b/gorm.go index 942024cf..6b2a6d75 100644 --- a/gorm.go +++ b/gorm.go @@ -189,3 +189,7 @@ func (db *DB) getInstance() *DB { return db } + +func Expr(expr string, args ...interface{}) clause.Expr { + return clause.Expr{SQL: expr, Vars: args} +} diff --git a/schema/field_test.go b/schema/field_test.go index c04149ff..aac46de9 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -19,7 +19,7 @@ func TestFieldValuerAndSetter(t *testing.T) { Model: gorm.Model{ ID: 10, CreatedAt: time.Now(), - DeletedAt: tests.Now(), + DeletedAt: gorm.DeletedAt{Time: time.Now(), Valid: true}, }, Name: "valuer_and_setter", Age: 18, @@ -46,7 +46,7 @@ func TestFieldValuerAndSetter(t *testing.T) { "name": "valuer_and_setter_2", "id": 2, "created_at": time.Now(), - "deleted_at": tests.Now(), + "deleted_at": time.Now(), "age": 20, "birthday": time.Now(), "active": false, @@ -89,7 +89,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { Model: &gorm.Model{ ID: 10, CreatedAt: time.Now(), - DeletedAt: tests.Now(), + DeletedAt: gorm.DeletedAt{Time: time.Now(), Valid: true}, }, Name: &name, Age: &age, @@ -116,7 +116,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { "name": "valuer_and_setter_2", "id": 2, "created_at": time.Now(), - "deleted_at": tests.Now(), + "deleted_at": time.Now(), "age": 20, "birthday": time.Now(), "active": false, diff --git a/tests/tests_all.sh b/tests/tests_all.sh index cd42e1e0..0c24a888 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -9,16 +9,8 @@ for dialect in "${dialects[@]}" ; do then if [ "$GORM_VERBOSE" = "" ] then - cd dialects/${dialect} - DEBUG=false GORM_DIALECT=${dialect} go test -race ./... - cd ../.. - DEBUG=false GORM_DIALECT=${dialect} go test -race ./... else - cd dialects/${dialect} - DEBUG=false GORM_DIALECT=${dialect} go test -race ./... - cd ../.. - DEBUG=false GORM_DIALECT=${dialect} go test -race -v ./... fi fi diff --git a/tests/update_test.go b/tests/update_test.go index 71da0751..cb61b40e 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/jinzhu/gorm" . "github.com/jinzhu/gorm/tests" ) @@ -114,3 +115,47 @@ func TestUpdate(t *testing.T) { CheckUser(t, result4, *user) } } + +func TestUpdates(t *testing.T) { + var users = []*User{ + GetUser("updates_01", Config{}), + GetUser("updates_02", Config{}), + } + + DB.Create(&users) + lastUpdatedAt := users[0].UpdatedAt + + // update with map + DB.Model(users[0]).Updates(map[string]interface{}{"name": "updates_01_newname", "age": 100}) + if users[0].Name != "updates_01_newname" || users[0].Age != 100 { + t.Errorf("Record should be updated also with map") + } + + if users[0].UpdatedAt.UnixNano() == lastUpdatedAt.UnixNano() { + t.Errorf("User's updated at should be changed, but got %v, was %v", users[0].UpdatedAt.UnixNano(), lastUpdatedAt) + } + + // user2 should not be updated + var user1, user2 User + DB.First(&user1, users[0].ID) + DB.First(&user2, users[1].ID) + CheckUser(t, user1, *users[0]) + CheckUser(t, user2, *users[1]) + + // update with struct + DB.Table("users").Where("name in ?", []string{users[1].Name}).Updates(User{Name: "updates_02_newname"}) + + var user3 User + if DB.First(&user3, "name = ?", "updates_02_newname").RecordNotFound() { + t.Errorf("User2's name should be updated") + } + AssertEqual(t, user2.UpdatedAt, user3.UpdatedAt) + + // update with gorm exprs + DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}) + var user4 User + DB.First(&user4, user3.ID) + + user3.Age += 100 + AssertEqual(t, user4.UpdatedAt, user3.UpdatedAt) +} From 9dd516a7e8aaccad326778abac631782f24689e1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 30 May 2020 17:34:22 +0800 Subject: [PATCH 349/881] Test UpdateColumn --- callbacks/update.go | 23 ++++++++++---------- finisher_api.go | 2 ++ statement.go | 1 + tests/update_test.go | 52 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 11 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 7e8c0f3e..623d64fe 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -141,9 +141,6 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { for _, k := range keys { if field := stmt.Schema.LookUpField(k); field != nil { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if field.AutoUpdateTime > 0 { - value[k] = time.Now() - } set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) assignValue(field, value[k]) } @@ -152,11 +149,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } - for _, field := range stmt.Schema.FieldsByDBName { - if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { - now := time.Now() - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) - assignValue(field, now) + if !stmt.DisableUpdateTime { + for _, field := range stmt.Schema.FieldsByDBName { + if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { + now := time.Now() + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) + assignValue(field, now) + } } } default: @@ -167,9 +166,11 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if !field.PrimaryKey || (!stmt.ReflectValue.CanAddr() || stmt.Dest != stmt.Model) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { value, isZero := field.ValueOf(stmt.ReflectValue) - if field.AutoUpdateTime > 0 { - value = time.Now() - isZero = false + if !stmt.DisableUpdateTime { + if field.AutoUpdateTime > 0 { + value = time.Now() + isZero = false + } } if ok || !isZero { diff --git a/finisher_api.go b/finisher_api.go index c47e12af..f14bcfbe 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -207,6 +207,7 @@ func (db *DB) Updates(values interface{}) (tx *DB) { func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} + tx.Statement.DisableUpdateTime = true tx.callbacks.Update().Execute(tx) return } @@ -214,6 +215,7 @@ func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { func (db *DB) UpdateColumns(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values + tx.Statement.DisableUpdateTime = true tx.callbacks.Update().Execute(tx) return } diff --git a/statement.go b/statement.go index f81ae0e5..42df148a 100644 --- a/statement.go +++ b/statement.go @@ -32,6 +32,7 @@ type Statement struct { Schema *schema.Schema Context context.Context RaiseErrorOnNotFound bool + DisableUpdateTime bool SQL strings.Builder Vars []interface{} NamedVars []sql.NamedArg diff --git a/tests/update_test.go b/tests/update_test.go index cb61b40e..371a9f78 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -159,3 +159,55 @@ func TestUpdates(t *testing.T) { user3.Age += 100 AssertEqual(t, user4.UpdatedAt, user3.UpdatedAt) } + +func TestUpdateColumn(t *testing.T) { + var users = []*User{ + GetUser("update_column_01", Config{}), + GetUser("update_column_02", Config{}), + } + + DB.Create(&users) + lastUpdatedAt := users[1].UpdatedAt + + // update with map + DB.Model(users[1]).UpdateColumns(map[string]interface{}{"name": "update_column_02_newname", "age": 100}) + if users[1].Name != "update_column_02_newname" || users[1].Age != 100 { + t.Errorf("user 2 should be updated with update column") + } + AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) + + // user2 should not be updated + var user1, user2 User + DB.First(&user1, users[0].ID) + DB.First(&user2, users[1].ID) + CheckUser(t, user1, *users[0]) + CheckUser(t, user2, *users[1]) + + DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew") + AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) + + if users[1].Name != "update_column_02_newnew" { + t.Errorf("user 2's name should be updated, but got %v", users[1].Name) + } + + DB.Model(users[1]).UpdateColumn("age", gorm.Expr("age + 100 - 50")) + var user3 User + DB.First(&user3, users[1].ID) + + users[1].Age += 50 + CheckUser(t, user3, *users[1]) + + // update with struct + DB.Model(users[1]).UpdateColumns(User{Name: "update_column_02_newnew2", Age: 200}) + if users[1].Name != "update_column_02_newnew2" || users[1].Age != 200 { + t.Errorf("user 2 should be updated with update column") + } + AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) + + // user2 should not be updated + var user5, user6 User + DB.First(&user5, users[0].ID) + DB.First(&user6, users[1].ID) + CheckUser(t, user5, *users[0]) + CheckUser(t, user6, *users[1]) +} From c422d75f4b474d36f60a9559273d08d080bc0c28 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 30 May 2020 18:50:20 +0800 Subject: [PATCH 350/881] Add Scopes tests --- callbacks/delete.go | 2 -- clause/expression.go | 30 +++++++++++++++++++++++++-- tests/scopes_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++++ tests/utils.go | 2 +- 4 files changed, 77 insertions(+), 5 deletions(-) create mode 100644 tests/scopes_test.go diff --git a/callbacks/delete.go b/callbacks/delete.go index 1c59afbe..b3278c83 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -1,7 +1,6 @@ package callbacks import ( - "fmt" "reflect" "github.com/jinzhu/gorm" @@ -38,7 +37,6 @@ func Delete(db *gorm.DB) { if db.Statement.Schema != nil && !db.Statement.Unscoped { for _, c := range db.Statement.Schema.DeleteClauses { db.Statement.AddClause(c) - fmt.Println(db.Statement.SQL.String()) } } diff --git a/clause/expression.go b/clause/expression.go index 067774d4..e54da1af 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -1,5 +1,7 @@ package clause +import "reflect" + // Expression expression interface type Expression interface { Build(builder Builder) @@ -18,12 +20,36 @@ type Expr struct { // Build build raw expression func (expr Expr) Build(builder Builder) { - var idx int + var ( + afterParenthesis bool + idx int + ) + for _, v := range []byte(expr.SQL) { if v == '?' { - builder.AddVar(builder, expr.Vars[idx]) + if afterParenthesis { + switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + default: + builder.AddVar(builder, expr.Vars[idx]) + } + } else { + builder.AddVar(builder, expr.Vars[idx]) + } + idx++ } else { + if v == '(' { + afterParenthesis = true + } else { + afterParenthesis = false + } builder.WriteByte(v) } } diff --git a/tests/scopes_test.go b/tests/scopes_test.go new file mode 100644 index 00000000..c0530da5 --- /dev/null +++ b/tests/scopes_test.go @@ -0,0 +1,48 @@ +package tests_test + +import ( + "testing" + + "github.com/jinzhu/gorm" + . "github.com/jinzhu/gorm/tests" +) + +func NameIn1And2(d *gorm.DB) *gorm.DB { + return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}) +} + +func NameIn2And3(d *gorm.DB) *gorm.DB { + return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}) +} + +func NameIn(names []string) func(d *gorm.DB) *gorm.DB { + return func(d *gorm.DB) *gorm.DB { + return d.Where("name in (?)", names) + } +} + +func TestScopes(t *testing.T) { + var users = []*User{ + GetUser("ScopeUser1", Config{}), + GetUser("ScopeUser2", Config{}), + GetUser("ScopeUser3", Config{}), + } + + DB.Create(&users) + + var users1, users2, users3 []User + DB.Scopes(NameIn1And2).Find(&users1) + if len(users1) != 2 { + t.Errorf("Should found two users's name in 1, 2, but got %v", len(users1)) + } + + DB.Scopes(NameIn1And2, NameIn2And3).Find(&users2) + if len(users2) != 1 { + t.Errorf("Should found one user's name is 2, but got %v", len(users2)) + } + + DB.Scopes(NameIn([]string{users[0].Name, users[2].Name})).Find(&users3) + if len(users3) != 2 { + t.Errorf("Should found two users's name in 1, 3, but got %v", len(users3)) + } +} diff --git a/tests/utils.go b/tests/utils.go index 92163d5c..041dc9b1 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -87,7 +87,7 @@ func AssertEqual(t *testing.T, got, expect interface{}) { format := "2006-01-02T15:04:05Z07:00" if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) { - t.Errorf("%v: expect: %v, got %v", utils.FileWithLineNum(), expect.(time.Time).Round(time.Second).Format(format), curTime.Round(time.Second).Format(format)) + t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time).Round(time.Second).Format(format), curTime.Round(time.Second).Format(format)) } } else if got != expect { t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) From c291c2f42cc66892198d5254592602e000c0dac6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 30 May 2020 21:05:27 +0800 Subject: [PATCH 351/881] Add Scanner, Valuer tests --- clause/expression.go | 27 ++++-- logger/sql.go | 7 +- schema/field.go | 2 +- statement.go | 3 + tests/scanner_valuer_test.go | 175 +++++++++++++++++++++++++++++++++++ tests/utils.go | 14 ++- 6 files changed, 211 insertions(+), 17 deletions(-) create mode 100644 tests/scanner_valuer_test.go diff --git a/clause/expression.go b/clause/expression.go index e54da1af..ecf8ba85 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -1,6 +1,9 @@ package clause -import "reflect" +import ( + "database/sql/driver" + "reflect" +) // Expression expression interface type Expression interface { @@ -28,16 +31,20 @@ func (expr Expr) Build(builder Builder) { for _, v := range []byte(expr.SQL) { if v == '?' { if afterParenthesis { - switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < rv.Len(); i++ { - if i > 0 { - builder.WriteByte(',') - } - builder.AddVar(builder, rv.Index(i).Interface()) - } - default: + if _, ok := expr.Vars[idx].(driver.Valuer); ok { builder.AddVar(builder, expr.Vars[idx]) + } else { + switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + default: + builder.AddVar(builder, expr.Vars[idx]) + } } } else { builder.AddVar(builder, expr.Vars[idx]) diff --git a/logger/sql.go b/logger/sql.go index bb4e3e06..dd502324 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -57,6 +57,9 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v vars[idx] = "NULL" } else if rv.Kind() == reflect.Ptr && rv.IsNil() { vars[idx] = "NULL" + } else if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + convertParams(v, idx) } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { convertParams(reflect.Indirect(rv).Interface(), idx) } else { @@ -74,10 +77,6 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v } for idx, v := range vars { - if valuer, ok := v.(driver.Valuer); ok { - v, _ = valuer.Value() - } - convertParams(v, idx) } diff --git a/schema/field.go b/schema/field.go index d435c928..57ba3ac7 100644 --- a/schema/field.go +++ b/schema/field.go @@ -207,7 +207,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DBDataType = val } - switch fieldValue.Elem().Kind() { + switch reflect.Indirect(fieldValue).Kind() { case reflect.Bool: field.DataType = Bool if field.HasDefaultValue { diff --git a/statement.go b/statement.go index 42df148a..e0d92c5e 100644 --- a/statement.go +++ b/statement.go @@ -146,6 +146,9 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { case clause.Expr: writer.WriteString(v.SQL) stmt.Vars = append(stmt.Vars, v.Vars...) + case driver.Valuer: + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) case []interface{}: if len(v) > 0 { writer.WriteByte('(') diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go new file mode 100644 index 00000000..38ffc919 --- /dev/null +++ b/tests/scanner_valuer_test.go @@ -0,0 +1,175 @@ +package tests_test + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "reflect" + "strconv" + "testing" + "time" + + "github.com/jinzhu/gorm" + . "github.com/jinzhu/gorm/tests" +) + +func TestScannerValuer(t *testing.T) { + DB.Migrator().DropTable(&ScannerValuerStruct{}) + if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { + t.Errorf("no error should happen when migrate scanner, valuer struct") + } + + data := ScannerValuerStruct{ + Name: sql.NullString{String: "name", Valid: true}, + Gender: &sql.NullString{String: "M", Valid: true}, + Age: sql.NullInt64{Int64: 18, Valid: true}, + Male: sql.NullBool{Bool: true, Valid: true}, + Height: sql.NullFloat64{Float64: 1.8888, Valid: true}, + Birthday: sql.NullTime{Time: time.Now(), Valid: true}, + Password: EncryptedData("pass1"), + Num: 18, + Strings: StringsSlice{"a", "b", "c"}, + Structs: StructsSlice{ + {"name1", "value1"}, + {"name2", "value2"}, + }, + } + + if err := DB.Create(&data).Error; err != nil { + t.Errorf("No error should happend when create scanner valuer struct, but got %v", err) + } + + var result ScannerValuerStruct + + if err := DB.Find(&result).Error; err != nil { + t.Errorf("no error should happen when query scanner, valuer struct, but got %v", err) + } + + AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Num", "Strings", "Structs") +} + +func TestInvalidValuer(t *testing.T) { + DB.Migrator().DropTable(&ScannerValuerStruct{}) + if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { + t.Errorf("no error should happen when migrate scanner, valuer struct") + } + + data := ScannerValuerStruct{ + Password: EncryptedData("xpass1"), + } + + if err := DB.Create(&data).Error; err == nil { + t.Errorf("Should failed to create data with invalid data") + } + + data.Password = EncryptedData("pass1") + if err := DB.Create(&data).Error; err != nil { + t.Errorf("Should got no error when creating data, but got %v", err) + } + + if err := DB.Model(&data).Update("password", EncryptedData("xnewpass")).Error; err == nil { + t.Errorf("Should failed to update data with invalid data") + } + + if err := DB.Model(&data).Update("password", EncryptedData("newpass")).Error; err != nil { + t.Errorf("Should got no error update data with valid data, but got %v", err) + } + + AssertEqual(t, data.Password, EncryptedData("newpass")) +} + +type ScannerValuerStruct struct { + gorm.Model + Name sql.NullString + Gender *sql.NullString + Age sql.NullInt64 + Male sql.NullBool + Height sql.NullFloat64 + Birthday sql.NullTime + Password EncryptedData + Num Num + Strings StringsSlice + Structs StructsSlice +} + +type EncryptedData []byte + +func (data *EncryptedData) Scan(value interface{}) error { + if b, ok := value.([]byte); ok { + if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' { + return errors.New("Too short") + } + + *data = b[3:] + return nil + } + + return errors.New("Bytes expected") +} + +func (data EncryptedData) Value() (driver.Value, error) { + if len(data) > 0 && data[0] == 'x' { + //needed to test failures + return nil, errors.New("Should not start with 'x'") + } + + //prepend asterisks + return append([]byte("***"), data...), nil +} + +type Num int64 + +func (i *Num) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + n, _ := strconv.Atoi(string(s)) + *i = Num(n) + case int64: + *i = Num(s) + default: + return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) + } + return nil +} + +type StringsSlice []string + +func (l StringsSlice) Value() (driver.Value, error) { + bytes, err := json.Marshal(l) + return string(bytes), err +} + +func (l *StringsSlice) Scan(input interface{}) error { + switch value := input.(type) { + case string: + return json.Unmarshal([]byte(value), l) + case []byte: + return json.Unmarshal(value, l) + default: + return errors.New("not supported") + } +} + +type ExampleStruct struct { + Name string + Value string +} + +type StructsSlice []ExampleStruct + +func (l StructsSlice) Value() (driver.Value, error) { + bytes, err := json.Marshal(l) + return string(bytes), err +} + +func (l *StructsSlice) Scan(input interface{}) error { + switch value := input.(type) { + case string: + return json.Unmarshal([]byte(value), l) + case []byte: + return json.Unmarshal(value, l) + default: + return errors.New("not supported") + } +} diff --git a/tests/utils.go b/tests/utils.go index 041dc9b1..dfddf848 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -1,6 +1,8 @@ package tests import ( + "database/sql/driver" + "fmt" "reflect" "sort" "strconv" @@ -89,12 +91,12 @@ func AssertEqual(t *testing.T, got, expect interface{}) { if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) { t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time).Round(time.Second).Format(format), curTime.Round(time.Second).Format(format)) } - } else if got != expect { + } else if fmt.Sprint(got) != fmt.Sprint(expect) { t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) } } - if got == expect { + if fmt.Sprint(got) == fmt.Sprint(expect) { return } @@ -103,6 +105,14 @@ func AssertEqual(t *testing.T, got, expect interface{}) { return } + if valuer, ok := got.(driver.Valuer); ok { + got, _ = valuer.Value() + } + + if valuer, ok := expect.(driver.Valuer); ok { + expect, _ = valuer.Value() + } + if got != nil { got = reflect.Indirect(reflect.ValueOf(got)).Interface() } From 7c0de9199c6f9225de3958b377e7a8ee0f691694 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 30 May 2020 22:27:20 +0800 Subject: [PATCH 352/881] Test Migrate Indexes --- dialects/mssql/migrator.go | 4 +++ dialects/postgres/migrator.go | 26 +++++--------- dialects/sqlite/migrator.go | 66 +++++++++++++++++++++-------------- migrator/migrator.go | 26 ++++++-------- schema/index.go | 17 +++++++++ tests/delete_test.go | 18 ++++++++++ tests/migrate_test.go | 44 +++++++++++++++++++++++ tests/tests_all.sh | 2 ++ 8 files changed, 145 insertions(+), 58 deletions(-) diff --git a/dialects/mssql/migrator.go b/dialects/mssql/migrator.go index 412d86c6..4707a637 100644 --- a/dialects/mssql/migrator.go +++ b/dialects/mssql/migrator.go @@ -23,6 +23,10 @@ func (m Migrator) HasTable(value interface{}) bool { func (m Migrator) HasIndex(value interface{}, name string) bool { var count int m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + return m.DB.Raw( "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", name, stmt.Table, diff --git a/dialects/postgres/migrator.go b/dialects/postgres/migrator.go index f06af25f..b144f573 100644 --- a/dialects/postgres/migrator.go +++ b/dialects/postgres/migrator.go @@ -37,11 +37,15 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem return } -func (m Migrator) HasIndex(value interface{}, indexName string) bool { +func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + return m.DB.Raw( - "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, indexName, + "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, name, ).Row().Scan(&count) }) @@ -50,10 +54,7 @@ func (m Migrator) HasIndex(value interface{}, indexName string) bool { func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - err := fmt.Errorf("failed to create index with name %v", name) - indexes := stmt.Schema.ParseIndexes() - - if idx, ok := indexes[name]; ok { + if idx := stmt.Schema.LookIndex(name); idx != nil { opts := m.BuildIndexOptions(idx.Fields, stmt) values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} @@ -73,18 +74,9 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { } return m.DB.Exec(createIndexSQL, values...).Error - } else if field := stmt.Schema.LookUpField(name); field != nil { - for _, idx := range indexes { - for _, idxOpt := range idx.Fields { - if idxOpt.Field == field { - if err = m.CreateIndex(value, idx.Name); err != nil { - return err - } - } - } - } } - return err + + return fmt.Errorf("failed to create index with name %v", name) }) } diff --git a/dialects/sqlite/migrator.go b/dialects/sqlite/migrator.go index 601de126..5f3671b4 100644 --- a/dialects/sqlite/migrator.go +++ b/dialects/sqlite/migrator.go @@ -2,6 +2,7 @@ package sqlite import ( "fmt" + "strings" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" @@ -37,17 +38,6 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { return count > 0 } -func (m Migrator) HasIndex(value interface{}, name string) bool { - var count int - m.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Raw( - "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND sql LIKE ?", - "index", stmt.Table, "%INDEX "+name+" ON%", - ).Row().Scan(&count) - }) - return count > 0 -} - func (m Migrator) CreateConstraint(interface{}, string) error { return gorm.ErrNotImplemented } @@ -83,10 +73,7 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - err := fmt.Errorf("failed to create index with name %v", name) - indexes := stmt.Schema.ParseIndexes() - - if idx, ok := indexes[name]; ok { + if idx := stmt.Schema.LookIndex(name); idx != nil { opts := m.BuildIndexOptions(idx.Fields, stmt) values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} @@ -106,17 +93,44 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { } return m.DB.Exec(createIndexSQL, values...).Error - } else if field := stmt.Schema.LookUpField(name); field != nil { - for _, idx := range indexes { - for _, idxOpt := range idx.Fields { - if idxOpt.Field == field { - if err = m.CreateIndex(value, idx.Name); err != nil { - return err - } - } - } - } } - return err + + return fmt.Errorf("failed to create index with name %v", name) + }) +} + +func (m Migrator) HasIndex(value interface{}, name string) bool { + var count int + m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + m.DB.Raw( + "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name, + ).Row().Scan(&count) + return nil + }) + return count > 0 +} + +func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + var sql string + m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql) + if sql != "" { + return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error + } + return fmt.Errorf("failed to find index with name %v", oldName) + }) +} + +func (m Migrator) DropIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error }) } diff --git a/migrator/migrator.go b/migrator/migrator.go index cab266a3..1b0edf68 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -418,10 +418,7 @@ type BuildIndexOptionsInterface interface { func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - err := fmt.Errorf("failed to create index with name %v", name) - indexes := stmt.Schema.ParseIndexes() - - if idx, ok := indexes[name]; ok { + if idx := stmt.Schema.LookIndex(name); idx != nil { opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} @@ -441,23 +438,18 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { } return m.DB.Exec(createIndexSQL, values...).Error - } else if field := stmt.Schema.LookUpField(name); field != nil { - for _, idx := range indexes { - for _, idxOpt := range idx.Fields { - if idxOpt.Field == field { - if err = m.CreateIndex(value, idx.Name); err != nil { - return err - } - } - } - } } - return err + + return fmt.Errorf("failed to create index with name %v", name) }) } func (m Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error }) } @@ -466,6 +458,10 @@ func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + return m.DB.Raw( "SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, stmt.Table, name, diff --git a/schema/index.go b/schema/index.go index c5c96aa4..4228bba2 100644 --- a/schema/index.go +++ b/schema/index.go @@ -52,6 +52,23 @@ func (schema *Schema) ParseIndexes() map[string]Index { return indexes } +func (schema *Schema) LookIndex(name string) *Index { + indexes := schema.ParseIndexes() + for _, index := range indexes { + if index.Name == name { + return &index + } + + for _, field := range index.Fields { + if field.Name == name { + return &index + } + } + } + + return nil +} + func parseFieldIndexes(field *Field) (indexes []Index) { for _, value := range strings.Split(field.Tag.Get("gorm"), ";") { if value != "" { diff --git a/tests/delete_test.go b/tests/delete_test.go index 8be072d3..3f17f1a1 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -46,3 +46,21 @@ func TestDelete(t *testing.T) { } } } + +func TestInlineCondDelete(t *testing.T) { + user1 := *GetUser("inline_delete_1", Config{}) + user2 := *GetUser("inline_delete_2", Config{}) + DB.Save(&user1).Save(&user2) + + if DB.Delete(&User{}, user1.ID).Error != nil { + t.Errorf("No error should happen when delete a record") + } else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { + t.Errorf("User can't be found after delete") + } + + if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Errorf("No error should happen when delete a record, err=%s", err) + } else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { + t.Errorf("User can't be found after delete") + } +} diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 917fba75..d944dfa2 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/jinzhu/gorm" . "github.com/jinzhu/gorm/tests" ) @@ -27,3 +28,46 @@ func TestMigrate(t *testing.T) { } } } + +func TestIndexes(t *testing.T) { + type User struct { + gorm.Model + Name string `gorm:"index"` + } + + if err := DB.Migrator().CreateIndex(&User{}, "Name"); err != nil { + t.Errorf("Got error when tried to create index: %+v", err) + } + + if !DB.Migrator().HasIndex(&User{}, "Name") { + t.Errorf("Failed to find index for user's name") + } + + if err := DB.Migrator().DropIndex(&User{}, "Name"); err != nil { + t.Errorf("Failed to drop index for user's name, got err %v", err) + } + + if DB.Migrator().HasIndex(&User{}, "Name") { + t.Errorf("Should not find index for user's name after delete") + } + + if err := DB.Migrator().CreateIndex(&User{}, "Name"); err != nil { + t.Errorf("Got error when tried to create index: %+v", err) + } + + if err := DB.Migrator().RenameIndex(&User{}, "idx_users_name", "idx_users_name_1"); err != nil { + t.Errorf("no error should happen when rename index, but got %v", err) + } + + if !DB.Migrator().HasIndex(&User{}, "idx_users_name_1") { + t.Errorf("Should find index for user's name after rename") + } + + if err := DB.Migrator().DropIndex(&User{}, "idx_users_name_1"); err != nil { + t.Errorf("Failed to drop index for user's name, got err %v", err) + } + + if DB.Migrator().HasIndex(&User{}, "idx_users_name_1") { + t.Errorf("Should not find index for user's name after delete") + } +} diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 0c24a888..9435b2b1 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -7,6 +7,8 @@ fi for dialect in "${dialects[@]}" ; do if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ] then + echo "testing ${dialect}..." + if [ "$GORM_VERBOSE" = "" ] then DEBUG=false GORM_DIALECT=${dialect} go test -race ./... From 7b6b9c4d22f2aacde8c2815ec35934d9d265019e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 00:42:52 +0800 Subject: [PATCH 353/881] Add tests for Columns --- clause/set.go | 3 -- clause/set_test.go | 2 +- dialects/mysql/mysql.go | 6 ++-- dialects/postgres/migrator.go | 19 +++++++++++ gorm.go | 2 +- logger/logger.go | 5 +-- migrator/migrator.go | 25 +++++++++----- tests/migrate_test.go | 64 +++++++++++++++++++++++++++++------ tests/non_std_test.go | 20 +++++------ tests/tests.go | 2 +- tests/utils.go | 2 +- 11 files changed, 109 insertions(+), 41 deletions(-) diff --git a/clause/set.go b/clause/set.go index de78b1be..590e27d5 100644 --- a/clause/set.go +++ b/clause/set.go @@ -30,8 +30,5 @@ func (set Set) Build(builder Builder) { // MergeClause merge assignments clauses func (set Set) MergeClause(clause *Clause) { - if v, ok := clause.Expression.(Set); ok { - set = append(v, set...) - } clause.Expression = set } diff --git a/clause/set_test.go b/clause/set_test.go index 85754737..48131218 100644 --- a/clause/set_test.go +++ b/clause/set_test.go @@ -26,7 +26,7 @@ func TestSet(t *testing.T) { clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), clause.Set([]clause.Assignment{{clause.Column{Name: "name"}, "jinzhu"}}), }, - "UPDATE `users` SET `users`.`id`=?,`name`=?", []interface{}{1, "jinzhu"}, + "UPDATE `users` SET `name`=?", []interface{}{"jinzhu"}, }, } diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index 6ca9f5f5..23525ed7 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -116,8 +116,10 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { return "double" case schema.String: size := field.Size - if field.PrimaryKey && size == 0 { - size = 256 + if size == 0 { + if field.PrimaryKey || field.HasDefaultValue { + size = 256 + } } if size >= 65536 && size <= int(math.Pow(2, 24)) { diff --git a/dialects/postgres/migrator.go b/dialects/postgres/migrator.go index b144f573..d93f681c 100644 --- a/dialects/postgres/migrator.go +++ b/dialects/postgres/migrator.go @@ -80,6 +80,25 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { }) } +func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Exec( + "ALTER INDEX ? RENAME TO ?", + clause.Column{Name: oldName}, clause.Column{Name: newName}, + ).Error + }) +} + +func (m Migrator) DropIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error + }) +} + func (m Migrator) HasTable(value interface{}) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { diff --git a/gorm.go b/gorm.go index 6b2a6d75..9adc0858 100644 --- a/gorm.go +++ b/gorm.go @@ -66,7 +66,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { } if config.NowFunc == nil { - config.NowFunc = func() time.Time { return time.Now().Local() } + config.NowFunc = func() time.Time { return time.Now().Local().Round(time.Second) } } if dialector != nil { diff --git a/logger/logger.go b/logger/logger.go index 7121b4fb..ae7c22c9 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -28,7 +28,8 @@ const ( type LogLevel int const ( - Error LogLevel = iota + 1 + Silent LogLevel = iota + 1 + Error Warn Info ) @@ -129,7 +130,7 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i if l.LogLevel > 0 { elapsed := time.Now().Sub(begin) switch { - case err != nil: + case err != nil && l.LogLevel >= Error: sql, rows := fc() l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: diff --git a/migrator/migrator.go b/migrator/migrator.go index 1b0edf68..8f35cbea 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -47,25 +47,32 @@ func (m Migrator) DataTypeOf(field *schema.Field) string { return m.Dialector.DataTypeOf(field) } -func (m Migrator) FullDataTypeOf(field *schema.Field) string { - dataType := m.DataTypeOf(field) +func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { + expr.SQL = m.DataTypeOf(field) if field.AutoIncrement { - dataType += " AUTO_INCREMENT" + expr.SQL += " AUTO_INCREMENT" } if field.NotNull { - dataType += " NOT NULL" + expr.SQL += " NOT NULL" } if field.Unique { - dataType += " UNIQUE" + expr.SQL += " UNIQUE" } if field.HasDefaultValue { - dataType += " DEFAULT " + field.DefaultValue + if field.DataType == schema.String { + defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValue}} + m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValue) + expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValue) + } else { + expr.SQL += " DEFAULT " + field.DefaultValue + } } - return dataType + + return } // AutoMigrate @@ -138,7 +145,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { field := stmt.Schema.FieldsByDBName[dbName] createTableSQL += fmt.Sprintf("? ?") hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY") - values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: m.FullDataTypeOf(field)}) + values = append(values, clause.Column{Name: dbName}, m.FullDataTypeOf(field)) createTableSQL += "," } @@ -229,7 +236,7 @@ func (m Migrator) AddColumn(value interface{}, field string) error { if field := stmt.Schema.LookUpField(field); field != nil { return m.DB.Exec( "ALTER TABLE ? ADD ? ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.FullDataTypeOf(field)}, + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), ).Error } return fmt.Errorf("failed to look up field with name: %s", field) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index d944dfa2..00025c58 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -30,44 +30,86 @@ func TestMigrate(t *testing.T) { } func TestIndexes(t *testing.T) { - type User struct { + type IndexStruct struct { gorm.Model - Name string `gorm:"index"` + Name string `gorm:"size:255;index"` } - if err := DB.Migrator().CreateIndex(&User{}, "Name"); err != nil { + DB.Migrator().DropTable(&IndexStruct{}) + DB.AutoMigrate(&IndexStruct{}) + + if err := DB.Migrator().DropIndex(&IndexStruct{}, "Name"); err != nil { + t.Errorf("Failed to drop index for user's name, got err %v", err) + } + + if err := DB.Migrator().CreateIndex(&IndexStruct{}, "Name"); err != nil { t.Errorf("Got error when tried to create index: %+v", err) } - if !DB.Migrator().HasIndex(&User{}, "Name") { + if !DB.Migrator().HasIndex(&IndexStruct{}, "Name") { t.Errorf("Failed to find index for user's name") } - if err := DB.Migrator().DropIndex(&User{}, "Name"); err != nil { + if err := DB.Migrator().DropIndex(&IndexStruct{}, "Name"); err != nil { t.Errorf("Failed to drop index for user's name, got err %v", err) } - if DB.Migrator().HasIndex(&User{}, "Name") { + if DB.Migrator().HasIndex(&IndexStruct{}, "Name") { t.Errorf("Should not find index for user's name after delete") } - if err := DB.Migrator().CreateIndex(&User{}, "Name"); err != nil { + if err := DB.Migrator().CreateIndex(&IndexStruct{}, "Name"); err != nil { t.Errorf("Got error when tried to create index: %+v", err) } - if err := DB.Migrator().RenameIndex(&User{}, "idx_users_name", "idx_users_name_1"); err != nil { + if err := DB.Migrator().RenameIndex(&IndexStruct{}, "idx_index_structs_name", "idx_users_name_1"); err != nil { t.Errorf("no error should happen when rename index, but got %v", err) } - if !DB.Migrator().HasIndex(&User{}, "idx_users_name_1") { + if !DB.Migrator().HasIndex(&IndexStruct{}, "idx_users_name_1") { t.Errorf("Should find index for user's name after rename") } - if err := DB.Migrator().DropIndex(&User{}, "idx_users_name_1"); err != nil { + if err := DB.Migrator().DropIndex(&IndexStruct{}, "idx_users_name_1"); err != nil { t.Errorf("Failed to drop index for user's name, got err %v", err) } - if DB.Migrator().HasIndex(&User{}, "idx_users_name_1") { + if DB.Migrator().HasIndex(&IndexStruct{}, "idx_users_name_1") { t.Errorf("Should not find index for user's name after delete") } } + +func TestColumns(t *testing.T) { + type ColumnStruct struct { + gorm.Model + Name string + } + + DB.Migrator().DropTable(&ColumnStruct{}) + + if err := DB.AutoMigrate(&ColumnStruct{}); err != nil { + t.Errorf("Failed to migrate, got %v", err) + } + + type NewColumnStruct struct { + gorm.Model + Name string + NewName string + } + + if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Errorf("Failed to add column, got %v", err) + } + + if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { + t.Errorf("Failed to find added column") + } + + if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Errorf("Failed to add column, got %v", err) + } + + if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { + t.Errorf("Found deleted column") + } +} diff --git a/tests/non_std_test.go b/tests/non_std_test.go index b8a278fe..e5e50141 100644 --- a/tests/non_std_test.go +++ b/tests/non_std_test.go @@ -8,21 +8,21 @@ import ( ) type Animal struct { - Counter uint64 `gorm:"primary_key:yes"` - Name string `gorm:"DEFAULT:'galeone'"` - From string //test reserved sql keyword as field name - Age time.Time `gorm:"DEFAULT:current_timestamp"` - unexported string // unexported value + Counter uint64 `gorm:"primary_key:yes"` + Name string `gorm:"DEFAULT:'galeone'"` + From string //test reserved sql keyword as field name + Age *time.Time + unexported string // unexported value CreatedAt time.Time UpdatedAt time.Time } -func init() { - DB.Migrator().DropTable(&Animal{}) - DB.AutoMigrate(&Animal{}) -} - func TestNonStdPrimaryKeyAndDefaultValues(t *testing.T) { + DB.Migrator().DropTable(&Animal{}) + if err := DB.AutoMigrate(&Animal{}); err != nil { + t.Fatalf("no error should happen when migrate but got %v", err) + } + animal := Animal{Name: "Ferdinand"} DB.Save(&animal) updatedAt1 := animal.UpdatedAt diff --git a/tests/tests.go b/tests/tests.go index 2b2bfc20..7e216776 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -61,7 +61,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { if debug := os.Getenv("DEBUG"); debug == "true" { db.Logger.LogMode(logger.Info) } else if debug == "false" { - db.Logger.LogMode(logger.Error) + db.Logger.LogMode(logger.Silent) } return diff --git a/tests/utils.go b/tests/utils.go index dfddf848..0a33edee 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -26,7 +26,7 @@ type Config struct { func GetUser(name string, config Config) *User { var ( - birthday = time.Now() + birthday = time.Now().Round(time.Second) user = User{ Name: name, Age: 18, From 2b56fa04725364eed4f2087b0055ea07d577beb2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 01:21:16 +0800 Subject: [PATCH 354/881] Fix Scanner tests on mssql --- dialects/mssql/create.go | 2 +- dialects/mssql/mssql.go | 14 ++++++++++++-- go.mod | 2 +- scan.go | 14 +++++--------- tests/scanner_valuer_test.go | 3 +++ tests/utils.go | 5 +++++ 6 files changed, 27 insertions(+), 13 deletions(-) diff --git a/dialects/mssql/create.go b/dialects/mssql/create.go index b17a2227..c85997fb 100644 --- a/dialects/mssql/create.go +++ b/dialects/mssql/create.go @@ -87,7 +87,7 @@ func Create(db *gorm.DB) { if rows.Next() { db.RowsAffected++ - err = rows.Scan(values) + db.AddError(rows.Scan(values)) } } } else { diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index de82f375..8e309faf 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -2,6 +2,7 @@ package mssql import ( "database/sql" + "database/sql/driver" "fmt" "regexp" "strconv" @@ -80,6 +81,15 @@ func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { var numericPlaceholder = regexp.MustCompile("@p(\\d+)") func (dialector Dialector) Explain(sql string, vars ...interface{}) string { + for idx, v := range vars { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + if v, ok := v.(bool); ok { + vars[idx] = strconv.FormatBool(v) + } + } return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) } @@ -103,7 +113,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { } return sqlType case schema.Float: - return "decimal" + return "float" case schema.String: size := field.Size if field.PrimaryKey && size == 0 { @@ -116,7 +126,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { case schema.Time: return "datetimeoffset" case schema.Bytes: - return "binary" + return "varbinary(MAX)" } return "" diff --git a/go.mod b/go.mod index 45bcf69c..7dabdd39 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/jinzhu/gorm go 1.14 require ( - github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd + github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 github.com/go-sql-driver/mysql v1.5.0 github.com/jinzhu/inflection v1.0.0 diff --git a/scan.go b/scan.go index c223f6eb..66cb0b94 100644 --- a/scan.go +++ b/scan.go @@ -20,7 +20,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { if initialized || rows.Next() { db.RowsAffected++ - rows.Scan(values...) + db.AddError(rows.Scan(values...)) } mapValue, ok := dest.(map[string]interface{}) @@ -41,7 +41,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { for initialized || rows.Next() { initialized = false db.RowsAffected++ - rows.Scan(values...) + db.AddError(rows.Scan(values...)) v := map[string]interface{}{} for idx, column := range columns { @@ -53,7 +53,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { for initialized || rows.Next() { initialized = false db.RowsAffected++ - rows.Scan(dest) + db.AddError(rows.Scan(dest)) } default: switch db.Statement.ReflectValue.Kind() { @@ -96,9 +96,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } db.RowsAffected++ - if err := rows.Scan(values...); err != nil { - db.AddError(err) - } + db.AddError(rows.Scan(values...)) if isPtr { db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr())) @@ -130,9 +128,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { if initialized || rows.Next() { db.RowsAffected++ - if err := rows.Scan(values...); err != nil { - db.AddError(err) - } + db.AddError(rows.Scan(values...)) } } } diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 38ffc919..88e7e12e 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -103,6 +103,9 @@ func (data *EncryptedData) Scan(value interface{}) error { *data = b[3:] return nil + } else if s, ok := value.(string); ok { + *data = []byte(s)[3:] + return nil } return errors.New("Bytes expected") diff --git a/tests/utils.go b/tests/utils.go index 0a33edee..0add8143 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -121,6 +121,11 @@ func AssertEqual(t *testing.T, got, expect interface{}) { expect = reflect.Indirect(reflect.ValueOf(expect)).Interface() } + if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() { + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) + return + } + if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() isEqual() From 58bc0f51c105bfed6d82549897bda968a1b55adf Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 07:57:13 +0800 Subject: [PATCH 355/881] Fix mssql rename index, has column --- callbacks/update.go | 5 ++--- dialects/mssql/migrator.go | 31 +++++++++++++++++++++++++++++++ dialects/mssql/mssql.go | 10 ---------- tests/tests_all.sh | 4 ++-- 4 files changed, 35 insertions(+), 15 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 623d64fe..cfa8c86b 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -3,7 +3,6 @@ package callbacks import ( "reflect" "sort" - "time" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" @@ -152,7 +151,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if !stmt.DisableUpdateTime { for _, field := range stmt.Schema.FieldsByDBName { if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { - now := time.Now() + now := stmt.DB.NowFunc() set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) assignValue(field, now) } @@ -168,7 +167,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { value, isZero := field.ValueOf(stmt.ReflectValue) if !stmt.DisableUpdateTime { if field.AutoUpdateTime > 0 { - value = time.Now() + value = stmt.DB.NowFunc() isZero = false } } diff --git a/dialects/mssql/migrator.go b/dialects/mssql/migrator.go index 4707a637..d1abd0e9 100644 --- a/dialects/mssql/migrator.go +++ b/dialects/mssql/migrator.go @@ -1,7 +1,10 @@ package mssql import ( + "fmt" + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/migrator" ) @@ -20,6 +23,24 @@ func (m Migrator) HasTable(value interface{}) bool { return count > 0 } +func (m Migrator) HasColumn(value interface{}, field string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + name := field + if field := stmt.Schema.LookUpField(field); field != nil { + name = field.DBName + } + + return m.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", + currentDatabase, stmt.Table, name, + ).Row().Scan(&count) + }) + + return count > 0 +} + func (m Migrator) HasIndex(value interface{}, name string) bool { var count int m.RunWithValue(value, func(stmt *gorm.Statement) error { @@ -35,6 +56,16 @@ func (m Migrator) HasIndex(value interface{}, name string) bool { return count > 0 } +func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + + return m.DB.Exec( + "sp_rename @objname = ?, @newname = ?, @objtype = 'INDEX';", + fmt.Sprintf("%s.%s", stmt.Table, oldName), clause.Column{Name: newName}, + ).Error + }) +} + func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 8e309faf..3828c546 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -2,7 +2,6 @@ package mssql import ( "database/sql" - "database/sql/driver" "fmt" "regexp" "strconv" @@ -81,15 +80,6 @@ func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { var numericPlaceholder = regexp.MustCompile("@p(\\d+)") func (dialector Dialector) Explain(sql string, vars ...interface{}) string { - for idx, v := range vars { - if valuer, ok := v.(driver.Valuer); ok { - v, _ = valuer.Value() - } - - if v, ok := v.(bool); ok { - vars[idx] = strconv.FormatBool(v) - } - } return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) } diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 9435b2b1..243af787 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -11,9 +11,9 @@ for dialect in "${dialects[@]}" ; do if [ "$GORM_VERBOSE" = "" ] then - DEBUG=false GORM_DIALECT=${dialect} go test -race ./... + DEBUG=false GORM_DIALECT=${dialect} go test -race -count=1 ./... else - DEBUG=false GORM_DIALECT=${dialect} go test -race -v ./... + DEBUG=false GORM_DIALECT=${dialect} go test -race -count=1 -v ./... fi fi done From 24285060d5d37898700802f567a9eaa1f875827e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 08:58:08 +0800 Subject: [PATCH 356/881] Fix RenameColumn for mssql, DropColumn for sqlite --- dialects/mssql/migrator.go | 17 ++++++++++++++ dialects/sqlite/migrator.go | 45 ++++++++++++++++++++++++++++++++++--- gorm.go | 2 +- migrator/migrator.go | 33 +++++++++++++++------------ tests/migrate_test.go | 28 +++++++++++++++++++---- tests/utils.go | 4 ++-- 6 files changed, 105 insertions(+), 24 deletions(-) diff --git a/dialects/mssql/migrator.go b/dialects/mssql/migrator.go index d1abd0e9..42a6b9b9 100644 --- a/dialects/mssql/migrator.go +++ b/dialects/mssql/migrator.go @@ -41,6 +41,23 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { return count > 0 } +func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(oldName); field != nil { + oldName = field.DBName + } + + if field := stmt.Schema.LookUpField(newName); field != nil { + newName = field.DBName + } + + return m.DB.Exec( + "sp_rename @objname = ?, @newname = ?, @objtype = 'COLUMN';", + fmt.Sprintf("%s.%s", stmt.Table, oldName), clause.Column{Name: newName}, + ).Error + }) +} + func (m Migrator) HasIndex(value interface{}, name string) bool { var count int m.RunWithValue(value, func(stmt *gorm.Statement) error { diff --git a/dialects/sqlite/migrator.go b/dialects/sqlite/migrator.go index 5f3671b4..e36dc5e7 100644 --- a/dialects/sqlite/migrator.go +++ b/dialects/sqlite/migrator.go @@ -2,6 +2,7 @@ package sqlite import ( "fmt" + "regexp" "strings" "github.com/jinzhu/gorm" @@ -22,11 +23,10 @@ func (m Migrator) HasTable(value interface{}) bool { return count > 0 } -func (m Migrator) HasColumn(value interface{}, field string) bool { +func (m Migrator) HasColumn(value interface{}, name string) bool { var count int m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - name := field - if field := stmt.Schema.LookUpField(field); field != nil { + if field := stmt.Schema.LookUpField(name); field != nil { name = field.DBName } @@ -38,6 +38,45 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { return count > 0 } +func (m Migrator) DropColumn(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(name); field != nil { + name = field.DBName + } + + var ( + createSQL string + newTableName = stmt.Table + "__temp" + ) + + m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL) + + if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil { + tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ") + if err != nil { + return err + } + + createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) + createSQL = reg.ReplaceAllString(createSQL, "") + + var columns []string + columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + for _, columnType := range columnTypes { + if columnType.Name() != name { + columns = append(columns, fmt.Sprintf("`%v`", columnType.Name())) + } + } + + createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table) + + return m.DB.Exec(createSQL).Error + } else { + return err + } + }) +} + func (m Migrator) CreateConstraint(interface{}, string) error { return gorm.ErrNotImplemented } diff --git a/gorm.go b/gorm.go index 9adc0858..6b2a6d75 100644 --- a/gorm.go +++ b/gorm.go @@ -66,7 +66,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { } if config.NowFunc == nil { - config.NowFunc = func() time.Time { return time.Now().Local().Round(time.Second) } + config.NowFunc = func() time.Time { return time.Now().Local() } } if dialector != nil { diff --git a/migrator/migrator.go b/migrator/migrator.go index 8f35cbea..d41646f4 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -243,14 +243,15 @@ func (m Migrator) AddColumn(value interface{}, field string) error { }) } -func (m Migrator) DropColumn(value interface{}, field string) error { +func (m Migrator) DropColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(field); field != nil { - return m.DB.Exec( - "ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, - ).Error + if field := stmt.Schema.LookUpField(name); field != nil { + name = field.DBName } - return fmt.Errorf("failed to look up field with name: %s", field) + + return m.DB.Exec( + "ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: name}, + ).Error }) } @@ -284,16 +285,20 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { return count > 0 } -func (m Migrator) RenameColumn(value interface{}, oldName, field string) error { +func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(field); field != nil { - oldName = m.DB.NamingStrategy.ColumnName(stmt.Table, oldName) - return m.DB.Exec( - "ALTER TABLE ? RENAME COLUMN ? TO ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: field.DBName}, - ).Error + if field := stmt.Schema.LookUpField(oldName); field != nil { + oldName = field.DBName } - return fmt.Errorf("failed to look up field with name: %s", field) + + if field := stmt.Schema.LookUpField(newName); field != nil { + newName = field.DBName + } + + return m.DB.Exec( + "ALTER TABLE ? RENAME COLUMN ? TO ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, + ).Error }) } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 00025c58..2252d09d 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -98,18 +98,38 @@ func TestColumns(t *testing.T) { } if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { - t.Errorf("Failed to add column, got %v", err) + t.Fatalf("Failed to add column, got %v", err) } if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { - t.Errorf("Failed to find added column") + t.Fatalf("Failed to find added column") } if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "NewName"); err != nil { - t.Errorf("Failed to add column, got %v", err) + t.Fatalf("Failed to add column, got %v", err) } if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { - t.Errorf("Found deleted column") + t.Fatalf("Found deleted column") + } + + if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { + t.Fatalf("Failed to found renamed column") + } + + if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "new_new_name"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { + t.Fatalf("Found deleted column") } } diff --git a/tests/utils.go b/tests/utils.go index 0add8143..7cc6d2bc 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -88,8 +88,8 @@ func AssertEqual(t *testing.T, got, expect interface{}) { if curTime, ok := got.(time.Time); ok { format := "2006-01-02T15:04:05Z07:00" - if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) { - t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time).Round(time.Second).Format(format), curTime.Round(time.Second).Format(format)) + if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) && curTime.Truncate(time.Second).Format(format) != expect.(time.Time).Truncate(time.Second).Format(format) { + t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime) } } else if fmt.Sprint(got) != fmt.Sprint(expect) { t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) From d81179557dfcc64a011e8198b3f2febe8a0c9a39 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 10:24:49 +0800 Subject: [PATCH 357/881] Add tests for Tables --- dialects/mssql/migrator.go | 30 +++++++++++++++++++ migrator.go | 2 +- migrator/migrator.go | 27 +++++++++++++++-- tests/migrate_test.go | 59 +++++++++++++++++++++++++++++--------- 4 files changed, 102 insertions(+), 16 deletions(-) diff --git a/dialects/mssql/migrator.go b/dialects/mssql/migrator.go index 42a6b9b9..b334268e 100644 --- a/dialects/mssql/migrator.go +++ b/dialects/mssql/migrator.go @@ -23,6 +23,36 @@ func (m Migrator) HasTable(value interface{}) bool { return count > 0 } +func (m Migrator) RenameTable(oldName, newName interface{}) error { + var oldTable, newTable string + if v, ok := oldName.(string); ok { + oldTable = v + } else { + stmt := &gorm.Statement{DB: m.DB} + if err := stmt.Parse(oldName); err == nil { + oldTable = stmt.Table + } else { + return err + } + } + + if v, ok := newName.(string); ok { + newTable = v + } else { + stmt := &gorm.Statement{DB: m.DB} + if err := stmt.Parse(newName); err == nil { + newTable = stmt.Table + } else { + return err + } + } + + return m.DB.Exec( + "sp_rename @objname = ?, @newname = ?;", + clause.Table{Name: oldTable}, clause.Table{Name: newTable}, + ).Error +} + func (m Migrator) HasColumn(value interface{}, field string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { diff --git a/migrator.go b/migrator.go index d90c362f..865a08ef 100644 --- a/migrator.go +++ b/migrator.go @@ -27,7 +27,7 @@ type Migrator interface { CreateTable(dst ...interface{}) error DropTable(dst ...interface{}) error HasTable(dst interface{}) bool - RenameTable(oldName, newName string) error + RenameTable(oldName, newName interface{}) error // Columns AddColumn(dst interface{}, field string) error diff --git a/migrator/migrator.go b/migrator/migrator.go index d41646f4..f22d6d2c 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -227,8 +227,31 @@ func (m Migrator) HasTable(value interface{}) bool { return count > 0 } -func (m Migrator) RenameTable(oldName, newName string) error { - return m.DB.Exec("RENAME TABLE ? TO ?", oldName, newName).Error +func (m Migrator) RenameTable(oldName, newName interface{}) error { + var oldTable, newTable string + if v, ok := oldName.(string); ok { + oldTable = v + } else { + stmt := &gorm.Statement{DB: m.DB} + if err := stmt.Parse(oldName); err == nil { + oldTable = stmt.Table + } else { + return err + } + } + + if v, ok := newName.(string); ok { + newTable = v + } else { + stmt := &gorm.Statement{DB: m.DB} + if err := stmt.Parse(newName); err == nil { + newTable = stmt.Table + } else { + return err + } + } + + return m.DB.Exec("ALTER TABLE ? RENAME TO ?", clause.Table{Name: oldTable}, clause.Table{Name: newTable}).Error } func (m Migrator) AddColumn(value interface{}, field string) error { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 2252d09d..748ee816 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -15,20 +15,53 @@ func TestMigrate(t *testing.T) { rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) if err := DB.Migrator().DropTable(allModels...); err != nil { - t.Errorf("Failed to drop table, got error %v", err) + t.Fatalf("Failed to drop table, got error %v", err) } if err := DB.AutoMigrate(allModels...); err != nil { - t.Errorf("Failed to auto migrate, but got error %v", err) + t.Fatalf("Failed to auto migrate, but got error %v", err) } for _, m := range allModels { if !DB.Migrator().HasTable(m) { - t.Errorf("Failed to create table for %#v", m) + t.Fatalf("Failed to create table for %#v", m) } } } +func TestTable(t *testing.T) { + type TableStruct struct { + gorm.Model + Name string + } + + DB.Migrator().DropTable(&TableStruct{}) + DB.AutoMigrate(&TableStruct{}) + + if !DB.Migrator().HasTable(&TableStruct{}) { + t.Fatalf("should found created table") + } + + type NewTableStruct struct { + gorm.Model + Name string + } + + if err := DB.Migrator().RenameTable(&TableStruct{}, &NewTableStruct{}); err != nil { + t.Fatalf("Failed to rename table, got error %v", err) + } + + if !DB.Migrator().HasTable("new_table_structs") { + t.Fatal("should found renamed table") + } + + DB.Migrator().DropTable("new_table_structs") + + if DB.Migrator().HasTable(&NewTableStruct{}) { + t.Fatal("should not found droped table") + } +} + func TestIndexes(t *testing.T) { type IndexStruct struct { gorm.Model @@ -39,43 +72,43 @@ func TestIndexes(t *testing.T) { DB.AutoMigrate(&IndexStruct{}) if err := DB.Migrator().DropIndex(&IndexStruct{}, "Name"); err != nil { - t.Errorf("Failed to drop index for user's name, got err %v", err) + t.Fatalf("Failed to drop index for user's name, got err %v", err) } if err := DB.Migrator().CreateIndex(&IndexStruct{}, "Name"); err != nil { - t.Errorf("Got error when tried to create index: %+v", err) + t.Fatalf("Got error when tried to create index: %+v", err) } if !DB.Migrator().HasIndex(&IndexStruct{}, "Name") { - t.Errorf("Failed to find index for user's name") + t.Fatalf("Failed to find index for user's name") } if err := DB.Migrator().DropIndex(&IndexStruct{}, "Name"); err != nil { - t.Errorf("Failed to drop index for user's name, got err %v", err) + t.Fatalf("Failed to drop index for user's name, got err %v", err) } if DB.Migrator().HasIndex(&IndexStruct{}, "Name") { - t.Errorf("Should not find index for user's name after delete") + t.Fatalf("Should not find index for user's name after delete") } if err := DB.Migrator().CreateIndex(&IndexStruct{}, "Name"); err != nil { - t.Errorf("Got error when tried to create index: %+v", err) + t.Fatalf("Got error when tried to create index: %+v", err) } if err := DB.Migrator().RenameIndex(&IndexStruct{}, "idx_index_structs_name", "idx_users_name_1"); err != nil { - t.Errorf("no error should happen when rename index, but got %v", err) + t.Fatalf("no error should happen when rename index, but got %v", err) } if !DB.Migrator().HasIndex(&IndexStruct{}, "idx_users_name_1") { - t.Errorf("Should find index for user's name after rename") + t.Fatalf("Should find index for user's name after rename") } if err := DB.Migrator().DropIndex(&IndexStruct{}, "idx_users_name_1"); err != nil { - t.Errorf("Failed to drop index for user's name, got err %v", err) + t.Fatalf("Failed to drop index for user's name, got err %v", err) } if DB.Migrator().HasIndex(&IndexStruct{}, "idx_users_name_1") { - t.Errorf("Should not find index for user's name after delete") + t.Fatalf("Should not find index for user's name after delete") } } From 536e4d34b078ea812521e209be5ac304848559e9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 10:38:01 +0800 Subject: [PATCH 358/881] Add test for AlterColumn --- dialects/mssql/migrator.go | 12 ++++++++++++ dialects/mysql/migrator.go | 4 ++-- dialects/postgres/postgres.go | 2 +- dialects/sqlite/migrator.go | 36 +++++++++++++++++++++++++++++++++++ migrator/migrator.go | 2 +- tests/migrate_test.go | 26 +++++++++++++++++++++++++ 6 files changed, 78 insertions(+), 4 deletions(-) diff --git a/dialects/mssql/migrator.go b/dialects/mssql/migrator.go index b334268e..1de49ae9 100644 --- a/dialects/mssql/migrator.go +++ b/dialects/mssql/migrator.go @@ -71,6 +71,18 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { return count > 0 } +func (m Migrator) AlterColumn(value interface{}, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(field); field != nil { + return m.DB.Exec( + "ALTER TABLE ? ALTER COLUMN ? ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), + ).Error + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(oldName); field != nil { diff --git a/dialects/mysql/migrator.go b/dialects/mysql/migrator.go index 2c11af94..74c11277 100644 --- a/dialects/mysql/migrator.go +++ b/dialects/mysql/migrator.go @@ -16,8 +16,8 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { return m.DB.Exec( - "ALTER TABLE ? MODIFY COLUMN ? TYPE ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType}, + "ALTER TABLE ? MODIFY COLUMN ? ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), ).Error } return fmt.Errorf("failed to look up field with name: %s", field) diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 73a19e9d..db559b9d 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -89,7 +89,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { } return "text" case schema.Time: - return "timestamp with time zone" + return "timestamptz" case schema.Bytes: return "bytea" } diff --git a/dialects/sqlite/migrator.go b/dialects/sqlite/migrator.go index e36dc5e7..252e4183 100644 --- a/dialects/sqlite/migrator.go +++ b/dialects/sqlite/migrator.go @@ -38,6 +38,42 @@ func (m Migrator) HasColumn(value interface{}, name string) bool { return count > 0 } +func (m Migrator) AlterColumn(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(name); field != nil { + var ( + createSQL string + newTableName = stmt.Table + "__temp" + ) + + m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL) + + if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil { + tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ") + if err != nil { + return err + } + + createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) + createSQL = reg.ReplaceAllString(createSQL, "?") + + var columns []string + columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + for _, columnType := range columnTypes { + columns = append(columns, fmt.Sprintf("`%v`", columnType.Name())) + } + + createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table) + return m.DB.Exec(createSQL, m.FullDataTypeOf(field)).Error + } else { + return err + } + } else { + return fmt.Errorf("failed to alter field with name %v", name) + } + }) +} + func (m Migrator) DropColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(name); field != nil { diff --git a/migrator/migrator.go b/migrator/migrator.go index f22d6d2c..5a06beb1 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -283,7 +283,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { if field := stmt.Schema.LookUpField(field); field != nil { return m.DB.Exec( "ALTER TABLE ? ALTER COLUMN ? TYPE ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)}, + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), ).Error } return fmt.Errorf("failed to look up field with name: %s", field) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 748ee816..957db8d6 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -2,6 +2,7 @@ package tests_test import ( "math/rand" + "strings" "testing" "time" @@ -124,6 +125,31 @@ func TestColumns(t *testing.T) { t.Errorf("Failed to migrate, got %v", err) } + type ColumnStruct2 struct { + gorm.Model + Name string `gorm:"size:100"` + } + + if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct2{}, "Name"); err != nil { + t.Fatalf("no error should happend when alter column, but got %v", err) + } + + if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { + t.Fatalf("no error should returns for ColumnTypes") + } else { + stmt := &gorm.Statement{DB: DB} + stmt.Parse(&ColumnStruct2{}) + + for _, columnType := range columnTypes { + if columnType.Name() == "name" { + dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) + if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { + t.Errorf("column type should be correct, name: %v, length: %v, expects: %v", columnType.Name(), columnType.DatabaseTypeName(), dataType) + } + } + } + } + type NewColumnStruct struct { gorm.Model Name string From 1e7eb12cbad363e6b1511fd6a3b9a3314d077ddb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 11:19:45 +0800 Subject: [PATCH 359/881] Test empty struct --- callbacks/create.go | 1 + dialects/mysql/mysql.go | 7 +++++++ tests/create_test.go | 27 +++++++++++++++++++++++---- 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 18f25c9a..ac63c89b 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -153,6 +153,7 @@ func CreateWithReturning(db *gorm.DB) { } if rows.Next() { + db.RowsAffected++ err = rows.Scan(values...) } } diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index 23525ed7..baeb79c7 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -60,6 +60,13 @@ func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { c.Build(builder) } }, + "VALUES": func(c clause.Clause, builder clause.Builder) { + if values, ok := c.Expression.(clause.Values); ok && len(values.Columns) == 0 { + builder.WriteString("VALUES()") + return + } + c.Build(builder) + }, } } diff --git a/tests/create_test.go b/tests/create_test.go index 5b859e99..43e2c718 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -9,8 +9,10 @@ import ( func TestCreate(t *testing.T) { var user = *GetUser("create", Config{}) - if err := DB.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) + if results := DB.Create(&user); results.Error != nil { + t.Fatalf("errors happened when create: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) } if user.ID == 0 { @@ -68,8 +70,10 @@ func TestBulkCreateWithAssociations(t *testing.T) { *GetUser("bulk_8", Config{Account: false, Pets: 0, Toys: 0, Company: false, Manager: false, Team: 0, Languages: 0, Friends: 0}), } - if err := DB.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) + if results := DB.Create(&users); results.Error != nil { + t.Fatalf("errors happened when create: %v", results.Error) + } else if results.RowsAffected != int64(len(users)) { + t.Fatalf("rows affected expects: %v, got %v", len(users), results.RowsAffected) } var userIDs []uint @@ -182,3 +186,18 @@ func TestPolymorphicHasOne(t *testing.T) { } }) } + +func TestCreateEmptyStrut(t *testing.T) { + type EmptyStruct struct { + ID uint + } + DB.Migrator().DropTable(&EmptyStruct{}) + + if err := DB.AutoMigrate(&EmptyStruct{}); err != nil { + t.Errorf("no error should happen when auto migrate, but got %v", err) + } + + if err := DB.Create(&EmptyStruct{}).Error; err != nil { + t.Errorf("No error should happen when creating user, but got %v", err) + } +} From b3b19a55773b2c4a004c469960dcac78eb068a96 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 11:34:59 +0800 Subject: [PATCH 360/881] Test Override NowFunc --- gorm.go | 24 +++++++++--------------- soft_delete.go | 3 +-- tests/create_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 17 deletions(-) diff --git a/gorm.go b/gorm.go index 6b2a6d75..70751cb3 100644 --- a/gorm.go +++ b/gorm.go @@ -30,9 +30,8 @@ type Config struct { // Dialector database dialector Dialector - statementPool sync.Pool - callbacks *callbacks - cacheStore *sync.Map + callbacks *callbacks + cacheStore *sync.Map } // DB GORM DB definition @@ -77,17 +76,6 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { config.cacheStore = &sync.Map{} } - config.statementPool = sync.Pool{ - New: func() interface{} { - return &Statement{ - DB: db, - ConnPool: db.ConnPool, - Clauses: map[string]clause.Clause{}, - Context: context.Background(), - } - }, - } - db = &DB{ Config: config, clone: true, @@ -179,7 +167,13 @@ func (db *DB) AddError(err error) error { func (db *DB) getInstance() *DB { if db.clone { - stmt := db.Config.statementPool.Get().(*Statement) + stmt := &Statement{ + DB: db, + ConnPool: db.ConnPool, + Clauses: map[string]clause.Clause{}, + Context: context.Background(), + } + if db.Statement != nil { stmt.Context = db.Statement.Context } diff --git a/soft_delete.go b/soft_delete.go index 138c9c63..09cfff37 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -4,7 +4,6 @@ import ( "database/sql" "database/sql/driver" "reflect" - "time" "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/schema" @@ -55,7 +54,7 @@ func (SoftDeleteClause) MergeClause(*clause.Clause) { func (SoftDeleteClause) ModifyStatement(stmt *Statement) { if stmt.SQL.String() == "" { - stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: time.Now()}}) + stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: stmt.DB.NowFunc()}}) if stmt.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) diff --git a/tests/create_test.go b/tests/create_test.go index 43e2c718..a3b3b598 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -1,9 +1,13 @@ package tests_test import ( + "fmt" "testing" + "time" + "github.com/jinzhu/gorm" . "github.com/jinzhu/gorm/tests" + "github.com/jinzhu/now" ) func TestCreate(t *testing.T) { @@ -201,3 +205,43 @@ func TestCreateEmptyStrut(t *testing.T) { t.Errorf("No error should happen when creating user, but got %v", err) } } + +func TestCreateWithExistingTimestamp(t *testing.T) { + user := User{Name: "CreateUserExistingTimestamp"} + curTime := now.MustParse("2016-01-01") + user.CreatedAt = curTime + user.UpdatedAt = curTime + DB.Save(&user) + + AssertEqual(t, user.CreatedAt, curTime) + AssertEqual(t, user.UpdatedAt, curTime) + + var newUser User + DB.First(&newUser, user.ID) + + AssertEqual(t, newUser.CreatedAt, curTime) + AssertEqual(t, newUser.UpdatedAt, curTime) +} + +func TestCreateWithNowFuncOverride(t *testing.T) { + user := User{Name: "CreateUserTimestampOverride"} + curTime := now.MustParse("2016-01-01") + + NEW := DB.Session(&gorm.Session{ + NowFunc: func() time.Time { + fmt.Println("11iiiin") + return curTime + }, + }) + + NEW.Save(&user) + + AssertEqual(t, user.CreatedAt, curTime) + AssertEqual(t, user.UpdatedAt, curTime) + + var newUser User + NEW.First(&newUser, user.ID) + + AssertEqual(t, newUser.CreatedAt, curTime) + AssertEqual(t, newUser.UpdatedAt, curTime) +} From 1546f8a4a19d598ff5b16aefa52acf36bd6b3d4e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 12:52:49 +0800 Subject: [PATCH 361/881] Test CreateWithNoGORMPrimayKey --- callbacks/create.go | 2 +- dialects/mssql/create.go | 46 ++++++++++++++++++++++-------------- migrator/migrator.go | 2 +- tests/create_test.go | 18 ++++++++++++++ tests/scanner_valuer_test.go | 23 ++++++++++++++++++ 5 files changed, 71 insertions(+), 20 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index ac63c89b..f558d7ae 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -64,7 +64,7 @@ func Create(config *Config) func(db *gorm.DB) { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { - if db.Statement.Schema != nil { + if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { if insertID, err := result.LastInsertId(); err == nil { switch db.Statement.ReflectValue.Kind() { diff --git a/dialects/mssql/create.go b/dialects/mssql/create.go index c85997fb..ebdeeab0 100644 --- a/dialects/mssql/create.go +++ b/dialects/mssql/create.go @@ -68,26 +68,30 @@ func Create(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - for rows.Next() { - // for idx, field := range fields { - // values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() - // } + if len(db.Statement.Schema.PrimaryFields) > 0 { + values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) - values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() - if err := rows.Scan(values); err != nil { - db.AddError(err) + for rows.Next() { + for idx, field := range db.Statement.Schema.PrimaryFields { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + } + + db.RowsAffected++ + db.AddError(rows.Scan(values...)) } - db.RowsAffected++ } case reflect.Struct: - // for idx, field := range fields { - // values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() - // } - values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + if len(db.Statement.Schema.PrimaryFields) > 0 { + values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) - if rows.Next() { - db.RowsAffected++ - db.AddError(rows.Scan(values)) + for idx, field := range db.Statement.Schema.PrimaryFields { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + } + + if rows.Next() { + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + } } } } else { @@ -177,8 +181,14 @@ func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) { } func outputInserted(db *gorm.DB) { - if db.Statement.Schema.PrioritizedPrimaryField != nil { - db.Statement.WriteString(" OUTPUT INSERTED.") - db.Statement.WriteQuoted(db.Statement.Schema.PrioritizedPrimaryField.DBName) + if len(db.Statement.Schema.PrimaryFields) > 0 { + db.Statement.WriteString(" OUTPUT ") + for idx, field := range db.Statement.Schema.PrimaryFields { + if idx > 0 { + db.Statement.WriteString(",") + } + db.Statement.WriteString(" INSERTED.") + db.Statement.AddVar(db.Statement, clause.Column{Name: field.DBName}) + } } } diff --git a/migrator/migrator.go b/migrator/migrator.go index 5a06beb1..4e0f28b5 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -149,7 +149,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { createTableSQL += "," } - if !hasPrimaryKeyInDataType { + if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { createTableSQL += "PRIMARY KEY ?," primaryKeys := []interface{}{} for _, field := range stmt.Schema.PrimaryFields { diff --git a/tests/create_test.go b/tests/create_test.go index a3b3b598..6421ca34 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -245,3 +245,21 @@ func TestCreateWithNowFuncOverride(t *testing.T) { AssertEqual(t, newUser.CreatedAt, curTime) AssertEqual(t, newUser.UpdatedAt, curTime) } + +func TestCreateWithNoGORMPrimayKey(t *testing.T) { + type JoinTable struct { + UserID uint + FriendID uint + } + + DB.Migrator().DropTable(&JoinTable{}) + if err := DB.AutoMigrate(&JoinTable{}); err != nil { + t.Errorf("no error should happen when auto migrate, but got %v", err) + } + + jt := JoinTable{UserID: 1, FriendID: 2} + err := DB.Create(&jt).Error + if err != nil { + t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) + } +} diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 88e7e12e..04c91ab2 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -34,6 +34,7 @@ func TestScannerValuer(t *testing.T) { {"name1", "value1"}, {"name2", "value2"}, }, + Role: Role{Name: "admin"}, } if err := DB.Create(&data).Error; err != nil { @@ -91,6 +92,7 @@ type ScannerValuerStruct struct { Num Num Strings StringsSlice Structs StructsSlice + Role Role } type EncryptedData []byte @@ -176,3 +178,24 @@ func (l *StructsSlice) Scan(input interface{}) error { return errors.New("not supported") } } + +type Role struct { + Name string `gorm:"size:256"` +} + +func (role *Role) Scan(value interface{}) error { + if b, ok := value.([]uint8); ok { + role.Name = string(b) + } else { + role.Name = value.(string) + } + return nil +} + +func (role Role) Value() (driver.Value, error) { + return role.Name, nil +} + +func (role Role) IsAdmin() bool { + return role.Name == "admin" +} From 9d3e929790141fd7604a83d2b5e14f2e79427b7e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 13:34:53 +0800 Subject: [PATCH 362/881] Test Select, Omit with Create --- callbacks/helper.go | 2 +- tests/create_test.go | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/callbacks/helper.go b/callbacks/helper.go index 8da74690..818d9c2c 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -31,7 +31,7 @@ func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate boo // omit columns for _, omit := range stmt.Omits { - if field := stmt.Schema.LookUpField(omit); field != nil { + if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { results[field.DBName] = false } else { results[omit] = false diff --git a/tests/create_test.go b/tests/create_test.go index 6421ca34..4b9694b6 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -1,7 +1,6 @@ package tests_test import ( - "fmt" "testing" "time" @@ -229,7 +228,6 @@ func TestCreateWithNowFuncOverride(t *testing.T) { NEW := DB.Session(&gorm.Session{ NowFunc: func() time.Time { - fmt.Println("11iiiin") return curTime }, }) @@ -263,3 +261,34 @@ func TestCreateWithNoGORMPrimayKey(t *testing.T) { t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) } } + +func TestSelectWithCreate(t *testing.T) { + user := *GetUser("select_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Select("Account", "Toys", "Manager", "ManagerID", "Languages", "Name", "CreatedAt", "UpdatedAt", "Age", "Active").Create(&user) + + var user2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&user2, user.ID) + + user.Birthday = nil + user.Pets = nil + user.Company = Company{} + user.Team = nil + user.Friends = nil + + CheckUser(t, user2, user) +} + +func TestOmitWithCreate(t *testing.T) { + user := *GetUser("omit_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Omit("Account", "Toys", "Manager", "Birthday").Create(&user) + + var user2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&user2, user.ID) + + user.Birthday = nil + user.Account = Account{} + user.Toys = nil + user.Manager = nil + + CheckUser(t, user2, user) +} From 6d555ef8d586a3101131407c60fbf10ae3f3557d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 14:18:07 +0800 Subject: [PATCH 363/881] Test embedded struct --- schema/field.go | 8 +++ tests/embedded_struct_test.go | 105 ++++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+) create mode 100644 tests/embedded_struct_test.go diff --git a/schema/field.go b/schema/field.go index 57ba3ac7..f52dd6a6 100644 --- a/schema/field.go +++ b/schema/field.go @@ -298,6 +298,14 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.DBName = prefix + ef.DBName } + if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { + ef.PrimaryKey = true + } else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { + ef.PrimaryKey = true + } else { + ef.PrimaryKey = false + } + for k, v := range field.TagSettings { ef.TagSettings[k] = v } diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go new file mode 100644 index 00000000..af003786 --- /dev/null +++ b/tests/embedded_struct_test.go @@ -0,0 +1,105 @@ +package tests_test + +import ( + "testing" + + "github.com/jinzhu/gorm" + . "github.com/jinzhu/gorm/tests" +) + +func TestEmbeddedStruct(t *testing.T) { + type BasePost struct { + Id int64 + Title string + URL string + } + + type Author struct { + ID string + Name string + Email string + } + + type HNPost struct { + BasePost + Author `gorm:"EmbeddedPrefix:user_"` // Embedded struct + Upvotes int32 + } + + type EngadgetPost struct { + BasePost BasePost `gorm:"Embedded"` + Author Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct + ImageUrl string + } + + DB.Migrator().DropTable(&HNPost{}, &EngadgetPost{}) + if err := DB.Migrator().AutoMigrate(&HNPost{}, &EngadgetPost{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + for _, name := range []string{"author_id", "author_name", "author_email"} { + if !DB.Migrator().HasColumn(&EngadgetPost{}, name) { + t.Errorf("should has prefixed column %v", name) + } + } + + stmt := gorm.Statement{DB: DB} + if err := stmt.Parse(&EngadgetPost{}); err != nil { + t.Fatalf("failed to parse embedded struct") + } else if len(stmt.Schema.PrimaryFields) != 1 { + t.Errorf("should have only one primary field with embedded struct, but got %v", len(stmt.Schema.PrimaryFields)) + } + + for _, name := range []string{"user_id", "user_name", "user_email"} { + if !DB.Migrator().HasColumn(&HNPost{}, name) { + t.Errorf("should has prefixed column %v", name) + } + } + + // save embedded struct + DB.Save(&HNPost{BasePost: BasePost{Title: "news"}}) + DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}}) + var news HNPost + if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { + t.Errorf("no error should happen when query with embedded struct, but got %v", err) + } else if news.Title != "hn_news" { + t.Errorf("embedded struct's value should be scanned correctly") + } + + DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) + var egNews EngadgetPost + if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { + t.Errorf("no error should happen when query with embedded struct, but got %v", err) + } else if egNews.BasePost.Title != "engadget_news" { + t.Errorf("embedded struct's value should be scanned correctly") + } +} + +func TestEmbeddedPointerTypeStruct(t *testing.T) { + type BasePost struct { + Id int64 + Title string + URL string + } + + type HNPost struct { + *BasePost + Upvotes int32 + } + + DB.Migrator().DropTable(&HNPost{}) + if err := DB.Migrator().AutoMigrate(&HNPost{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}}) + + var hnPost HNPost + if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil { + t.Errorf("No error should happen when find embedded pointer type, but got %v", err) + } + + if hnPost.Title != "embedded_pointer_type" { + t.Errorf("Should find correct value for embedded pointer type") + } +} From aa959ec38309082cbf07efb80d68f518296a246a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 14:41:45 +0800 Subject: [PATCH 364/881] Test NamedPolymorphic --- callbacks/preload.go | 4 +- tests/named_polymorphic_test.go | 146 ++++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 2 deletions(-) create mode 100644 tests/named_polymorphic_test.go diff --git a/callbacks/preload.go b/callbacks/preload.go index cfea4f94..a77db2b1 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -34,7 +34,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { joinForeignFields = append(joinForeignFields, ref.ForeignKey) foreignFields = append(foreignFields, ref.PrimaryKey) } else if ref.PrimaryValue != "" { - tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) } else { joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey) relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) @@ -76,7 +76,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { relForeignFields = append(relForeignFields, ref.ForeignKey) foreignFields = append(foreignFields, ref.PrimaryKey) } else if ref.PrimaryValue != "" { - tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) } else { relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) relForeignFields = append(relForeignFields, ref.PrimaryKey) diff --git a/tests/named_polymorphic_test.go b/tests/named_polymorphic_test.go new file mode 100644 index 00000000..7af548a4 --- /dev/null +++ b/tests/named_polymorphic_test.go @@ -0,0 +1,146 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +type Hamster struct { + Id int + Name string + PreferredToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_preferred"` + OtherToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_other"` +} + +func TestNamedPolymorphic(t *testing.T) { + DB.AutoMigrate(&Hamster{}) + + hamster := Hamster{Name: "Mr. Hammond", PreferredToy: Toy{Name: "bike"}, OtherToy: Toy{Name: "treadmill"}} + DB.Save(&hamster) + + hamster2 := Hamster{} + DB.Debug().Preload("PreferredToy").Preload("OtherToy").Find(&hamster2, hamster.Id) + + if hamster2.PreferredToy.ID != hamster.PreferredToy.ID || hamster2.PreferredToy.Name != hamster.PreferredToy.Name { + t.Errorf("Hamster's preferred toy failed to preload") + } + + if hamster2.OtherToy.ID != hamster.OtherToy.ID || hamster2.OtherToy.Name != hamster.OtherToy.Name { + t.Errorf("Hamster's other toy failed to preload") + } + + // clear to omit Toy.ID in count + hamster2.PreferredToy = Toy{} + hamster2.OtherToy = Toy{} + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { + t.Errorf("Hamster's preferred toy count should be 1") + } + + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's other toy count should be 1") + } + + // Query + hamsterToy := Toy{} + DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) + if hamsterToy.Name != hamster.PreferredToy.Name { + t.Errorf("Should find has one polymorphic association") + } + + hamsterToy = Toy{} + DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) + if hamsterToy.Name != hamster.OtherToy.Name { + t.Errorf("Should find has one polymorphic association") + } + + // Append + DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ + Name: "bike 2", + }) + + DB.Model(&hamster).Association("OtherToy").Append(&Toy{ + Name: "treadmill 2", + }) + + hamsterToy = Toy{} + DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) + if hamsterToy.Name != "bike 2" { + t.Errorf("Should update has one polymorphic association with Append") + } + + hamsterToy = Toy{} + DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) + if hamsterToy.Name != "treadmill 2" { + t.Errorf("Should update has one polymorphic association with Append") + } + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { + t.Errorf("Hamster's toys count should be 1 after Append") + } + + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's toys count should be 1 after Append") + } + + // Replace + DB.Model(&hamster).Association("PreferredToy").Replace(&Toy{ + Name: "bike 3", + }) + + DB.Model(&hamster).Association("OtherToy").Replace(&Toy{ + Name: "treadmill 3", + }) + + hamsterToy = Toy{} + DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) + if hamsterToy.Name != "bike 3" { + t.Errorf("Should update has one polymorphic association with Replace") + } + + hamsterToy = Toy{} + DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) + if hamsterToy.Name != "treadmill 3" { + t.Errorf("Should update has one polymorphic association with Replace") + } + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { + t.Errorf("hamster's toys count should be 1 after Replace") + } + + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("hamster's toys count should be 1 after Replace") + } + + // Clear + DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ + Name: "bike 2", + }) + DB.Model(&hamster).Association("OtherToy").Append(&Toy{ + Name: "treadmill 2", + }) + + if DB.Model(&hamster).Association("PreferredToy").Count() != 1 { + t.Errorf("Hamster's toys should be added with Append") + } + + if DB.Model(&hamster).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's toys should be added with Append") + } + + DB.Model(&hamster).Association("PreferredToy").Clear() + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 0 { + t.Errorf("Hamster's preferred toy should be cleared with Clear") + } + + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's other toy should be still available") + } + + DB.Model(&hamster).Association("OtherToy").Clear() + if DB.Model(&hamster).Association("OtherToy").Count() != 0 { + t.Errorf("Hamster's other toy should be cleared with Clear") + } +} From 49310d09746ccf1852d347fa27d00355470400b8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 17:42:21 +0800 Subject: [PATCH 365/881] Test override foreign key, reference --- schema/relationship.go | 83 +++++++++++---- schema/relationship_test.go | 199 +++++++++++++++++++++++++++++++++++ schema/schema_helper_test.go | 2 +- 3 files changed, 264 insertions(+), 20 deletions(-) create mode 100644 schema/relationship_test.go diff --git a/schema/relationship.go b/schema/relationship.go index 3dcef9fc..dffe5988 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -168,31 +168,76 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel joinTableFields []reflect.StructField fieldsMap = map[string]*Field{} ownFieldsMap = map[string]bool{} // fix self join many2many + joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"]) + joinReferences = toColumns(field.TagSettings["JOINREFERENCES"]) ) - for _, s := range []*Schema{schema, relation.FieldSchema} { - for _, primaryField := range s.PrimaryFields { - fieldName := s.Name + primaryField.Name - if _, ok := fieldsMap[fieldName]; ok { - if field.Name != s.Name { - fieldName = inflection.Singular(field.Name) + primaryField.Name - } else { - fieldName = s.Name + primaryField.Name + "Reference" - } - } else { - ownFieldsMap[fieldName] = true - } + ownForeignFields := schema.PrimaryFields + refForeignFields := relation.FieldSchema.PrimaryFields - fieldsMap[fieldName] = primaryField - joinTableFields = append(joinTableFields, reflect.StructField{ - Name: fieldName, - PkgPath: primaryField.StructField.PkgPath, - Type: primaryField.StructField.Type, - Tag: removeSettingFromTag(primaryField.StructField.Tag, "column"), - }) + if len(relation.foreignKeys) > 0 { + ownForeignFields = []*Field{} + for _, foreignKey := range relation.foreignKeys { + if field := schema.LookUpField(foreignKey); field != nil { + ownForeignFields = append(ownForeignFields, field) + } else { + schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey) + return + } } } + if len(relation.primaryKeys) > 0 { + refForeignFields = []*Field{} + for _, foreignKey := range relation.primaryKeys { + if field := relation.FieldSchema.LookUpField(foreignKey); field != nil { + refForeignFields = append(refForeignFields, field) + } else { + schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey) + return + } + } + } + + for idx, ownField := range ownForeignFields { + joinFieldName := schema.Name + ownField.Name + if len(joinForeignKeys) > idx { + joinFieldName = joinForeignKeys[idx] + } + + ownFieldsMap[joinFieldName] = true + fieldsMap[joinFieldName] = ownField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: joinFieldName, + PkgPath: ownField.StructField.PkgPath, + Type: ownField.StructField.Type, + Tag: removeSettingFromTag(ownField.StructField.Tag, "column"), + }) + } + + for idx, relField := range refForeignFields { + joinFieldName := relation.FieldSchema.Name + relField.Name + if len(joinReferences) > idx { + joinFieldName = joinReferences[idx] + } + + if _, ok := ownFieldsMap[joinFieldName]; ok { + if field.Name != relation.FieldSchema.Name { + joinFieldName = inflection.Singular(field.Name) + relField.Name + } else { + joinFieldName += "Reference" + } + } + + fieldsMap[joinFieldName] = relField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: joinFieldName, + PkgPath: relField.StructField.PkgPath, + Type: relField.StructField.Type, + Tag: removeSettingFromTag(relField.StructField.Tag, "column"), + }) + } + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { schema.err = err } diff --git a/schema/relationship_test.go b/schema/relationship_test.go new file mode 100644 index 00000000..41e8c7bd --- /dev/null +++ b/schema/relationship_test.go @@ -0,0 +1,199 @@ +package schema_test + +import ( + "sync" + "testing" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/schema" +) + +func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) { + if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil { + t.Errorf("Failed to parse schema") + } else { + for _, rel := range relations { + checkSchemaRelation(t, s, rel) + } + } +} + +func TestBelongsToOverrideForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:ProfileRefer"` + ProfileRefer int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ID", "Profile", "ProfileRefer", "User", "", false}}, + }) +} + +func TestBelongsToOverrideReferences(t *testing.T) { + type Profile struct { + gorm.Model + Refer string + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:ProfileID;References:Refer"` + ProfileID int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "Profile", "ProfileID", "User", "", false}}, + }) +} + +func TestHasOneOverrideForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:UserRefer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, + }) +} + +func TestHasOneOverrideReferences(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserID uint + } + + type User struct { + gorm.Model + Refer string + Profile Profile `gorm:"ForeignKey:UserID;References:Refer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "User", "UserID", "Profile", "", true}}, + }) +} + +func TestHasManyOverrideForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profile []Profile `gorm:"ForeignKey:UserRefer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasMany, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, + }) +} + +func TestHasManyOverrideReferences(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserID uint + } + + type User struct { + gorm.Model + Refer string + Profile []Profile `gorm:"ForeignKey:UserID;References:Refer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasMany, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "User", "UserID", "Profile", "", true}}, + }) +} + +func TestMany2ManyOverrideForeignKeyAndReferences(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profiles;ForeignKey:Refer;JoinForeignKey:UserReferID;References:UserRefer;JoinReferences:ProfileRefer"` + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + References: []Reference{ + {"Refer", "User", "UserReferID", "user_profiles", "", true}, + {"UserRefer", "Profile", "ProfileRefer", "user_profiles", "", false}, + }, + }) +} + +func TestMany2ManyOverrideForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profiles;ForeignKey:Refer;References:UserRefer"` + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + References: []Reference{ + {"Refer", "User", "UserRefer", "user_profiles", "", true}, + {"UserRefer", "Profile", "ProfileUserRefer", "user_profiles", "", false}, + }, + }) +} + +func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profiles;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + References: []Reference{ + {"ID", "User", "UserReferID", "user_profiles", "", true}, + {"ID", "Profile", "ProfileRefer", "user_profiles", "", false}, + }, + }) +} diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 24920515..b5474fe7 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -127,7 +127,7 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { } if r.FieldSchema.Name != relation.FieldSchema { - t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name) + t.Errorf("schema %v field relation's schema expects %v, but got %v", s, relation.FieldSchema, r.FieldSchema.Name) } if r.Polymorphic != nil { From ae9e4f1dd85c59caaa2707f8040a3ec1ea58bb46 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 17:49:31 +0800 Subject: [PATCH 366/881] Fix change log level --- logger/logger.go | 5 +++-- tests/named_polymorphic_test.go | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/logger/logger.go b/logger/logger.go index ae7c22c9..694adedc 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -100,8 +100,9 @@ type logger struct { // LogMode log mode func (l *logger) LogMode(level LogLevel) Interface { - l.LogLevel = level - return l + newlogger := *l + newlogger.LogLevel = level + return &newlogger } // Info print info diff --git a/tests/named_polymorphic_test.go b/tests/named_polymorphic_test.go index 7af548a4..95b8ec7d 100644 --- a/tests/named_polymorphic_test.go +++ b/tests/named_polymorphic_test.go @@ -20,7 +20,7 @@ func TestNamedPolymorphic(t *testing.T) { DB.Save(&hamster) hamster2 := Hamster{} - DB.Debug().Preload("PreferredToy").Preload("OtherToy").Find(&hamster2, hamster.Id) + DB.Preload("PreferredToy").Preload("OtherToy").Find(&hamster2, hamster.Id) if hamster2.PreferredToy.ID != hamster.PreferredToy.ID || hamster2.PreferredToy.Name != hamster.PreferredToy.Name { t.Errorf("Hamster's preferred toy failed to preload") From 5457fe88e6f8df372aecef18570fa1b62c318ad3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 18:51:43 +0800 Subject: [PATCH 367/881] Test Transactions --- finisher_api.go | 12 +++- gorm.go | 11 ++-- tests/main_test.go | 37 +++++++++++ tests/transaction_test.go | 135 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 188 insertions(+), 7 deletions(-) create mode 100644 tests/main_test.go create mode 100644 tests/transaction_test.go diff --git a/finisher_api.go b/finisher_api.go index f14bcfbe..cfbb98c1 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -267,6 +267,16 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { return } +// Pluck used to query single column from a model as a map +// var ages []int64 +// db.Find(&users).Pluck("age", &ages) +func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = dest + tx.callbacks.Query().Execute(tx) + return +} + func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { tx := db.getInstance() tx.Error = tx.Statement.Parse(dest) @@ -307,7 +317,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { opt = opts[0] } - if tx.Statement.ConnPool, err = beginner.BeginTx(db.Statement.Context, opt); err != nil { + if tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt); err != nil { tx.AddError(err) } } else { diff --git a/gorm.go b/gorm.go index 70751cb3..ac4bff5e 100644 --- a/gorm.go +++ b/gorm.go @@ -167,15 +167,14 @@ func (db *DB) AddError(err error) error { func (db *DB) getInstance() *DB { if db.clone { - stmt := &Statement{ - DB: db, - ConnPool: db.ConnPool, - Clauses: map[string]clause.Clause{}, - Context: context.Background(), - } + stmt := &Statement{DB: db, Clauses: map[string]clause.Clause{}} if db.Statement != nil { stmt.Context = db.Statement.Context + stmt.ConnPool = db.Statement.ConnPool + } else { + stmt.Context = context.Background() + stmt.ConnPool = db.ConnPool } return &DB{Config: db.Config, Statement: stmt} diff --git a/tests/main_test.go b/tests/main_test.go new file mode 100644 index 00000000..da2003d6 --- /dev/null +++ b/tests/main_test.go @@ -0,0 +1,37 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestExceptionsWithInvalidSql(t *testing.T) { + var columns []string + if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + var count1, count2 int64 + DB.Model(&User{}).Count(&count1) + if count1 <= 0 { + t.Errorf("Should find some users") + } + + if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + DB.Model(&User{}).Count(&count2) + if count1 != count2 { + t.Errorf("No user should not be deleted by invalid SQL") + } +} diff --git a/tests/transaction_test.go b/tests/transaction_test.go new file mode 100644 index 00000000..9405fd76 --- /dev/null +++ b/tests/transaction_test.go @@ -0,0 +1,135 @@ +package tests_test + +import ( + "database/sql" + "errors" + "testing" + + "github.com/jinzhu/gorm" + . "github.com/jinzhu/gorm/tests" +) + +func TestTransaction(t *testing.T) { + tx := DB.Begin() + user := *GetUser("transcation", Config{}) + + if err := tx.Save(&user).Error; err != nil { + t.Errorf("No error should raise, but got %v", err) + } + + if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { + t.Errorf("Should find saved record, but got %v", err) + } + + if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil { + t.Errorf("Should return the underlying sql.Tx") + } + + tx.Rollback() + + if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil { + t.Errorf("Should not find record after rollback, but got %v", err) + } + + tx2 := DB.Begin() + user2 := *GetUser("transcation-2", Config{}) + if err := tx2.Save(&user2).Error; err != nil { + t.Errorf("No error should raise, but got %v", err) + } + + if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + t.Errorf("Should find saved record, but got %v", err) + } + + tx2.Commit() + + if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + t.Errorf("Should be able to find committed record, but got %v", err) + } +} + +func TestTransactionWithBlock(t *testing.T) { + assertPanic := func(f func()) { + defer func() { + if r := recover(); r == nil { + t.Errorf("The code did not panic") + } + }() + f() + } + + // rollback + err := DB.Transaction(func(tx *gorm.DB) error { + user := *GetUser("transcation-block", Config{}) + if err := tx.Save(&user).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Errorf("Should find saved record") + } + + return errors.New("the error message") + }) + + if err.Error() != "the error message" { + t.Errorf("Transaction return error will equal the block returns error") + } + + if err := DB.First(&User{}, "name = ?", "transcation-block").Error; err == nil { + t.Errorf("Should not find record after rollback") + } + + // commit + DB.Transaction(func(tx *gorm.DB) error { + user := *GetUser("transcation-block-2", Config{}) + if err := tx.Save(&user).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Errorf("Should find saved record") + } + return nil + }) + + if err := DB.First(&User{}, "name = ?", "transcation-block-2").Error; err != nil { + t.Errorf("Should be able to find committed record") + } + + // panic will rollback + assertPanic(func() { + DB.Transaction(func(tx *gorm.DB) error { + user := *GetUser("transcation-block-3", Config{}) + if err := tx.Save(&user).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Errorf("Should find saved record") + } + + panic("force panic") + }) + }) + + if err := DB.First(&User{}, "name = ?", "transcation-block-3").Error; err == nil { + t.Errorf("Should not find record after panic rollback") + } +} + +func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) { + tx := DB.Begin() + user := User{Name: "transcation"} + if err := tx.Save(&user).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.Commit().Error; err != nil { + t.Errorf("Commit should not raise error") + } + + if err := tx.Rollback().Error; err == nil { + t.Errorf("Rollback after commit should raise error") + } +} From 749ca37eb0bdb149dbdc8fa7a47c39cf708f51ba Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 19:23:32 +0800 Subject: [PATCH 368/881] Add sql builder test --- callbacks/query.go | 188 ++++++++++++++++++++------------------ callbacks/row.go | 6 +- tests/sql_builder_test.go | 82 +++++++++++++++++ 3 files changed, 184 insertions(+), 92 deletions(-) create mode 100644 tests/sql_builder_test.go diff --git a/callbacks/query.go b/callbacks/query.go index 6edfee0b..9f96fd1a 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -19,93 +19,7 @@ func Query(db *gorm.DB) { } if db.Statement.SQL.String() == "" { - clauseSelect := clause.Select{} - - if len(db.Statement.Selects) > 0 { - for _, name := range db.Statement.Selects { - if f := db.Statement.Schema.LookUpField(name); f != nil { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Name: f.DBName, - }) - } else { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Name: name, - Raw: true, - }) - } - } - } - - // inline joins - if len(db.Statement.Joins) != 0 { - joins := []clause.Join{} - - if len(db.Statement.Selects) == 0 { - for _, dbName := range db.Statement.Schema.DBNames { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: db.Statement.Table, - Name: dbName, - }) - } - } - - for name, conds := range db.Statement.Joins { - if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { - tableAliasName := relation.Name - - for _, s := range relation.FieldSchema.DBNames { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: tableAliasName, - Name: s, - Alias: tableAliasName + "__" + s, - }) - } - - var exprs []clause.Expression - for _, ref := range relation.References { - if ref.OwnPrimaryKey { - exprs = append(exprs, clause.Eq{ - Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.PrimaryKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - }) - } else { - if ref.PrimaryValue == "" { - exprs = append(exprs, clause.Eq{ - Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.ForeignKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, - }) - } else { - exprs = append(exprs, clause.Eq{ - Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, - }) - } - } - } - - joins = append(joins, clause.Join{ - Type: clause.LeftJoin, - Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: exprs}, - }) - } else { - joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: name, Vars: conds}, - }) - } - } - - db.Statement.AddClause(clause.From{Joins: joins}) - } else { - db.Statement.AddClauseIfNotExists(clause.From{}) - } - - if len(clauseSelect.Columns) > 0 { - db.Statement.AddClause(clauseSelect) - } else { - db.Statement.AddClauseIfNotExists(clauseSelect) - } - db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") + BuildQuerySQL(db) } rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) @@ -118,6 +32,106 @@ func Query(db *gorm.DB) { gorm.Scan(rows, db, false) } +func BuildQuerySQL(db *gorm.DB) { + clauseSelect := clause.Select{} + + if len(db.Statement.Selects) > 0 { + for _, name := range db.Statement.Selects { + if db.Statement.Schema == nil { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Name: name, + Raw: true, + }) + } else if f := db.Statement.Schema.LookUpField(name); f != nil { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Name: f.DBName, + }) + } else { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Name: name, + Raw: true, + }) + } + } + } + + // inline joins + if len(db.Statement.Joins) != 0 { + joins := []clause.Join{} + + if len(db.Statement.Selects) == 0 { + for _, dbName := range db.Statement.Schema.DBNames { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: db.Statement.Table, + Name: dbName, + }) + } + } + + for name, conds := range db.Statement.Joins { + if db.Statement.Schema == nil { + joins = append(joins, clause.Join{ + Expression: clause.Expr{SQL: name, Vars: conds}, + }) + } else if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { + tableAliasName := relation.Name + + for _, s := range relation.FieldSchema.DBNames { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: tableAliasName, + Name: s, + Alias: tableAliasName + "__" + s, + }) + } + + var exprs []clause.Expression + for _, ref := range relation.References { + if ref.OwnPrimaryKey { + exprs = append(exprs, clause.Eq{ + Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + }) + } else { + if ref.PrimaryValue == "" { + exprs = append(exprs, clause.Eq{ + Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + }) + } else { + exprs = append(exprs, clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } + } + } + + joins = append(joins, clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + }) + } else { + joins = append(joins, clause.Join{ + Expression: clause.Expr{SQL: name, Vars: conds}, + }) + } + } + + db.Statement.AddClause(clause.From{Joins: joins}) + } else { + db.Statement.AddClauseIfNotExists(clause.From{}) + } + + if len(clauseSelect.Columns) > 0 { + db.Statement.AddClause(clauseSelect) + } else { + db.Statement.AddClauseIfNotExists(clauseSelect) + } + + db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") +} + func Preload(db *gorm.DB) { if len(db.Statement.Preloads) > 0 { preloadMap := map[string][]string{} diff --git a/callbacks/row.go b/callbacks/row.go index b84cf694..004a89d5 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -2,15 +2,11 @@ package callbacks import ( "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/clause" ) func RowQuery(db *gorm.DB) { if db.Statement.SQL.String() == "" { - db.Statement.AddClauseIfNotExists(clause.Select{}) - db.Statement.AddClauseIfNotExists(clause.From{}) - - db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") + BuildQuerySQL(db) } if _, ok := db.Get("rows"); ok { diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go new file mode 100644 index 00000000..4cd40c7a --- /dev/null +++ b/tests/sql_builder_test.go @@ -0,0 +1,82 @@ +package tests_test + +import ( + "testing" + + "github.com/jinzhu/gorm" + . "github.com/jinzhu/gorm/tests" +) + +func TestRow(t *testing.T) { + user1 := User{Name: "RowUser1", Age: 1} + user2 := User{Name: "RowUser2", Age: 10} + user3 := User{Name: "RowUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row() + + var age int64 + if err := row.Scan(&age); err != nil { + t.Fatalf("Failed to scan age, got %v", err) + } + + if age != 10 { + t.Errorf("Scan with Row, age expects: %v, got %v", user2.Age, age) + } +} + +func TestRows(t *testing.T) { + user1 := User{Name: "RowsUser1", Age: 1} + user2 := User{Name: "RowsUser2", Age: 10} + user3 := User{Name: "RowsUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() + if err != nil { + t.Errorf("Not error should happen, got %v", err) + } + + count := 0 + for rows.Next() { + var name string + var age int64 + rows.Scan(&name, &age) + count++ + } + + if count != 2 { + t.Errorf("Should found two records") + } +} + +func TestRaw(t *testing.T) { + user1 := User{Name: "ExecRawSqlUser1", Age: 1} + user2 := User{Name: "ExecRawSqlUser2", Age: 10} + user3 := User{Name: "ExecRawSqlUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + type result struct { + Name string + Email string + } + + var results []result + DB.Raw("SELECT name, age FROM users WHERE name = ? or name = ?", user2.Name, user3.Name).Scan(&results) + if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { + t.Errorf("Raw with scan") + } + + rows, _ := DB.Raw("select name, age from users where name = ?", user3.Name).Rows() + count := 0 + for rows.Next() { + count++ + } + if count != 1 { + t.Errorf("Raw with Rows should find one record with name 3") + } + + DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name}) + if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound { + t.Error("Raw sql to update records") + } +} From 5b1d3e4a771947f5caae6950b86ab32fd8e56507 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 20:21:52 +0800 Subject: [PATCH 369/881] Test Joins --- callbacks/query.go | 6 +----- finisher_api.go | 5 +++-- statement.go | 10 ++++----- tests/joins_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 12 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 9f96fd1a..55f2c65b 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -123,11 +123,7 @@ func BuildQuerySQL(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.From{}) } - if len(clauseSelect.Columns) > 0 { - db.Statement.AddClause(clauseSelect) - } else { - db.Statement.AddClauseIfNotExists(clauseSelect) - } + db.Statement.AddClauseIfNotExists(clauseSelect) db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } diff --git a/finisher_api.go b/finisher_api.go index cfbb98c1..49b08fa4 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -233,9 +233,10 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Count(count *int64) (tx *DB) { tx = db.getInstance() - if len(tx.Statement.Selects) == 0 { - tx.Statement.Selects = []string{"count(1)"} + if s, ok := tx.Statement.Clauses["SELECT"].Expression.(clause.Select); !ok || len(s.Columns) == 0 { + tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) } + if tx.Statement.Model == nil { tx.Statement.Model = tx.Statement.Dest } diff --git a/statement.go b/statement.go index e0d92c5e..444d5c37 100644 --- a/statement.go +++ b/statement.go @@ -196,7 +196,7 @@ func (stmt *Statement) AddClause(v clause.Interface) { // AddClauseIfNotExists add clause if not exists func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { - if _, ok := stmt.Clauses[v.Name()]; !ok { + if c, ok := stmt.Clauses[v.Name()]; !ok && c.Expression == nil { stmt.AddClause(v) } } @@ -248,9 +248,9 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con for _, field := range s.Fields { if v, isZero := field.ValueOf(reflectValue); !isZero { if field.DBName == "" { - conds = append(conds, clause.Eq{Column: field.Name, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) } else { - conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) } } } @@ -259,9 +259,9 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con for _, field := range s.Fields { if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { if field.DBName == "" { - conds = append(conds, clause.Eq{Column: field.Name, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) } else { - conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) } } } diff --git a/tests/joins_test.go b/tests/joins_test.go index 556130ee..8a9cdde5 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -4,6 +4,7 @@ import ( "sort" "testing" + "github.com/jinzhu/gorm" . "github.com/jinzhu/gorm/tests" ) @@ -53,3 +54,54 @@ func TestJoinsForSlice(t *testing.T) { CheckUser(t, user, users2[idx]) } } + +func TestJoinConds(t *testing.T) { + var user = *GetUser("joins-conds", Config{Account: true, Pets: 3}) + DB.Save(&user) + + var users1 []User + DB.Joins("left join pets on pets.user_id = users.id").Where("users.name = ?", user.Name).Find(&users1) + if len(users1) != 3 { + t.Errorf("should find two users using left join, but got %v", len(users1)) + } + + var users2 []User + DB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Where("users.name = ?", user.Name).First(&users2) + if len(users2) != 1 { + t.Errorf("should find one users using left join with conditions, but got %v", len(users2)) + } + + var users3 []User + DB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where("users.name = ?", user.Name).First(&users3) + if len(users3) != 1 { + t.Errorf("should find one users using multiple left join conditions, but got %v", len(users3)) + } + + var users4 []User + DB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number+"non-exist").Where("users.name = ?", user.Name).First(&users4) + if len(users4) != 0 { + t.Errorf("should find no user when searching with unexisting credit card, but got %v", len(users4)) + } + + var users5 []User + db5 := DB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5) + if db5.Error != nil { + t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) + } +} + +func TestJoinsWithSelect(t *testing.T) { + type result struct { + ID uint + Name string + } + + user := *GetUser("joins_with_select", Config{Pets: 2}) + DB.Save(&user) + + var results []result + DB.Table("users").Select("users.id, pets.name").Joins("left join pets on pets.user_id = users.id").Where("users.name = ?", "joins_with_select").Scan(&results) + if len(results) != 2 || results[0].Name != user.Pets[0].Name || results[1].Name != user.Pets[1].Name { + t.Errorf("Should find all two pets with Join select") + } +} From e26abb84b322a5a6648b7135ae6ee90cfeedee2c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 20:42:07 +0800 Subject: [PATCH 370/881] Test block global update/delete --- callbacks/update.go | 5 +++++ tests/delete_test.go | 10 ++++++---- tests/joins_test.go | 18 ++++++++++++++---- tests/main_test.go | 14 ++++++++++++++ tests/non_std_test.go | 2 +- tests/update_test.go | 7 +++++++ 6 files changed, 47 insertions(+), 9 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index cfa8c86b..c16b77d1 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -59,6 +59,11 @@ func Update(db *gorm.DB) { db.Statement.Build("UPDATE", "SET", "WHERE") } + if _, ok := db.Statement.Clauses["WHERE"]; !ok { + db.AddError(gorm.ErrMissingWhereClause) + return + } + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { diff --git a/tests/delete_test.go b/tests/delete_test.go index 3f17f1a1..4288253f 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -36,10 +36,6 @@ func TestDelete(t *testing.T) { } } - if err := DB.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { - t.Errorf("should returns missing WHERE clause while deleting error") - } - for _, user := range []User{users[0], users[2]} { if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { t.Errorf("no error should returns when query %v, but got %v", user.ID, err) @@ -64,3 +60,9 @@ func TestInlineCondDelete(t *testing.T) { t.Errorf("User can't be found after delete") } } + +func TestBlockGlobalDelete(t *testing.T) { + if err := DB.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { + t.Errorf("should returns missing WHERE clause while deleting error") + } +} diff --git a/tests/joins_test.go b/tests/joins_test.go index 8a9cdde5..d9cfd22f 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -92,16 +92,26 @@ func TestJoinConds(t *testing.T) { func TestJoinsWithSelect(t *testing.T) { type result struct { - ID uint - Name string + ID uint + PetID uint + Name string } user := *GetUser("joins_with_select", Config{Pets: 2}) DB.Save(&user) var results []result - DB.Table("users").Select("users.id, pets.name").Joins("left join pets on pets.user_id = users.id").Where("users.name = ?", "joins_with_select").Scan(&results) + DB.Table("users").Select("users.id, pets.id as pet_id, pets.name").Joins("left join pets on pets.user_id = users.id").Where("users.name = ?", "joins_with_select").Scan(&results) + + sort.Slice(results, func(i, j int) bool { + return results[i].PetID > results[j].PetID + }) + + sort.Slice(results, func(i, j int) bool { + return user.Pets[i].ID > user.Pets[j].ID + }) + if len(results) != 2 || results[0].Name != user.Pets[0].Name || results[1].Name != user.Pets[1].Name { - t.Errorf("Should find all two pets with Join select") + t.Errorf("Should find all two pets with Join select, got %+v", results) } } diff --git a/tests/main_test.go b/tests/main_test.go index da2003d6..095588a2 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -35,3 +35,17 @@ func TestExceptionsWithInvalidSql(t *testing.T) { t.Errorf("No user should not be deleted by invalid SQL") } } + +func TestSetAndGet(t *testing.T) { + if value, ok := DB.Set("hello", "world").Get("hello"); !ok { + t.Errorf("Should be able to get setting after set") + } else { + if value.(string) != "world" { + t.Errorf("Setted value should not be changed") + } + } + + if _, ok := DB.Get("non_existing"); ok { + t.Errorf("Get non existing key should return error") + } +} diff --git a/tests/non_std_test.go b/tests/non_std_test.go index e5e50141..606b4fc9 100644 --- a/tests/non_std_test.go +++ b/tests/non_std_test.go @@ -34,7 +34,7 @@ func TestNonStdPrimaryKeyAndDefaultValues(t *testing.T) { var animals []Animal DB.Find(&animals) - if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { + if count := DB.Model(Animal{}).Where("1=1").Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { t.Error("RowsAffected should be correct when do batch update") } diff --git a/tests/update_test.go b/tests/update_test.go index 371a9f78..869ce4cd 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "errors" "testing" "time" @@ -211,3 +212,9 @@ func TestUpdateColumn(t *testing.T) { CheckUser(t, user5, *users[0]) CheckUser(t, user6, *users[1]) } + +func TestBlockGlobalUpdate(t *testing.T) { + if err := DB.Model(&User{}).Update("name", "jinzhu").Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { + t.Errorf("should returns missing WHERE clause while updating error, got err %v", err) + } +} From 95a6539331aef3d7da0885540478463ff7f36b62 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 21:11:20 +0800 Subject: [PATCH 371/881] Test Pluck --- finisher_api.go | 1 + scan.go | 32 +++++++++++++++++++++----------- tests/query_test.go | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 11 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 49b08fa4..334aea58 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -273,6 +273,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { // db.Find(&users).Pluck("age", &ages) func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.AddClause(clause.Select{Columns: []clause.Column{{Name: column}}}) tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) return diff --git a/scan.go b/scan.go index 66cb0b94..4d328fde 100644 --- a/scan.go +++ b/scan.go @@ -58,7 +58,12 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { default: switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - isPtr := db.Statement.ReflectValue.Type().Elem().Kind() == reflect.Ptr + reflectValueType := db.Statement.ReflectValue.Type().Elem() + isPtr := reflectValueType.Kind() == reflect.Ptr + if isPtr { + reflectValueType = reflectValueType.Elem() + } + db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) fields := make([]*schema.Field, len(columns)) joinFields := make([][2]*schema.Field, len(columns)) @@ -81,17 +86,22 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { for initialized || rows.Next() { initialized = false - elem := reflect.New(db.Statement.Schema.ModelType).Elem() - for idx, field := range fields { - if field != nil { - values[idx] = field.ReflectValueOf(elem).Addr().Interface() - } else if joinFields[idx][0] != nil { - relValue := joinFields[idx][0].ReflectValueOf(elem) - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - relValue.Set(reflect.New(relValue.Type().Elem())) - } + elem := reflect.New(reflectValueType).Elem() - values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface() + if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 { + values[0] = elem.Addr().Interface() + } else { + for idx, field := range fields { + if field != nil { + values[idx] = field.ReflectValueOf(elem).Addr().Interface() + } else if joinFields[idx][0] != nil { + relValue := joinFields[idx][0].ReflectValueOf(elem) + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + relValue.Set(reflect.New(relValue.Type().Elem())) + } + + values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface() + } } } diff --git a/tests/query_test.go b/tests/query_test.go index 4388066f..b7c619d7 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -80,3 +80,35 @@ func TestFind(t *testing.T) { } } } + +func TestPluck(t *testing.T) { + users := []*User{ + GetUser("pluck-user1", Config{}), + GetUser("pluck-user2", Config{}), + GetUser("pluck-user3", Config{}), + } + + DB.Create(&users) + + var names []string + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name").Pluck("name", &names).Error; err != nil { + t.Errorf("Raise error when pluck name, got %v", err) + } + + var ids []int + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name").Pluck("id", &ids).Error; err != nil { + t.Errorf("Raise error when pluck id, got %v", err) + } + + for idx, name := range names { + if name != users[idx].Name { + t.Errorf("Unexpected result on pluck name, got %+v", names) + } + } + + for idx, id := range ids { + if int(id) != int(users[idx].ID) { + t.Errorf("Unexpected result on pluck id, got %+v", ids) + } + } +} From befef0c9a97e3816688074392c0762cefc414c9f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 23:55:56 +0800 Subject: [PATCH 372/881] Improve Hooks --- callbacks/associations.go | 4 +- callbacks/create.go | 252 +++++++++++++++++++------------------- callbacks/delete.go | 80 ++++++------ callbacks/query.go | 101 +++++++-------- callbacks/raw.go | 12 +- callbacks/row.go | 16 +-- callbacks/transaction.go | 18 ++- callbacks/update.go | 62 +++++----- errors.go | 2 +- gorm.go | 108 +++++++++++----- interfaces.go | 18 +-- schema/callbacks_test.go | 6 +- schema/schema.go | 4 +- tests/hooks_test.go | 201 ++++++++++++++++++++++++++++++ tests/tests.go | 4 +- tests/transaction_test.go | 42 +++---- 16 files changed, 610 insertions(+), 320 deletions(-) create mode 100644 tests/hooks_test.go diff --git a/callbacks/associations.go b/callbacks/associations.go index 76fc5b81..3c8c2a50 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -10,7 +10,7 @@ import ( ) func SaveBeforeAssociations(db *gorm.DB) { - if db.Statement.Schema != nil { + if db.Error == nil && db.Statement.Schema != nil { selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) // Save Belongs To associations @@ -83,7 +83,7 @@ func SaveBeforeAssociations(db *gorm.DB) { } func SaveAfterAssociations(db *gorm.DB) { - if db.Statement.Schema != nil { + if db.Error == nil && db.Statement.Schema != nil { selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) // Save Has One associations diff --git a/callbacks/create.go b/callbacks/create.go index f558d7ae..7a2b8bfe 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -9,20 +9,21 @@ import ( ) func BeforeCreate(db *gorm.DB) { - if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { + if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { + tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { var ok bool if db.Statement.Schema.BeforeSave { if i, ok := value.(gorm.BeforeSaveInterface); ok { ok = true - i.BeforeSave(db) + db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeCreate { if i, ok := value.(gorm.BeforeCreateInterface); ok { ok = true - i.BeforeCreate(db) + db.AddError(i.BeforeCreate(tx)) } } return ok @@ -31,7 +32,7 @@ func BeforeCreate(db *gorm.DB) { if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { callMethod(db.Statement.ReflectValue.Index(i).Interface()) } case reflect.Struct: @@ -46,48 +47,127 @@ func Create(config *Config) func(db *gorm.DB) { return CreateWithReturning } else { return func(db *gorm.DB) { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.CreateClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.String() == "" { - db.Statement.AddClauseIfNotExists(clause.Insert{ - Table: clause.Table{Name: db.Statement.Table}, - }) - db.Statement.AddClause(ConvertToCreateValues(db.Statement)) - - db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") - } - - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - - if err == nil { - if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { - if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { - if insertID, err := result.LastInsertId(); err == nil { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID-- - } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID++ - } - } - case reflect.Struct: - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) - } - } else { - db.AddError(err) - } + if db.Error == nil { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) } } + + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Insert{ + Table: clause.Table{Name: db.Statement.Table}, + }) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + + db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") + } + + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { + if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { + if insertID, err := result.LastInsertId(); err == nil { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID-- + } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID++ + } + } + case reflect.Struct: + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } + } else { + db.AddError(err) + } + } + } + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } + } + } + } +} + +func CreateWithReturning(db *gorm.DB) { + if db.Error == nil { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Insert{ + Table: clause.Table{Name: db.Statement.Table}, + }) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + + db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") + } + + if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { + db.Statement.WriteString(" RETURNING ") + + var ( + idx int + fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) + values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) + ) + + for dbName, field := range sch.FieldsWithDefaultDBValue { + if idx != 0 { + db.Statement.WriteByte(',') + } + + fields[idx] = field + db.Statement.WriteQuoted(dbName) + idx++ + } + + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + defer rows.Close() + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for rows.Next() { + for idx, field := range fields { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + } + if err := rows.Scan(values...); err != nil { + db.AddError(err) + } + db.RowsAffected++ + } + case reflect.Struct: + for idx, field := range fields { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + } + + if rows.Next() { + db.RowsAffected++ + err = rows.Scan(values...) + } + } + } + + if err != nil { + db.AddError(err) + } + } else { + if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { db.RowsAffected, _ = result.RowsAffected() } else { db.AddError(err) @@ -96,96 +176,22 @@ func Create(config *Config) func(db *gorm.DB) { } } -func CreateWithReturning(db *gorm.DB) { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.CreateClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.String() == "" { - db.Statement.AddClauseIfNotExists(clause.Insert{ - Table: clause.Table{Name: db.Statement.Table}, - }) - db.Statement.AddClause(ConvertToCreateValues(db.Statement)) - - db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") - } - - if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { - db.Statement.WriteString(" RETURNING ") - - var ( - idx int - fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) - values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) - ) - - for dbName, field := range sch.FieldsWithDefaultDBValue { - if idx != 0 { - db.Statement.WriteByte(',') - } - - fields[idx] = field - db.Statement.WriteQuoted(dbName) - idx++ - } - - rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - - if err == nil { - defer rows.Close() - - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for rows.Next() { - for idx, field := range fields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() - } - if err := rows.Scan(values...); err != nil { - db.AddError(err) - } - db.RowsAffected++ - } - case reflect.Struct: - for idx, field := range fields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() - } - - if rows.Next() { - db.RowsAffected++ - err = rows.Scan(values...) - } - } - } - - if err != nil { - db.AddError(err) - } - } else { - if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) - } - } -} - func AfterCreate(db *gorm.DB) { - if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { + if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { + tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { var ok bool if db.Statement.Schema.AfterSave { if i, ok := value.(gorm.AfterSaveInterface); ok { ok = true - i.AfterSave(db) + db.AddError(i.AfterSave(tx)) } } if db.Statement.Schema.AfterCreate { if i, ok := value.(gorm.AfterCreateInterface); ok { ok = true - i.AfterCreate(db) + db.AddError(i.AfterCreate(tx)) } } return ok @@ -194,7 +200,7 @@ func AfterCreate(db *gorm.DB) { if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { callMethod(db.Statement.ReflectValue.Index(i).Interface()) } case reflect.Struct: diff --git a/callbacks/delete.go b/callbacks/delete.go index b3278c83..582a76f4 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -9,11 +9,12 @@ import ( ) func BeforeDelete(db *gorm.DB) { - if db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { + if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { + tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { if db.Statement.Schema.BeforeDelete { if i, ok := value.(gorm.BeforeDeleteInterface); ok { - i.BeforeDelete(db) + db.AddError(i.BeforeDelete(tx)) return true } } @@ -23,7 +24,7 @@ func BeforeDelete(db *gorm.DB) { if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { callMethod(db.Statement.ReflectValue.Index(i).Interface()) } case reflect.Struct: @@ -34,57 +35,60 @@ func BeforeDelete(db *gorm.DB) { } func Delete(db *gorm.DB) { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.DeleteClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.String() == "" { - db.Statement.AddClauseIfNotExists(clause.Delete{}) - - if db.Statement.Schema != nil { - _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) - column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) - - if len(values) > 0 { - db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + if db.Error == nil { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.DeleteClauses { + db.Statement.AddClause(c) } + } - if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { - _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) - column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Delete{}) + + if db.Statement.Schema != nil { + _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) + column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) } + + if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { + _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) + column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + } } + + if _, ok := db.Statement.Clauses["WHERE"]; !ok { + db.AddError(gorm.ErrMissingWhereClause) + return + } + + db.Statement.AddClauseIfNotExists(clause.From{}) + db.Statement.Build("DELETE", "FROM", "WHERE") } - if _, ok := db.Statement.Clauses["WHERE"]; !ok { - db.AddError(gorm.ErrMissingWhereClause) - return + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) } - - db.Statement.AddClauseIfNotExists(clause.From{}) - db.Statement.Build("DELETE", "FROM", "WHERE") - } - - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - - if err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) } } func AfterDelete(db *gorm.DB) { - if db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { + if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { + tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { if db.Statement.Schema.AfterDelete { if i, ok := value.(gorm.AfterDeleteInterface); ok { - i.AfterDelete(db) + db.AddError(i.AfterDelete(tx)) return true } } @@ -94,7 +98,7 @@ func AfterDelete(db *gorm.DB) { if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { callMethod(db.Statement.ReflectValue.Index(i).Interface()) } case reflect.Struct: diff --git a/callbacks/query.go b/callbacks/query.go index 55f2c65b..91948031 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -12,24 +12,26 @@ import ( ) func Query(db *gorm.DB) { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.QueryClauses { - db.Statement.AddClause(c) + if db.Error == nil { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.QueryClauses { + db.Statement.AddClause(c) + } } - } - if db.Statement.SQL.String() == "" { - BuildQuerySQL(db) - } + if db.Statement.SQL.String() == "" { + BuildQuerySQL(db) + } - rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err != nil { - db.AddError(err) - return - } - defer rows.Close() + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if err != nil { + db.AddError(err) + return + } + defer rows.Close() - gorm.Scan(rows, db, false) + gorm.Scan(rows, db, false) + } } func BuildQuerySQL(db *gorm.DB) { @@ -129,50 +131,53 @@ func BuildQuerySQL(db *gorm.DB) { } func Preload(db *gorm.DB) { - if len(db.Statement.Preloads) > 0 { - preloadMap := map[string][]string{} - for name := range db.Statement.Preloads { - preloadFields := strings.Split(name, ".") - for idx := range preloadFields { - preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] - } - } - - preloadNames := make([]string, len(preloadMap)) - idx := 0 - for key := range preloadMap { - preloadNames[idx] = key - idx++ - } - sort.Strings(preloadNames) - - for _, name := range preloadNames { - var ( - curSchema = db.Statement.Schema - preloadFields = preloadMap[name] - rels = make([]*schema.Relationship, len(preloadFields)) - ) - - for idx, preloadField := range preloadFields { - if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { - rels[idx] = rel - curSchema = rel.FieldSchema - } else { - db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) + if db.Error == nil { + if len(db.Statement.Preloads) > 0 { + preloadMap := map[string][]string{} + for name := range db.Statement.Preloads { + preloadFields := strings.Split(name, ".") + for idx := range preloadFields { + preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] } } - preload(db, rels, db.Statement.Preloads[name]) + preloadNames := make([]string, len(preloadMap)) + idx := 0 + for key := range preloadMap { + preloadNames[idx] = key + idx++ + } + sort.Strings(preloadNames) + + for _, name := range preloadNames { + var ( + curSchema = db.Statement.Schema + preloadFields = preloadMap[name] + rels = make([]*schema.Relationship, len(preloadFields)) + ) + + for idx, preloadField := range preloadFields { + if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { + rels[idx] = rel + curSchema = rel.FieldSchema + } else { + db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) + } + } + + preload(db, rels, db.Statement.Preloads[name]) + } } } } func AfterQuery(db *gorm.DB) { - if db.Statement.Schema != nil && db.Statement.Schema.AfterFind { + if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind { + tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { if db.Statement.Schema.AfterFind { if i, ok := value.(gorm.AfterFindInterface); ok { - i.AfterFind(db) + db.AddError(i.AfterFind(tx)) return true } } @@ -182,7 +187,7 @@ func AfterQuery(db *gorm.DB) { if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { callMethod(db.Statement.ReflectValue.Index(i).Interface()) } case reflect.Struct: diff --git a/callbacks/raw.go b/callbacks/raw.go index ce125e61..cb0cd6c9 100644 --- a/callbacks/raw.go +++ b/callbacks/raw.go @@ -5,10 +5,12 @@ import ( ) func RawExec(db *gorm.DB) { - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err != nil { - db.AddError(err) - } else { - db.RowsAffected, _ = result.RowsAffected() + if db.Error == nil { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if err != nil { + db.AddError(err) + } else { + db.RowsAffected, _ = result.RowsAffected() + } } } diff --git a/callbacks/row.go b/callbacks/row.go index 004a89d5..f4ff734c 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -5,13 +5,15 @@ import ( ) func RowQuery(db *gorm.DB) { - if db.Statement.SQL.String() == "" { - BuildQuerySQL(db) - } + if db.Error == nil { + if db.Statement.SQL.String() == "" { + BuildQuerySQL(db) + } - if _, ok := db.Get("rows"); ok { - db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - } else { - db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if _, ok := db.Get("rows"); ok { + db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } else { + db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } } } diff --git a/callbacks/transaction.go b/callbacks/transaction.go index 253c4e82..63015364 100644 --- a/callbacks/transaction.go +++ b/callbacks/transaction.go @@ -1,9 +1,25 @@ package callbacks -import "github.com/jinzhu/gorm" +import ( + "github.com/jinzhu/gorm" +) func BeginTransaction(db *gorm.DB) { + if tx := db.Begin(); tx.Error == nil { + db.Statement.ConnPool = tx.Statement.ConnPool + tx.InstanceSet("gorm:started_transaction", true) + } else { + tx.Error = nil + } } func CommitOrRollbackTransaction(db *gorm.DB) { + if _, ok := db.InstanceGet("gorm:started_transaction"); ok { + if db.Error == nil { + db.Commit() + } else { + db.Rollback() + } + db.Statement.ConnPool = db.ConnPool + } } diff --git a/callbacks/update.go b/callbacks/update.go index c16b77d1..cbbcddf7 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -10,20 +10,21 @@ import ( ) func BeforeUpdate(db *gorm.DB) { - if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { + if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { + tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { var ok bool if db.Statement.Schema.BeforeSave { if i, ok := value.(gorm.BeforeSaveInterface); ok { ok = true - i.BeforeSave(db) + db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeUpdate { if i, ok := value.(gorm.BeforeUpdateInterface); ok { ok = true - i.BeforeUpdate(db) + db.AddError(i.BeforeUpdate(tx)) } } return ok @@ -32,7 +33,7 @@ func BeforeUpdate(db *gorm.DB) { if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { callMethod(db.Statement.ReflectValue.Index(i).Interface()) } case reflect.Struct: @@ -43,51 +44,54 @@ func BeforeUpdate(db *gorm.DB) { } func Update(db *gorm.DB) { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.UpdateClauses { - db.Statement.AddClause(c) + if db.Error == nil { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } } - } - if db.Statement.SQL.String() == "" { - db.Statement.AddClauseIfNotExists(clause.Update{}) - if set := ConvertToAssignments(db.Statement); len(set) != 0 { - db.Statement.AddClause(set) - } else { + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Update{}) + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + db.Statement.AddClause(set) + } else { + return + } + db.Statement.Build("UPDATE", "SET", "WHERE") + } + + if _, ok := db.Statement.Clauses["WHERE"]; !ok { + db.AddError(gorm.ErrMissingWhereClause) return } - db.Statement.Build("UPDATE", "SET", "WHERE") - } - if _, ok := db.Statement.Clauses["WHERE"]; !ok { - db.AddError(gorm.ErrMissingWhereClause) - return - } + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - - if err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } } } func AfterUpdate(db *gorm.DB) { - if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { + if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { + tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { var ok bool if db.Statement.Schema.AfterSave { if i, ok := value.(gorm.AfterSaveInterface); ok { ok = true - i.AfterSave(db) + db.AddError(i.AfterSave(tx)) } } if db.Statement.Schema.AfterUpdate { if i, ok := value.(gorm.AfterUpdateInterface); ok { ok = true - i.AfterUpdate(db) + db.AddError(i.AfterUpdate(tx)) } } return ok @@ -96,7 +100,7 @@ func AfterUpdate(db *gorm.DB) { if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { callMethod(db.Statement.ReflectValue.Index(i).Interface()) } case reflect.Struct: diff --git a/errors.go b/errors.go index 140a5186..82f24df2 100644 --- a/errors.go +++ b/errors.go @@ -16,7 +16,7 @@ var ( // ErrNotImplemented not implemented ErrNotImplemented = errors.New("not implemented") // ErrMissingWhereClause missing where clause - ErrMissingWhereClause = errors.New("missing WHERE clause while deleting") + ErrMissingWhereClause = errors.New("WHERE conditions required") // ErrUnsupportedRelation unsupported relations ErrUnsupportedRelation = errors.New("unsupported relations") // ErrPtrStructSupported only ptr of struct supported diff --git a/gorm.go b/gorm.go index ac4bff5e..c1d6f8da 100644 --- a/gorm.go +++ b/gorm.go @@ -40,14 +40,15 @@ type DB struct { Error error RowsAffected int64 Statement *Statement - clone bool + clone int } // Session session config when create session with Session() method type Session struct { - Context context.Context - Logger logger.Interface - NowFunc func() time.Time + WithConditions bool + Context context.Context + Logger logger.Interface + NowFunc func() time.Time } // Open initialize db session based on dialector @@ -76,10 +77,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { config.cacheStore = &sync.Map{} } - db = &DB{ - Config: config, - clone: true, - } + db = &DB{Config: config, clone: 1} db.callbacks = initializeCallbacks(db) @@ -96,38 +94,54 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { // Session create new db session func (db *DB) Session(config *Session) *DB { var ( - tx = db.getInstance() - stmt = tx.Statement.clone() - txConfig = *tx.Config + txConfig = *db.Config + tx = &DB{ + Config: &txConfig, + Statement: db.Statement, + clone: 1, + } ) if config.Context != nil { - stmt.Context = config.Context + if tx.Statement != nil { + tx.Statement = tx.Statement.clone() + } else { + tx.Statement = &Statement{ + DB: tx, + Clauses: map[string]clause.Clause{}, + ConnPool: tx.ConnPool, + } + } + + tx.Statement.Context = config.Context + } + + if config.WithConditions { + tx.clone = 3 } if config.Logger != nil { - txConfig.Logger = config.Logger + tx.Config.Logger = config.Logger } if config.NowFunc != nil { - txConfig.NowFunc = config.NowFunc + tx.Config.NowFunc = config.NowFunc } - return &DB{ - Config: &txConfig, - Statement: stmt, - clone: true, - } + return tx } // WithContext change current instance db's context to ctx func (db *DB) WithContext(ctx context.Context) *DB { - return db.Session(&Session{Context: ctx}) + return db.Session(&Session{WithConditions: true, Context: ctx}) } // Debug start debug mode func (db *DB) Debug() (tx *DB) { - return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)}) + return db.Session(&Session{ + WithConditions: true, + Logger: db.Logger.LogMode(logger.Info), + }) } // Set store value with key into current db instance's context @@ -145,6 +159,21 @@ func (db *DB) Get(key string) (interface{}, bool) { return nil, false } +// InstanceSet store value with key into current db instance's context +func (db *DB) InstanceSet(key string, value interface{}) *DB { + tx := db.getInstance() + tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value) + return tx +} + +// InstanceGet get value with key from current db instance's context +func (db *DB) InstanceGet(key string) (interface{}, bool) { + if db.Statement != nil { + return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) + } + return nil, false +} + // Callback returns callback manager func (db *DB) Callback() *callbacks { return db.callbacks @@ -166,18 +195,37 @@ func (db *DB) AddError(err error) error { } func (db *DB) getInstance() *DB { - if db.clone { - stmt := &Statement{DB: db, Clauses: map[string]clause.Clause{}} + if db.clone > 0 { + tx := &DB{Config: db.Config} - if db.Statement != nil { - stmt.Context = db.Statement.Context - stmt.ConnPool = db.Statement.ConnPool - } else { - stmt.Context = context.Background() - stmt.ConnPool = db.ConnPool + switch db.clone { + case 1: // clone with new statement + case 2: // with old statement, generate new statement for future call, used to pass to callbacks + db.clone = 1 + tx.Statement = db.Statement + case 3: // with clone statement + if db.Statement != nil { + tx.Statement = db.Statement.clone() + tx.Statement.DB = tx + } } - return &DB{Config: db.Config, Statement: stmt} + if tx.Statement == nil { + tx.Statement = &Statement{ + DB: tx, + Clauses: map[string]clause.Clause{}, + } + } + + if db.Statement != nil { + tx.Statement.Context = db.Statement.Context + tx.Statement.ConnPool = db.Statement.ConnPool + } else { + tx.Statement.Context = context.Background() + tx.Statement.ConnPool = db.ConnPool + } + + return tx } return db diff --git a/interfaces.go b/interfaces.go index 9dd00c15..14d8fa34 100644 --- a/interfaces.go +++ b/interfaces.go @@ -36,37 +36,37 @@ type TxCommiter interface { } type BeforeCreateInterface interface { - BeforeCreate(*DB) + BeforeCreate(*DB) error } type AfterCreateInterface interface { - AfterCreate(*DB) + AfterCreate(*DB) error } type BeforeUpdateInterface interface { - BeforeUpdate(*DB) + BeforeUpdate(*DB) error } type AfterUpdateInterface interface { - AfterUpdate(*DB) + AfterUpdate(*DB) error } type BeforeSaveInterface interface { - BeforeSave(*DB) + BeforeSave(*DB) error } type AfterSaveInterface interface { - AfterSave(*DB) + AfterSave(*DB) error } type BeforeDeleteInterface interface { - BeforeDelete(*DB) + BeforeDelete(*DB) error } type AfterDeleteInterface interface { - AfterDelete(*DB) + AfterDelete(*DB) error } type AfterFindInterface interface { - AfterFind(*DB) + AfterFind(*DB) error } diff --git a/schema/callbacks_test.go b/schema/callbacks_test.go index 720c9a5b..efa01e89 100644 --- a/schema/callbacks_test.go +++ b/schema/callbacks_test.go @@ -12,10 +12,12 @@ import ( type UserWithCallback struct { } -func (UserWithCallback) BeforeSave(*gorm.DB) { +func (UserWithCallback) BeforeSave(*gorm.DB) error { + return nil } -func (UserWithCallback) AfterCreate(*gorm.DB) { +func (UserWithCallback) AfterCreate(*gorm.DB) error { + return nil } func TestCallback(t *testing.T) { diff --git a/schema/schema.go b/schema/schema.go index 77b9832c..231ed1db 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -200,12 +200,12 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } - reflectValue := reflect.Indirect(reflect.New(modelType)) + reflectValue := reflect.New(modelType) callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} for _, name := range callbacks { if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() { switch methodValue.Type().String() { - case "func(*gorm.DB)": // TODO hack + case "func(*gorm.DB) error": // TODO hack reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) default: logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name) diff --git a/tests/hooks_test.go b/tests/hooks_test.go new file mode 100644 index 00000000..432226a3 --- /dev/null +++ b/tests/hooks_test.go @@ -0,0 +1,201 @@ +package tests_test + +import ( + "errors" + "reflect" + "testing" + + "github.com/jinzhu/gorm" + . "github.com/jinzhu/gorm/tests" +) + +type Product struct { + gorm.Model + Name string + Code string + Price float64 + AfterFindCallTimes int64 + BeforeCreateCallTimes int64 + AfterCreateCallTimes int64 + BeforeUpdateCallTimes int64 + AfterUpdateCallTimes int64 + BeforeSaveCallTimes int64 + AfterSaveCallTimes int64 + BeforeDeleteCallTimes int64 + AfterDeleteCallTimes int64 +} + +func (s *Product) BeforeCreate(tx *gorm.DB) (err error) { + if s.Code == "Invalid" { + err = errors.New("invalid product") + } + s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 + return +} + +func (s *Product) BeforeUpdate(tx *gorm.DB) (err error) { + if s.Code == "dont_update" { + err = errors.New("can't update") + } + s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 + return +} + +func (s *Product) BeforeSave(tx *gorm.DB) (err error) { + if s.Code == "dont_save" { + err = errors.New("can't save") + } + s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 + return +} + +func (s *Product) AfterFind(tx *gorm.DB) (err error) { + s.AfterFindCallTimes = s.AfterFindCallTimes + 1 + return +} + +func (s *Product) AfterCreate(tx *gorm.DB) (err error) { + return tx.Model(s).UpdateColumn("AfterCreateCallTimes", s.AfterCreateCallTimes+1).Error +} + +func (s *Product) AfterUpdate(tx *gorm.DB) (err error) { + s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 + return +} + +func (s *Product) AfterSave(tx *gorm.DB) (err error) { + if s.Code == "after_save_error" { + err = errors.New("can't save") + } + s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 + return +} + +func (s *Product) BeforeDelete(tx *gorm.DB) (err error) { + if s.Code == "dont_delete" { + err = errors.New("can't delete") + } + s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 + return +} + +func (s *Product) AfterDelete(tx *gorm.DB) (err error) { + if s.Code == "after_delete_error" { + err = errors.New("can't delete") + } + s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 + return +} + +func (s *Product) GetCallTimes() []int64 { + return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} +} + +func TestRunCallbacks(t *testing.T) { + DB.Migrator().DropTable(&Product{}) + DB.AutoMigrate(&Product{}) + + p := Product{Code: "unique_code", Price: 100} + DB.Save(&p) + + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { + t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) + } + + DB.Where("Code = ?", "unique_code").First(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { + t.Fatalf("After callbacks values are not saved, %v", p.GetCallTimes()) + } + + p.Price = 200 + DB.Save(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { + t.Fatalf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) + } + + var products []Product + DB.Find(&products, "code = ?", "unique_code") + if products[0].AfterFindCallTimes != 1 { + t.Fatalf("AfterFind callbacks should work with slice, called %v", products[0].AfterFindCallTimes) + } + + DB.Where("Code = ?", "unique_code").First(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { + t.Fatalf("After update callbacks values are not saved, %v", p.GetCallTimes()) + } + + DB.Delete(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { + t.Fatalf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) + } + + if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { + t.Fatalf("Can't find a deleted record") + } +} + +func TestCallbacksWithErrors(t *testing.T) { + DB.Migrator().DropTable(&Product{}) + DB.AutoMigrate(&Product{}) + + p := Product{Code: "Invalid", Price: 100} + if DB.Save(&p).Error == nil { + t.Fatalf("An error from before create callbacks happened when create with invalid value") + } + + if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { + t.Fatalf("Should not save record that have errors") + } + + if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { + t.Fatalf("An error from after create callbacks happened when create with invalid value") + } + + p2 := Product{Code: "update_callback", Price: 100} + DB.Save(&p2) + + p2.Code = "dont_update" + if DB.Save(&p2).Error == nil { + t.Fatalf("An error from before update callbacks happened when update with invalid value") + } + + if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { + t.Fatalf("Record Should not be updated due to errors happened in before update callback") + } + + if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { + t.Fatalf("Record Should not be updated due to errors happened in before update callback") + } + + p2.Code = "dont_save" + if DB.Save(&p2).Error == nil { + t.Fatalf("An error from before save callbacks happened when update with invalid value") + } + + p3 := Product{Code: "dont_delete", Price: 100} + DB.Save(&p3) + if DB.Delete(&p3).Error == nil { + t.Fatalf("An error from before delete callbacks happened when delete") + } + + if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { + t.Fatalf("An error from before delete callbacks happened") + } + + p4 := Product{Code: "after_save_error", Price: 100} + DB.Save(&p4) + if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { + t.Fatalf("Record should be reverted if get an error in after save callback") + } + + p5 := Product{Code: "after_delete_error", Price: 100} + DB.Save(&p5) + if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { + t.Fatalf("Record should be found") + } + + DB.Delete(&p5) + if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { + t.Fatalf("Record shouldn't be deleted because of an error happened in after delete callback") + } +} diff --git a/tests/tests.go b/tests/tests.go index 7e216776..d9257898 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -59,9 +59,9 @@ func OpenTestConnection() (db *gorm.DB, err error) { } if debug := os.Getenv("DEBUG"); debug == "true" { - db.Logger.LogMode(logger.Info) + db.Logger = db.Logger.LogMode(logger.Info) } else if debug == "false" { - db.Logger.LogMode(logger.Silent) + db.Logger = db.Logger.LogMode(logger.Silent) } return diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 9405fd76..f39b3167 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -14,37 +14,37 @@ func TestTransaction(t *testing.T) { user := *GetUser("transcation", Config{}) if err := tx.Save(&user).Error; err != nil { - t.Errorf("No error should raise, but got %v", err) + t.Fatalf("No error should raise, but got %v", err) } if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { - t.Errorf("Should find saved record, but got %v", err) + t.Fatalf("Should find saved record, but got %v", err) } if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil { - t.Errorf("Should return the underlying sql.Tx") + t.Fatalf("Should return the underlying sql.Tx") } tx.Rollback() if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil { - t.Errorf("Should not find record after rollback, but got %v", err) + t.Fatalf("Should not find record after rollback, but got %v", err) } tx2 := DB.Begin() user2 := *GetUser("transcation-2", Config{}) if err := tx2.Save(&user2).Error; err != nil { - t.Errorf("No error should raise, but got %v", err) + t.Fatalf("No error should raise, but got %v", err) } if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should find saved record, but got %v", err) + t.Fatalf("Should find saved record, but got %v", err) } tx2.Commit() if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should be able to find committed record, but got %v", err) + t.Fatalf("Should be able to find committed record, but got %v", err) } } @@ -52,7 +52,7 @@ func TestTransactionWithBlock(t *testing.T) { assertPanic := func(f func()) { defer func() { if r := recover(); r == nil { - t.Errorf("The code did not panic") + t.Fatalf("The code did not panic") } }() f() @@ -62,39 +62,39 @@ func TestTransactionWithBlock(t *testing.T) { err := DB.Transaction(func(tx *gorm.DB) error { user := *GetUser("transcation-block", Config{}) if err := tx.Save(&user).Error; err != nil { - t.Errorf("No error should raise") + t.Fatalf("No error should raise") } if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { - t.Errorf("Should find saved record") + t.Fatalf("Should find saved record") } return errors.New("the error message") }) if err.Error() != "the error message" { - t.Errorf("Transaction return error will equal the block returns error") + t.Fatalf("Transaction return error will equal the block returns error") } if err := DB.First(&User{}, "name = ?", "transcation-block").Error; err == nil { - t.Errorf("Should not find record after rollback") + t.Fatalf("Should not find record after rollback") } // commit DB.Transaction(func(tx *gorm.DB) error { user := *GetUser("transcation-block-2", Config{}) if err := tx.Save(&user).Error; err != nil { - t.Errorf("No error should raise") + t.Fatalf("No error should raise") } if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { - t.Errorf("Should find saved record") + t.Fatalf("Should find saved record") } return nil }) if err := DB.First(&User{}, "name = ?", "transcation-block-2").Error; err != nil { - t.Errorf("Should be able to find committed record") + t.Fatalf("Should be able to find committed record") } // panic will rollback @@ -102,11 +102,11 @@ func TestTransactionWithBlock(t *testing.T) { DB.Transaction(func(tx *gorm.DB) error { user := *GetUser("transcation-block-3", Config{}) if err := tx.Save(&user).Error; err != nil { - t.Errorf("No error should raise") + t.Fatalf("No error should raise") } if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { - t.Errorf("Should find saved record") + t.Fatalf("Should find saved record") } panic("force panic") @@ -114,7 +114,7 @@ func TestTransactionWithBlock(t *testing.T) { }) if err := DB.First(&User{}, "name = ?", "transcation-block-3").Error; err == nil { - t.Errorf("Should not find record after panic rollback") + t.Fatalf("Should not find record after panic rollback") } } @@ -122,14 +122,14 @@ func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) { tx := DB.Begin() user := User{Name: "transcation"} if err := tx.Save(&user).Error; err != nil { - t.Errorf("No error should raise") + t.Fatalf("No error should raise") } if err := tx.Commit().Error; err != nil { - t.Errorf("Commit should not raise error") + t.Fatalf("Commit should not raise error") } if err := tx.Rollback().Error; err == nil { - t.Errorf("Rollback after commit should raise error") + t.Fatalf("Rollback after commit should raise error") } } From a02cb39a45483955fb45cd16168c6ed68af8c7ed Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 1 Jun 2020 00:36:18 +0800 Subject: [PATCH 373/881] Add more tests --- finisher_api.go | 2 +- tests/query_test.go | 43 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 334aea58..780de267 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -273,7 +273,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { // db.Find(&users).Pluck("age", &ages) func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Select{Columns: []clause.Column{{Name: column}}}) + tx.Statement.AddClauseIfNotExists(clause.Select{Columns: []clause.Column{{Name: column}}}) tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) return diff --git a/tests/query_test.go b/tests/query_test.go index b7c619d7..a4fe1243 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -2,8 +2,10 @@ package tests_test import ( "reflect" + "sort" "strconv" "testing" + "time" . "github.com/jinzhu/gorm/tests" ) @@ -81,6 +83,24 @@ func TestFind(t *testing.T) { } } +func TestFillSmallerStruct(t *testing.T) { + user := User{Name: "SmallerUser", Age: 100} + DB.Save(&user) + type SimpleUser struct { + Name string + ID int64 + UpdatedAt time.Time + CreatedAt time.Time + } + + var simpleUser SimpleUser + if err := DB.Table("users").Where("name = ?", user.Name).First(&simpleUser).Error; err != nil { + t.Fatalf("Failed to query smaller user, got error %v", err) + } + + AssertObjEqual(t, user, simpleUser, "Name", "ID", "UpdatedAt", "CreatedAt") +} + func TestPluck(t *testing.T) { users := []*User{ GetUser("pluck-user1", Config{}), @@ -92,12 +112,12 @@ func TestPluck(t *testing.T) { var names []string if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name").Pluck("name", &names).Error; err != nil { - t.Errorf("Raise error when pluck name, got %v", err) + t.Errorf("got error when pluck name: %v", err) } var ids []int if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name").Pluck("id", &ids).Error; err != nil { - t.Errorf("Raise error when pluck id, got %v", err) + t.Errorf("got error when pluck id: %v", err) } for idx, name := range names { @@ -112,3 +132,22 @@ func TestPluck(t *testing.T) { } } } + +func TestPluckWithSelect(t *testing.T) { + users := []User{ + {Name: "pluck_with_select_1", Age: 25}, + {Name: "pluck_with_select_2", Age: 26}, + } + + DB.Create(&users) + + var userAges []int + err := DB.Model(&User{}).Where("name like ?", "pluck_with_select%").Select("age + 1 as user_age").Pluck("user_age", &userAges).Error + if err != nil { + t.Fatalf("got error when pluck user_age: %v", err) + } + + sort.Ints(userAges) + + AssertEqual(t, userAges, []int{26, 27}) +} From 76b8e78dcb40539ff7723fbf88e7d5b4cd4be9ef Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 1 Jun 2020 08:12:44 +0800 Subject: [PATCH 374/881] Add multi primary keys test --- callbacks/preload.go | 6 +- dialects/mssql/mssql.go | 4 + dialects/mysql/mysql.go | 4 + dialects/postgres/postgres.go | 4 + dialects/sqlite/sqlite.go | 4 + interfaces.go | 1 + schema/relationship_test.go | 48 ++++ tests/dummy_dialecter.go | 4 + tests/multi_primary_keys_test.go | 395 +++++++++++++++++++++++++++++++ 9 files changed, 467 insertions(+), 3 deletions(-) create mode 100644 tests/multi_primary_keys_test.go diff --git a/callbacks/preload.go b/callbacks/preload.go index a77db2b1..5b5beb06 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -52,8 +52,8 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()) // convert join identity map to relation identity map - fieldValues := make([]interface{}, len(foreignFields)) - joinFieldValues := make([]interface{}, len(joinForeignFields)) + fieldValues := make([]interface{}, len(joinForeignFields)) + joinFieldValues := make([]interface{}, len(joinRelForeignFields)) for i := 0; i < joinResults.Len(); i++ { for idx, field := range joinForeignFields { fieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) @@ -94,7 +94,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { column, values := schema.ToQueryValues(relForeignKeys, foreignValues) tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), conds...) - fieldValues := make([]interface{}, len(foreignFields)) + fieldValues := make([]interface{}, len(relForeignFields)) for i := 0; i < reflectResults.Len(); i++ { for idx, field := range relForeignFields { fieldValues[idx], _ = field.ValueOf(reflectResults.Index(i)) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 3828c546..066aa38f 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -19,6 +19,10 @@ type Dialector struct { DSN string } +func (dialector Dialector) Name() string { + return "mssql" +} + func Open(dsn string) gorm.Dialector { return &Dialector{DSN: dsn} } diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index baeb79c7..e617a1e1 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -22,6 +22,10 @@ func Open(dsn string) gorm.Dialector { return &Dialector{DSN: dsn} } +func (dialector Dialector) Name() string { + return "mysql" +} + func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index db559b9d..fb3ecc68 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -23,6 +23,10 @@ func Open(dsn string) gorm.Dialector { return &Dialector{DSN: dsn} } +func (dialector Dialector) Name() string { + return "postgres" +} + func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 51829b17..1b9809af 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -20,6 +20,10 @@ func Open(dsn string) gorm.Dialector { return &Dialector{DSN: dsn} } +func (dialector Dialector) Name() string { + return "sqlite" +} + func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ diff --git a/interfaces.go b/interfaces.go index 14d8fa34..421428a3 100644 --- a/interfaces.go +++ b/interfaces.go @@ -10,6 +10,7 @@ import ( // Dialector GORM database dialector type Dialector interface { + Name() string Initialize(*DB) error Migrator(db *DB) Migrator DataTypeOf(*schema.Field) string diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 41e8c7bd..0f62f45d 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -197,3 +197,51 @@ func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { }, }) } + +func TestMany2ManyWithMultiPrimaryKeys(t *testing.T) { + type Tag struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Value string + } + + type Blog struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Subject string + Body string + Tags []Tag `gorm:"many2many:blog_tags;"` + SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;References:id"` + LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;References:id"` + } + + checkStructRelation(t, &Blog{}, + Relation{ + Name: "Tags", Type: schema.Many2Many, Schema: "Blog", FieldSchema: "Tag", + JoinTable: JoinTable{Name: "blog_tags", Table: "blog_tags"}, + References: []Reference{ + {"ID", "Blog", "BlogID", "blog_tags", "", true}, + {"Locale", "Blog", "BlogLocale", "blog_tags", "", true}, + {"ID", "Tag", "TagID", "blog_tags", "", false}, + {"Locale", "Tag", "TagLocale", "blog_tags", "", false}, + }, + }, + Relation{ + Name: "SharedTags", Type: schema.Many2Many, Schema: "Blog", FieldSchema: "Tag", + JoinTable: JoinTable{Name: "shared_blog_tags", Table: "shared_blog_tags"}, + References: []Reference{ + {"ID", "Blog", "BlogID", "shared_blog_tags", "", true}, + {"ID", "Tag", "TagID", "shared_blog_tags", "", false}, + }, + }, + Relation{ + Name: "LocaleTags", Type: schema.Many2Many, Schema: "Blog", FieldSchema: "Tag", + JoinTable: JoinTable{Name: "locale_blog_tags", Table: "locale_blog_tags"}, + References: []Reference{ + {"ID", "Blog", "BlogID", "locale_blog_tags", "", true}, + {"Locale", "Blog", "BlogLocale", "locale_blog_tags", "", true}, + {"ID", "Tag", "TagID", "locale_blog_tags", "", false}, + }, + }, + ) +} diff --git a/tests/dummy_dialecter.go b/tests/dummy_dialecter.go index 63af0c9c..4ea17a0f 100644 --- a/tests/dummy_dialecter.go +++ b/tests/dummy_dialecter.go @@ -10,6 +10,10 @@ import ( type DummyDialector struct { } +func (DummyDialector) Name() string { + return "dummy" +} + func (DummyDialector) Initialize(*gorm.DB) error { return nil } diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go new file mode 100644 index 00000000..b3284f15 --- /dev/null +++ b/tests/multi_primary_keys_test.go @@ -0,0 +1,395 @@ +package tests_test + +import ( + "reflect" + "sort" + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +type Blog struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Subject string + Body string + Tags []Tag `gorm:"many2many:blog_tags;"` + SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;References:id"` + LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;References:id"` +} + +type Tag struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Value string + Blogs []*Blog `gorm:"many2many:blogs_tags"` +} + +func compareTags(tags []Tag, contents []string) bool { + var tagContents []string + for _, tag := range tags { + tagContents = append(tagContents, tag.Value) + } + sort.Strings(tagContents) + sort.Strings(contents) + return reflect.DeepEqual(tagContents, contents) +} + +func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { + t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") + } + + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") + if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { + t.Fatalf("Failed to auto migrate, got error: %v", err) + } + + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + Tags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + + DB.Save(&blog) + if !compareTags(blog.Tags, []string{"tag1", "tag2"}) { + t.Fatalf("Blog should has two tags") + } + + // Append + var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("Tags").Append([]*Tag{tag3}) + + if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Blog should has three tags after Append") + } + + if count := DB.Model(&blog).Association("Tags").Count(); count != 3 { + t.Fatalf("Blog should has 3 tags after Append, got %v", count) + } + + var tags []Tag + DB.Model(&blog).Association("Tags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags") + } + + var blog1 Blog + DB.Preload("Tags").Find(&blog1) + if !compareTags(blog1.Tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Preload many2many relations") + } + + // Replace + var tag5 = &Tag{Locale: "ZH", Value: "tag5"} + var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog).Association("Tags").Replace(tag5, tag6) + var tags2 []Tag + DB.Model(&blog).Association("Tags").Find(&tags2) + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Fatalf("Should find 2 tags after Replace") + } + + if DB.Model(&blog).Association("Tags").Count() != 2 { + t.Fatalf("Blog should has three tags after Replace") + } + + // Delete + DB.Model(&blog).Association("Tags").Delete(tag5) + var tags3 []Tag + DB.Model(&blog).Association("Tags").Find(&tags3) + if !compareTags(tags3, []string{"tag6"}) { + t.Fatalf("Should find 1 tags after Delete") + } + + if DB.Model(&blog).Association("Tags").Count() != 1 { + t.Fatalf("Blog should has three tags after Delete") + } + + DB.Model(&blog).Association("Tags").Delete(tag3) + var tags4 []Tag + DB.Model(&blog).Association("Tags").Find(&tags4) + if !compareTags(tags4, []string{"tag6"}) { + t.Fatalf("Tag should not be deleted when Delete with a unrelated tag") + } + + // Clear + DB.Model(&blog).Association("Tags").Clear() + if DB.Model(&blog).Association("Tags").Count() != 0 { + t.Fatalf("All tags should be cleared") + } +} + +func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { + t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") + } + + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") + if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { + t.Fatalf("Failed to auto migrate, got error: %v", err) + } + + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + SharedTags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + DB.Save(&blog) + + blog2 := Blog{ + ID: blog.ID, + Locale: "EN", + } + DB.Create(&blog2) + + if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) { + t.Fatalf("Blog should has two tags") + } + + // Append + var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3}) + if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Blog should has three tags after Append") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 3 { + t.Fatalf("Blog should has three tags after Append") + } + + if DB.Model(&blog2).Association("SharedTags").Count() != 3 { + t.Fatalf("Blog should has three tags after Append") + } + + var tags []Tag + DB.Model(&blog).Association("SharedTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags") + } + + DB.Model(&blog2).Association("SharedTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags") + } + + var blog1 Blog + DB.Preload("SharedTags").Find(&blog1) + if !compareTags(blog1.SharedTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Preload many2many relations") + } + + var tag4 = &Tag{Locale: "ZH", Value: "tag4"} + DB.Model(&blog2).Association("SharedTags").Append(tag4) + + DB.Model(&blog).Association("SharedTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { + t.Fatalf("Should find 3 tags") + } + + DB.Model(&blog2).Association("SharedTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { + t.Fatalf("Should find 3 tags") + } + + // Replace + var tag5 = &Tag{Locale: "ZH", Value: "tag5"} + var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6) + var tags2 []Tag + DB.Model(&blog).Association("SharedTags").Find(&tags2) + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Fatalf("Should find 2 tags after Replace") + } + + DB.Model(&blog2).Association("SharedTags").Find(&tags2) + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Fatalf("Should find 2 tags after Replace") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 2 { + t.Fatalf("Blog should has three tags after Replace") + } + + // Delete + DB.Model(&blog).Association("SharedTags").Delete(tag5) + var tags3 []Tag + DB.Model(&blog).Association("SharedTags").Find(&tags3) + if !compareTags(tags3, []string{"tag6"}) { + t.Fatalf("Should find 1 tags after Delete") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 1 { + t.Fatalf("Blog should has three tags after Delete") + } + + DB.Model(&blog2).Association("SharedTags").Delete(tag3) + var tags4 []Tag + DB.Model(&blog).Association("SharedTags").Find(&tags4) + if !compareTags(tags4, []string{"tag6"}) { + t.Fatalf("Tag should not be deleted when Delete with a unrelated tag") + } + + // Clear + DB.Model(&blog2).Association("SharedTags").Clear() + if DB.Model(&blog).Association("SharedTags").Count() != 0 { + t.Fatalf("All tags should be cleared") + } +} + +func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { + t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") + } + + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") + if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { + t.Fatalf("Failed to auto migrate, got error: %v", err) + } + + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + LocaleTags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + DB.Save(&blog) + + blog2 := Blog{ + ID: blog.ID, + Locale: "EN", + } + DB.Create(&blog2) + + // Append + var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3}) + if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Blog should has three tags after Append") + } + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("Blog should has three tags after Append") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Fatalf("EN Blog should has 0 tags after ZH Blog Append") + } + + var tags []Tag + DB.Model(&blog).Association("LocaleTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags") + } + + DB.Model(&blog2).Association("LocaleTags").Find(&tags) + if len(tags) != 0 { + t.Fatalf("Should find 0 tags for EN Blog") + } + + var blog1 Blog + DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID) + if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Preload many2many relations") + } + + var tag4 = &Tag{Locale: "ZH", Value: "tag4"} + DB.Model(&blog2).Association("LocaleTags").Append(tag4) + + DB.Model(&blog).Association("LocaleTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags for EN Blog") + } + + DB.Model(&blog2).Association("LocaleTags").Find(&tags) + if !compareTags(tags, []string{"tag4"}) { + t.Fatalf("Should find 1 tags for EN Blog") + } + + // Replace + var tag5 = &Tag{Locale: "ZH", Value: "tag5"} + var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6) + + var tags2 []Tag + DB.Model(&blog).Association("LocaleTags").Find(&tags2) + if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("CN Blog's tags should not be changed after EN Blog Replace") + } + + var blog11 Blog + DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale) + if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("CN Blog's tags should not be changed after EN Blog Replace") + } + + DB.Model(&blog2).Association("LocaleTags").Find(&tags2) + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Fatalf("Should find 2 tags after Replace") + } + + var blog21 Blog + DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale) + if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) { + t.Fatalf("EN Blog's tags should be changed after Replace") + } + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("ZH Blog should has three tags after Replace") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { + t.Fatalf("EN Blog should has two tags after Replace") + } + + // Delete + DB.Model(&blog).Association("LocaleTags").Delete(tag5) + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("ZH Blog should has three tags after Delete with EN's tag") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { + t.Fatalf("EN Blog should has two tags after ZH Blog Delete with EN's tag") + } + + DB.Model(&blog2).Association("LocaleTags").Delete(tag5) + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("ZH Blog should has three tags after EN Blog Delete with EN's tag") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 1 { + t.Fatalf("EN Blog should has 1 tags after EN Blog Delete with EN's tag") + } + + // Clear + DB.Model(&blog2).Association("LocaleTags").Clear() + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("ZH Blog's tags should not be cleared when clear EN Blog's tags") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Fatalf("EN Blog's tags should be cleared when clear EN Blog's tags") + } + + DB.Model(&blog).Association("LocaleTags").Clear() + if DB.Model(&blog).Association("LocaleTags").Count() != 0 { + t.Fatalf("ZH Blog's tags should be cleared when clear ZH Blog's tags") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Fatalf("EN Blog's tags should be cleared") + } +} From dffc2713f010c4253b61adae61810e27044ab157 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 1 Jun 2020 10:02:20 +0800 Subject: [PATCH 375/881] Add mores tests for query --- chainable_api.go | 12 ++- statement.go | 21 ++-- tests/query_test.go | 197 ++++++++++++++++++++++++++++++++++- tests/scanner_valuer_test.go | 41 ++++++++ tests/sql_builder_test.go | 42 ++++++++ 5 files changed, 299 insertions(+), 14 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index afcdccd2..6fa605c6 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -111,21 +111,27 @@ func (db *DB) Omit(columns ...string) (tx *DB) { // Where add conditions func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(query, args...)}) + if conds := tx.Statement.BuildCondtion(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: conds}) + } return } // Not add NOT conditions func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}}) + if conds := tx.Statement.BuildCondtion(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}}) + } return } // Or add OR conditions func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(tx.Statement.BuildCondtion(query, args...)...)}}) + if conds := tx.Statement.BuildCondtion(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(conds...)}}) + } return } diff --git a/statement.go b/statement.go index 444d5c37..aa7d193c 100644 --- a/statement.go +++ b/statement.go @@ -204,12 +204,15 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { // BuildCondtion build condition func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conds []clause.Expression) { if sql, ok := query.(string); ok { - if i, err := strconv.Atoi(sql); err == nil { - query = i - } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { - return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} - } else if len(args) == 1 { - return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}} + // if it is a number, then treats it as primary key + if _, err := strconv.Atoi(sql); err != nil { + if sql == "" && len(args) == 0 { + return + } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { + return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} + } else if len(args) == 1 { + return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}} + } } } @@ -267,14 +270,12 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con } } } + } else if len(conds) == 0 { + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) } } } - if len(conds) == 0 { - conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) - } - return } diff --git a/tests/query_test.go b/tests/query_test.go index a4fe1243..6efadc8e 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1,12 +1,14 @@ package tests_test import ( + "fmt" "reflect" "sort" "strconv" "testing" "time" + "github.com/jinzhu/gorm" . "github.com/jinzhu/gorm/tests" ) @@ -115,8 +117,14 @@ func TestPluck(t *testing.T) { t.Errorf("got error when pluck name: %v", err) } + var names2 []string + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name desc").Pluck("name", &names2).Error; err != nil { + t.Errorf("got error when pluck name: %v", err) + } + AssertEqual(t, names, sort.Reverse(sort.StringSlice(names2))) + var ids []int - if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name").Pluck("id", &ids).Error; err != nil { + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids).Error; err != nil { t.Errorf("got error when pluck id: %v", err) } @@ -133,6 +141,21 @@ func TestPluck(t *testing.T) { } } +func TestSelect(t *testing.T) { + user := User{Name: "SelectUser1"} + DB.Save(&user) + + var result User + DB.Where("name = ?", user.Name).Select("name").Find(&result) + if result.ID != 0 { + t.Errorf("Should not have ID because only selected name, %+v", result.ID) + } + + if user.Name != result.Name { + t.Errorf("Should have user Name when selected it") + } +} + func TestPluckWithSelect(t *testing.T) { users := []User{ {Name: "pluck_with_select_1", Age: 25}, @@ -151,3 +174,175 @@ func TestPluckWithSelect(t *testing.T) { AssertEqual(t, userAges, []int{26, 27}) } + +func TestSelectWithVariables(t *testing.T) { + DB.Save(&User{Name: "select_with_variables"}) + + rows, _ := DB.Table("users").Where("name = ?", "select_with_variables").Select("? as fake", gorm.Expr("name")).Rows() + + if !rows.Next() { + t.Errorf("Should have returned at least one row") + } else { + columns, _ := rows.Columns() + AssertEqual(t, columns, []string{"fake"}) + } + + rows.Close() +} + +func TestSelectWithArrayInput(t *testing.T) { + DB.Save(&User{Name: "select_with_array", Age: 42}) + + var user User + DB.Select([]string{"name", "age"}).Where("age = 42 AND name = ?", "select_with_array").First(&user) + + if user.Name != "select_with_array" || user.Age != 42 { + t.Errorf("Should have selected both age and name") + } +} + +func TestCustomizedTypePrimaryKey(t *testing.T) { + type ID uint + type CustomizedTypePrimaryKey struct { + ID ID + Name string + } + + DB.Migrator().DropTable(&CustomizedTypePrimaryKey{}) + if err := DB.AutoMigrate(&CustomizedTypePrimaryKey{}); err != nil { + t.Fatalf("failed to migrate, got error %v", err) + } + + p1 := CustomizedTypePrimaryKey{Name: "p1"} + p2 := CustomizedTypePrimaryKey{Name: "p2"} + p3 := CustomizedTypePrimaryKey{Name: "p3"} + DB.Create(&p1) + DB.Create(&p2) + DB.Create(&p3) + + var p CustomizedTypePrimaryKey + + if err := DB.First(&p, p2.ID).Error; err != nil { + t.Errorf("No error should returns, but got %v", err) + } + + AssertEqual(t, p, p2) + + if err := DB.First(&p, "id = ?", p2.ID).Error; err != nil { + t.Errorf("No error should happen when querying with customized type for primary key, got err %v", err) + } + + AssertEqual(t, p, p2) +} + +func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { + type AddressByZipCode struct { + ZipCode string `gorm:"primary_key"` + Address string + } + + DB.Migrator().DropTable(&AddressByZipCode{}) + if err := DB.AutoMigrate(&AddressByZipCode{}); err != nil { + t.Fatalf("failed to migrate, got error %v", err) + } + + address := AddressByZipCode{ZipCode: "00501", Address: "Holtsville"} + DB.Create(&address) + + var result AddressByZipCode + DB.First(&result, "00501") + + AssertEqual(t, result, address) +} + +func TestSearchWithEmptyChain(t *testing.T) { + user := User{Name: "search_with_empty_chain", Age: 1} + DB.Create(&user) + + var result User + if DB.Where("").Where("").First(&result).Error != nil { + t.Errorf("Should not raise any error if searching with empty strings") + } + + if DB.Where(&User{}).Where("name = ?", user.Name).First(&result).Error != nil { + t.Errorf("Should not raise any error if searching with empty struct") + } + + if DB.Where(map[string]interface{}{}).Where("name = ?", user.Name).First(&result).Error != nil { + t.Errorf("Should not raise any error if searching with empty map") + } +} + +func TestLimit(t *testing.T) { + users := []User{ + {Name: "LimitUser1", Age: 1}, + {Name: "LimitUser2", Age: 10}, + {Name: "LimitUser3", Age: 20}, + {Name: "LimitUser4", Age: 10}, + {Name: "LimitUser5", Age: 20}, + } + + DB.Create(&users) + + var users1, users2, users3 []User + DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) + + if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { + t.Errorf("Limit should works") + } +} + +func TestOffset(t *testing.T) { + for i := 0; i < 20; i++ { + DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) + } + var users1, users2, users3, users4 []User + + DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) + + if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { + t.Errorf("Offset should work") + } +} + +func TestSearchWithMap(t *testing.T) { + users := []User{ + *GetUser("map_search_user1", Config{}), + *GetUser("map_search_user2", Config{}), + *GetUser("map_search_user3", Config{}), + *GetUser("map_search_user4", Config{Company: true}), + } + + DB.Create(&users) + + var user User + DB.First(&user, map[string]interface{}{"name": users[0].Name}) + CheckUser(t, user, users[0]) + + DB.Where(map[string]interface{}{"name": users[1].Name}).First(&user) + CheckUser(t, user, users[1]) + + var results []User + DB.Where(map[string]interface{}{"name": users[2].Name}).Find(&results) + if len(results) != 1 { + t.Fatalf("Search all records with inline map") + } + + CheckUser(t, results[0], users[2]) + + var results2 []User + DB.Find(&results2, map[string]interface{}{"name": users[3].Name, "company_id": nil}) + if len(results2) != 0 { + t.Errorf("Search all records with inline map containing null value finding 0 records") + } + + DB.Find(&results2, map[string]interface{}{"name": users[0].Name, "company_id": nil}) + if len(results2) != 1 { + t.Errorf("Search all records with inline map containing null value finding 1 record") + } + + DB.Find(&results2, map[string]interface{}{"name": users[3].Name, "company_id": users[3].CompanyID}) + if len(results2) != 1 { + t.Errorf("Search all records with inline multiple value map") + } +} diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 04c91ab2..9f91b5d8 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -50,6 +50,47 @@ func TestScannerValuer(t *testing.T) { AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Num", "Strings", "Structs") } +func TestScannerValuerWithFirstOrCreate(t *testing.T) { + DB.Migrator().DropTable(&ScannerValuerStruct{}) + if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { + t.Errorf("no error should happen when migrate scanner, valuer struct") + } + + data := ScannerValuerStruct{ + Name: sql.NullString{String: "name", Valid: true}, + Gender: &sql.NullString{String: "M", Valid: true}, + Age: sql.NullInt64{Int64: 18, Valid: true}, + } + + var result ScannerValuerStruct + tx := DB.Where(data).FirstOrCreate(&result) + + if tx.RowsAffected != 1 { + t.Errorf("RowsAffected should be 1 after create some record") + } + + if tx.Error != nil { + t.Errorf("Should not raise any error, but got %v", tx.Error) + } + + AssertObjEqual(t, result, data, "Name", "Gender", "Age") + + if err := DB.Where(data).Assign(ScannerValuerStruct{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&result).Error; err != nil { + t.Errorf("Should not raise any error, but got %v", err) + } + + if result.Age.Int64 != 18 { + t.Errorf("should update age to 18") + } + + var result2 ScannerValuerStruct + if err := DB.First(&result2, result.ID).Error; err != nil { + t.Errorf("got error %v when query with %v", err, result.ID) + } + + AssertObjEqual(t, result2, result, "ID", "CreatedAt", "UpdatedAt", "Name", "Gender", "Age") +} + func TestInvalidValuer(t *testing.T) { DB.Migrator().DropTable(&ScannerValuerStruct{}) if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 4cd40c7a..0aed82a2 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -80,3 +80,45 @@ func TestRaw(t *testing.T) { t.Error("Raw sql to update records") } } + +func TestRowsWithGroup(t *testing.T) { + users := []User{ + {Name: "having_user_1", Age: 1}, + {Name: "having_user_2", Age: 10}, + {Name: "having_user_1", Age: 20}, + {Name: "having_user_1", Age: 30}, + } + + DB.Create(&users) + + rows, err := DB.Select("name, count(*) as total").Table("users").Group("name").Having("name IN ?", []string{users[0].Name, users[1].Name}).Rows() + if err != nil { + t.Fatalf("got error %v", err) + } + + defer rows.Close() + for rows.Next() { + var name string + var total int64 + rows.Scan(&name, &total) + + if name == users[0].Name && total != 3 { + t.Errorf("Should have one user having name %v", users[0].Name) + } else if name == users[1].Name && total != 1 { + t.Errorf("Should have two users having name %v", users[1].Name) + } + } +} + +func TestQueryRaw(t *testing.T) { + users := []*User{ + GetUser("row_query_user", Config{}), + GetUser("row_query_user", Config{}), + GetUser("row_query_user", Config{}), + } + DB.Create(&users) + + var user User + DB.Raw("select * from users WHERE id = ?", users[1].ID).First(&user) + CheckUser(t, user, *users[1]) +} From 1559fe24e5d193a31ca31470482cb75137b1e080 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 1 Jun 2020 19:41:33 +0800 Subject: [PATCH 376/881] Add more updates test --- association.go | 8 + callbacks/associations.go | 7 + callbacks/callbacks.go | 1 + callbacks/query.go | 13 ++ callbacks/update.go | 64 +++++--- schema/field.go | 2 + tests/associations_test.go | 6 +- tests/delete_test.go | 2 + tests/query_test.go | 3 + tests/update_test.go | 303 +++++++++++++++++++++++++++++++++++++ tests/utils.go | 32 ++++ 11 files changed, 419 insertions(+), 22 deletions(-) diff --git a/association.go b/association.go index bed89837..55dd7772 100644 --- a/association.go +++ b/association.go @@ -86,6 +86,14 @@ func (association *Association) Replace(values ...interface{}) error { case schema.BelongsTo: if len(values) == 0 { updateMap := map[string]interface{}{} + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) + } + case reflect.Struct: + rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) + } for _, ref := range rel.References { updateMap[ref.ForeignKey.DBName] = nil diff --git a/callbacks/associations.go b/callbacks/associations.go index 3c8c2a50..d19f7339 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -24,6 +24,13 @@ func SaveBeforeAssociations(db *gorm.DB) { if !ref.OwnPrimaryKey { pv, _ := ref.PrimaryKey.ValueOf(elem) ref.ForeignKey.Set(obj, pv) + + if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { + dest[ref.ForeignKey.DBName] = pv + if _, ok := dest[rel.Name]; ok { + dest[rel.Name] = elem.Interface() + } + } } } } diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index 1985aec2..1c1d6ade 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -37,6 +37,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { updateCallback := db.Callback().Update() updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) updateCallback.Register("gorm:before_update", BeforeUpdate) updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) updateCallback.Register("gorm:update", Update) diff --git a/callbacks/query.go b/callbacks/query.go index 91948031..e4e76665 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -37,6 +37,19 @@ func Query(db *gorm.DB) { func BuildQuerySQL(db *gorm.DB) { clauseSelect := clause.Select{} + if db.Statement.ReflectValue.Kind() == reflect.Struct { + var conds []clause.Expression + for _, primaryField := range db.Statement.Schema.PrimaryFields { + if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { + conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) + } + } + + if len(conds) > 0 { + db.Statement.AddClause(clause.Where{Exprs: conds}) + } + } + if len(db.Statement.Selects) > 0 { for _, name := range db.Statement.Selects { if db.Statement.Schema == nil { diff --git a/callbacks/update.go b/callbacks/update.go index cbbcddf7..fda07676 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -9,6 +9,25 @@ import ( "github.com/jinzhu/gorm/schema" ) +func SetupUpdateReflectValue(db *gorm.DB) { + if db.Error == nil { + if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest { + db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) + for db.Statement.ReflectValue.Kind() == reflect.Ptr { + db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() + } + + if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { + for _, rel := range db.Statement.Schema.Relationships.BelongsTo { + if _, ok := dest[rel.Name]; ok { + rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) + } + } + } + } + } +} + func BeforeUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { tx := db.Session(&gorm.Session{}) @@ -114,21 +133,20 @@ func AfterUpdate(db *gorm.DB) { func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { var ( selectColumns, restricted = SelectAndOmitColumns(stmt, false, true) - reflectModelValue = reflect.Indirect(reflect.ValueOf(stmt.Model)) assignValue func(field *schema.Field, value interface{}) ) - switch reflectModelValue.Kind() { + switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: assignValue = func(field *schema.Field, value interface{}) { - for i := 0; i < reflectModelValue.Len(); i++ { - field.Set(reflectModelValue.Index(i), value) + for i := 0; i < stmt.ReflectValue.Len(); i++ { + field.Set(stmt.ReflectValue.Index(i), value) } } case reflect.Struct: assignValue = func(field *schema.Field, value interface{}) { - if reflectModelValue.CanAddr() { - field.Set(reflectModelValue, value) + if stmt.ReflectValue.CanAddr() { + field.Set(stmt.ReflectValue, value) } } default: @@ -136,7 +154,12 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } - switch value := stmt.Dest.(type) { + updatingValue := reflect.ValueOf(stmt.Dest) + for updatingValue.Kind() == reflect.Ptr { + updatingValue = updatingValue.Elem() + } + + switch value := updatingValue.Interface().(type) { case map[string]interface{}: set = make([]clause.Assignment, 0, len(value)) @@ -148,8 +171,12 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { for _, k := range keys { if field := stmt.Schema.LookUpField(k); field != nil { - if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) + if field.DBName != "" { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) + assignValue(field, value[k]) + } + } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { assignValue(field, value[k]) } } else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { @@ -167,13 +194,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } default: - switch stmt.ReflectValue.Kind() { + switch updatingValue.Kind() { case reflect.Struct: set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) for _, field := range stmt.Schema.FieldsByDBName { - if !field.PrimaryKey || (!stmt.ReflectValue.CanAddr() || stmt.Dest != stmt.Model) { + if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - value, isZero := field.ValueOf(stmt.ReflectValue) + value, isZero := field.ValueOf(updatingValue) if !stmt.DisableUpdateTime { if field.AutoUpdateTime > 0 { value = stmt.DB.NowFunc() @@ -187,7 +214,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } } else { - if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + if value, isZero := field.ValueOf(updatingValue); !isZero { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } @@ -195,16 +222,15 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } - if !stmt.ReflectValue.CanAddr() || stmt.Dest != stmt.Model { - reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Model)) - switch reflectValue.Kind() { + if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: var priamryKeyExprs []clause.Expression - for i := 0; i < reflectValue.Len(); i++ { + for i := 0; i < stmt.ReflectValue.Len(); i++ { var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) var notZero bool for idx, field := range stmt.Schema.PrimaryFields { - value, isZero := field.ValueOf(reflectValue.Index(i)) + value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) exprs[idx] = clause.Eq{Column: field.DBName, Value: value} notZero = notZero || !isZero } @@ -215,7 +241,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}}) case reflect.Struct: for _, field := range stmt.Schema.PrimaryFields { - if value, isZero := field.ValueOf(reflectValue); !isZero { + if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } diff --git a/schema/field.go b/schema/field.go index f52dd6a6..8a0f01bf 100644 --- a/schema/field.go +++ b/schema/field.go @@ -347,6 +347,8 @@ func (field *Field) setupValuerAndSetter() { if v.Type().Elem().Kind() == reflect.Struct { if !v.IsNil() { v = v.Elem() + } else { + return nil, true } } else { return nil, true diff --git a/tests/associations_test.go b/tests/associations_test.go index 89bbe142..3668b44b 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -8,7 +8,7 @@ import ( func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) { if count := DB.Model(data).Association(name).Count(); count != result { - t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count) + t.Fatalf("invalid %v count %v, expects: %v got %v", name, reason, result, count) } var newUser User @@ -20,7 +20,7 @@ func AssertAssociationCount(t *testing.T, data interface{}, name string, result if newUser.ID != 0 { if count := DB.Model(&newUser).Association(name).Count(); count != result { - t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count) + t.Fatalf("invalid %v count %v, expects: %v got %v", name, reason, result, count) } } } @@ -28,6 +28,6 @@ func AssertAssociationCount(t *testing.T, data interface{}, name string, result func TestInvalidAssociation(t *testing.T) { var user = *GetUser("invalid", Config{Company: true, Manager: true}) if err := DB.Model(&user).Association("Invalid").Find(&user.Company).Error; err == nil { - t.Errorf("should return errors for invalid association, but got nil") + t.Fatalf("should return errors for invalid association, but got nil") } } diff --git a/tests/delete_test.go b/tests/delete_test.go index 4288253f..e7076aa6 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -31,12 +31,14 @@ func TestDelete(t *testing.T) { } for _, user := range []User{users[0], users[2]} { + result = User{} if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { t.Errorf("no error should returns when query %v, but got %v", user.ID, err) } } for _, user := range []User{users[0], users[2]} { + result = User{} if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { t.Errorf("no error should returns when query %v, but got %v", user.ID, err) } diff --git a/tests/query_test.go b/tests/query_test.go index 6efadc8e..73b6dca3 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -264,10 +264,12 @@ func TestSearchWithEmptyChain(t *testing.T) { t.Errorf("Should not raise any error if searching with empty strings") } + result = User{} if DB.Where(&User{}).Where("name = ?", user.Name).First(&result).Error != nil { t.Errorf("Should not raise any error if searching with empty struct") } + result = User{} if DB.Where(map[string]interface{}{}).Where("name = ?", user.Name).First(&result).Error != nil { t.Errorf("Should not raise any error if searching with empty map") } @@ -319,6 +321,7 @@ func TestSearchWithMap(t *testing.T) { DB.First(&user, map[string]interface{}{"name": users[0].Name}) CheckUser(t, user, users[0]) + user = User{} DB.Where(map[string]interface{}{"name": users[1].Name}).First(&user) CheckUser(t, user, users[1]) diff --git a/tests/update_test.go b/tests/update_test.go index 869ce4cd..a5a62237 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -2,6 +2,8 @@ package tests_test import ( "errors" + "sort" + "strings" "testing" "time" @@ -218,3 +220,304 @@ func TestBlockGlobalUpdate(t *testing.T) { t.Errorf("should returns missing WHERE clause while updating error, got err %v", err) } } + +func TestSelectWithUpdate(t *testing.T) { + user := *GetUser("select_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + var result User + DB.First(&result, user.ID) + + user2 := *GetUser("select_update_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + result.Name = user2.Name + result.Age = 50 + result.Account = user2.Account + result.Pets = user2.Pets + result.Toys = user2.Toys + result.Company = user2.Company + result.Manager = user2.Manager + result.Team = user2.Team + result.Languages = user2.Languages + result.Friends = user2.Friends + + DB.Select("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Save(&result) + + var result2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) + + result.Languages = append(user.Languages, result.Languages...) + result.Toys = append(user.Toys, result.Toys...) + + sort.Slice(result.Languages, func(i, j int) bool { + return strings.Compare(result.Languages[i].Code, result.Languages[j].Code) > 0 + }) + + sort.Slice(result.Toys, func(i, j int) bool { + return result.Toys[i].ID < result.Toys[j].ID + }) + + sort.Slice(result2.Languages, func(i, j int) bool { + return strings.Compare(result2.Languages[i].Code, result2.Languages[j].Code) > 0 + }) + + sort.Slice(result2.Toys, func(i, j int) bool { + return result2.Toys[i].ID < result2.Toys[j].ID + }) + + AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages") +} + +func TestSelectWithUpdateWithMap(t *testing.T) { + user := *GetUser("select_update_map", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + var result User + DB.First(&result, user.ID) + + user2 := *GetUser("select_update_map_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + updateValues := map[string]interface{}{ + "Name": user2.Name, + "Age": 50, + "Account": user2.Account, + "Pets": user2.Pets, + "Toys": user2.Toys, + "Company": user2.Company, + "Manager": user2.Manager, + "Team": user2.Team, + "Languages": user2.Languages, + "Friends": user2.Friends, + } + + DB.Model(&result).Select("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Updates(updateValues) + + var result2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) + + result.Languages = append(user.Languages, result.Languages...) + result.Toys = append(user.Toys, result.Toys...) + + sort.Slice(result.Languages, func(i, j int) bool { + return strings.Compare(result.Languages[i].Code, result.Languages[j].Code) > 0 + }) + + sort.Slice(result.Toys, func(i, j int) bool { + return result.Toys[i].ID < result.Toys[j].ID + }) + + sort.Slice(result2.Languages, func(i, j int) bool { + return strings.Compare(result2.Languages[i].Code, result2.Languages[j].Code) > 0 + }) + + sort.Slice(result2.Toys, func(i, j int) bool { + return result2.Toys[i].ID < result2.Toys[j].ID + }) + + AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages") +} + +func TestOmitWithUpdate(t *testing.T) { + user := *GetUser("omit_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + var result User + DB.First(&result, user.ID) + + user2 := *GetUser("omit_update_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + result.Name = user2.Name + result.Age = 50 + result.Account = user2.Account + result.Pets = user2.Pets + result.Toys = user2.Toys + result.Company = user2.Company + result.Manager = user2.Manager + result.Team = user2.Team + result.Languages = user2.Languages + result.Friends = user2.Friends + + DB.Omit("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Save(&result) + + var result2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) + + result.Pets = append(user.Pets, result.Pets...) + result.Team = append(user.Team, result.Team...) + result.Friends = append(user.Friends, result.Friends...) + + sort.Slice(result.Pets, func(i, j int) bool { + return result.Pets[i].ID < result.Pets[j].ID + }) + sort.Slice(result.Team, func(i, j int) bool { + return result.Team[i].ID < result.Team[j].ID + }) + sort.Slice(result.Friends, func(i, j int) bool { + return result.Friends[i].ID < result.Friends[j].ID + }) + sort.Slice(result2.Pets, func(i, j int) bool { + return result2.Pets[i].ID < result2.Pets[j].ID + }) + sort.Slice(result2.Team, func(i, j int) bool { + return result2.Team[i].ID < result2.Team[j].ID + }) + sort.Slice(result2.Friends, func(i, j int) bool { + return result2.Friends[i].ID < result2.Friends[j].ID + }) + + AssertObjEqual(t, result2, result, "Age", "Pets", "Company", "CompanyID", "Team", "Friends") +} + +func TestOmitWithUpdateWithMap(t *testing.T) { + user := *GetUser("omit_update_map", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + var result User + DB.First(&result, user.ID) + + user2 := *GetUser("omit_update_map_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + updateValues := map[string]interface{}{ + "Name": user2.Name, + "Age": 50, + "Account": user2.Account, + "Pets": user2.Pets, + "Toys": user2.Toys, + "Company": user2.Company, + "Manager": user2.Manager, + "Team": user2.Team, + "Languages": user2.Languages, + "Friends": user2.Friends, + } + + DB.Model(&result).Omit("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Updates(updateValues) + + var result2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) + + result.Pets = append(user.Pets, result.Pets...) + result.Team = append(user.Team, result.Team...) + result.Friends = append(user.Friends, result.Friends...) + + sort.Slice(result.Pets, func(i, j int) bool { + return result.Pets[i].ID < result.Pets[j].ID + }) + sort.Slice(result.Team, func(i, j int) bool { + return result.Team[i].ID < result.Team[j].ID + }) + sort.Slice(result.Friends, func(i, j int) bool { + return result.Friends[i].ID < result.Friends[j].ID + }) + sort.Slice(result2.Pets, func(i, j int) bool { + return result2.Pets[i].ID < result2.Pets[j].ID + }) + sort.Slice(result2.Team, func(i, j int) bool { + return result2.Team[i].ID < result2.Team[j].ID + }) + sort.Slice(result2.Friends, func(i, j int) bool { + return result2.Friends[i].ID < result2.Friends[j].ID + }) + + AssertObjEqual(t, result2, result, "Age", "Pets", "Company", "CompanyID", "Team", "Friends") +} + +func TestSelectWithUpdateColumn(t *testing.T) { + user := *GetUser("select_with_update_column", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} + + var result User + DB.First(&result, user.ID) + DB.Model(&result).Select("Name").UpdateColumns(updateValues) + + var result2 User + DB.First(&result2, user.ID) + + if result2.Name == user.Name || result2.Age != user.Age { + t.Errorf("Should only update users with name column") + } +} + +func TestOmitWithUpdateColumn(t *testing.T) { + user := *GetUser("omit_with_update_column", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} + + var result User + DB.First(&result, user.ID) + DB.Model(&result).Omit("Name").UpdateColumns(updateValues) + + var result2 User + DB.First(&result2, user.ID) + + if result2.Name != user.Name || result2.Age == user.Age { + t.Errorf("Should only update users with name column") + } +} + +func TestUpdateColumnsSkipsAssociations(t *testing.T) { + user := *GetUser("update_column_skips_association", Config{}) + DB.Create(&user) + + // Update a single field of the user and verify that the changed address is not stored. + newAge := uint(100) + user.Account.Number = "new_account_number" + db := DB.Model(&user).UpdateColumns(User{Age: newAge}) + + if db.RowsAffected != 1 { + t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", db.RowsAffected) + } + + // Verify that Age now=`newAge`. + result := &User{} + result.ID = user.ID + DB.Preload("Account").First(result) + + if result.Age != newAge { + t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, result.Age) + } + + if result.Account.Number != user.Account.Number { + t.Errorf("account number should not been changed, expects: %v, got %v", user.Account.Number, result.Account.Number) + } +} + +func TestUpdatesWithBlankValues(t *testing.T) { + user := *GetUser("updates_with_blank_value", Config{}) + DB.Save(&user) + + var user2 User + user2.ID = user.ID + DB.Model(&user2).Updates(&User{Age: 100}) + + var result User + DB.First(&result, user.ID) + + if result.Name != user.Name || result.Age != 100 { + t.Errorf("user's name should not be updated") + } +} + +func TestUpdatesTableWithIgnoredValues(t *testing.T) { + type ElementWithIgnoredField struct { + Id int64 + Value string + IgnoredField int64 `gorm:"-"` + } + DB.Migrator().DropTable(&ElementWithIgnoredField{}) + DB.AutoMigrate(&ElementWithIgnoredField{}) + + elem := ElementWithIgnoredField{Value: "foo", IgnoredField: 10} + DB.Save(&elem) + + DB.Model(&ElementWithIgnoredField{}). + Where("id = ?", elem.Id). + Updates(&ElementWithIgnoredField{Value: "bar", IgnoredField: 100}) + + var result ElementWithIgnoredField + if err := DB.First(&result, elem.Id).Error; err != nil { + t.Errorf("error getting an element from database: %s", err.Error()) + } + + if result.IgnoredField != 0 { + t.Errorf("element's ignored field should not be updated") + } +} diff --git a/tests/utils.go b/tests/utils.go index 7cc6d2bc..97b5d5c8 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -3,6 +3,7 @@ package tests import ( "database/sql/driver" "fmt" + "go/ast" "reflect" "sort" "strconv" @@ -126,6 +127,37 @@ func AssertEqual(t *testing.T, got, expect interface{}) { return } + if reflect.ValueOf(got).Kind() == reflect.Slice { + if reflect.ValueOf(expect).Kind() == reflect.Slice { + if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() { + for i := 0; i < reflect.ValueOf(got).Len(); i++ { + name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i) + t.Run(name, func(t *testing.T) { + AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface()) + }) + } + } else { + name := reflect.ValueOf(got).Type().Elem().Name() + t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len()) + } + return + } + } + + if reflect.ValueOf(got).Kind() == reflect.Struct { + if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { + for i := 0; i < reflect.ValueOf(got).NumField(); i++ { + if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { + field := reflect.ValueOf(got).Field(i) + t.Run(fieldStruct.Name, func(t *testing.T) { + AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) + }) + } + } + return + } + } + if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() isEqual() From 4e147e1256b7118eb4c0126bd866659738117617 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 1 Jun 2020 21:26:23 +0800 Subject: [PATCH 377/881] Test SubQuery --- callbacks.go | 2 +- callbacks/create.go | 104 ++++++++++++++++++++------------------- callbacks/delete.go | 12 +++-- callbacks/query.go | 16 +++--- callbacks/update.go | 12 +++-- dialects/mssql/create.go | 54 ++++++++++---------- gorm.go | 7 +++ logger/sql.go | 4 +- statement.go | 14 ++++-- tests/query_test.go | 86 ++++++++++++++++++++++++++++++++ 10 files changed, 212 insertions(+), 99 deletions(-) diff --git a/callbacks.go b/callbacks.go index d05947d9..d3cd8e62 100644 --- a/callbacks.go +++ b/callbacks.go @@ -80,7 +80,7 @@ func (p *processor) Execute(db *DB) { } if stmt.Model != nil { - if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") { + if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { db.AddError(err) } } diff --git a/callbacks/create.go b/callbacks/create.go index 7a2b8bfe..01329141 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -63,36 +63,38 @@ func Create(config *Config) func(db *gorm.DB) { db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") } - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if !db.DryRun { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err == nil { - if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { - if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { - if insertID, err := result.LastInsertId(); err == nil { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID-- - } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID++ + if err == nil { + if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { + if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { + if insertID, err := result.LastInsertId(); err == nil { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID-- + } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID++ + } } + case reflect.Struct: + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) } - case reflect.Struct: - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } else { + db.AddError(err) } - } else { - db.AddError(err) } } + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) } - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) } } } @@ -135,42 +137,44 @@ func CreateWithReturning(db *gorm.DB) { idx++ } - rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if !db.DryRun { + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err == nil { - defer rows.Close() + if err == nil { + defer rows.Close() - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for rows.Next() { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for rows.Next() { + for idx, field := range fields { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + } + if err := rows.Scan(values...); err != nil { + db.AddError(err) + } + db.RowsAffected++ + } + case reflect.Struct: for idx, field := range fields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() } - if err := rows.Scan(values...); err != nil { - db.AddError(err) - } - db.RowsAffected++ - } - case reflect.Struct: - for idx, field := range fields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() - } - if rows.Next() { - db.RowsAffected++ - err = rows.Scan(values...) + if rows.Next() { + db.RowsAffected++ + err = rows.Scan(values...) + } } } - } - if err != nil { - db.AddError(err) - } - } else { - if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { - db.RowsAffected, _ = result.RowsAffected() + if err != nil { + db.AddError(err) + } } else { - db.AddError(err) + if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } } } } diff --git a/callbacks/delete.go b/callbacks/delete.go index 582a76f4..451569cf 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -72,12 +72,14 @@ func Delete(db *gorm.DB) { db.Statement.Build("DELETE", "FROM", "WHERE") } - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if !db.DryRun { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } } } } diff --git a/callbacks/query.go b/callbacks/query.go index e4e76665..f7c3271f 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -23,14 +23,16 @@ func Query(db *gorm.DB) { BuildQuerySQL(db) } - rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err != nil { - db.AddError(err) - return - } - defer rows.Close() + if !db.DryRun { + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if err != nil { + db.AddError(err) + return + } + defer rows.Close() - gorm.Scan(rows, db, false) + gorm.Scan(rows, db, false) + } } } diff --git a/callbacks/update.go b/callbacks/update.go index fda07676..a52bd310 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -85,12 +85,14 @@ func Update(db *gorm.DB) { return } - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if !db.DryRun { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } } } } diff --git a/dialects/mssql/create.go b/dialects/mssql/create.go index ebdeeab0..6820bb7b 100644 --- a/dialects/mssql/create.go +++ b/dialects/mssql/create.go @@ -61,41 +61,43 @@ func Create(db *gorm.DB) { } } - rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if !db.DryRun { + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err == nil { - defer rows.Close() + if err == nil { + defer rows.Close() - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if len(db.Statement.Schema.PrimaryFields) > 0 { - values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if len(db.Statement.Schema.PrimaryFields) > 0 { + values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) + + for rows.Next() { + for idx, field := range db.Statement.Schema.PrimaryFields { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + } + + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + } + } + case reflect.Struct: + if len(db.Statement.Schema.PrimaryFields) > 0 { + values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) - for rows.Next() { for idx, field := range db.Statement.Schema.PrimaryFields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() } - db.RowsAffected++ - db.AddError(rows.Scan(values...)) - } - } - case reflect.Struct: - if len(db.Statement.Schema.PrimaryFields) > 0 { - values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) - - for idx, field := range db.Statement.Schema.PrimaryFields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() - } - - if rows.Next() { - db.RowsAffected++ - db.AddError(rows.Scan(values...)) + if rows.Next() { + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + } } } + } else { + db.AddError(err) } - } else { - db.AddError(err) } } diff --git a/gorm.go b/gorm.go index c1d6f8da..7d6bd2ed 100644 --- a/gorm.go +++ b/gorm.go @@ -22,6 +22,8 @@ type Config struct { Logger logger.Interface // NowFunc the function to be used when creating a new timestamp NowFunc func() time.Time + // DryRun generate sql without execute + DryRun bool // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder @@ -45,6 +47,7 @@ type DB struct { // Session session config when create session with Session() method type Session struct { + DryRun bool WithConditions bool Context context.Context Logger logger.Interface @@ -120,6 +123,10 @@ func (db *DB) Session(config *Session) *DB { tx.clone = 3 } + if config.DryRun { + tx.Config.DryRun = true + } + if config.Logger != nil { tx.Config.Logger = config.Logger } diff --git a/logger/sql.go b/logger/sql.go index dd502324..d3c0bf10 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -22,8 +22,10 @@ func isPrintable(s []byte) bool { var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} -func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, vars ...interface{}) string { +func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { var convertParams func(interface{}, int) + var vars = make([]interface{}, len(avars)) + copy(vars, avars) convertParams = func(v interface{}, idx int) { switch v := v.(type) { diff --git a/statement.go b/statement.go index aa7d193c..03d1b8a8 100644 --- a/statement.go +++ b/statement.go @@ -157,6 +157,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } else { writer.WriteString("(NULL)") } + case *DB: + result := v.Session(&Session{DryRun: true, WithConditions: true}).Find(nil).Statement + writer.WriteString(result.SQL.String()) + stmt.Vars = append(stmt.Vars, result.Vars...) default: switch rv := reflect.ValueOf(v); rv.Kind() { case reflect.Slice, reflect.Array: @@ -226,7 +230,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con case clause.Expression: conds = append(conds, v) case *DB: - if v.Statement == nil { + if v.Statement != nil { if cs, ok := v.Statement.Clauses["WHERE"]; ok { conds = append(conds, cs.Expression) } @@ -367,7 +371,9 @@ func (stmt *Statement) reinit() { // }) // stmt.Schema = nil - stmt.SQL.Reset() - stmt.Vars = nil - stmt.NamedVars = nil + if !stmt.DB.DryRun { + stmt.SQL.Reset() + stmt.Vars = nil + stmt.NamedVars = nil + } } diff --git a/tests/query_test.go b/tests/query_test.go index 73b6dca3..12f29ace 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -349,3 +349,89 @@ func TestSearchWithMap(t *testing.T) { t.Errorf("Search all records with inline multiple value map") } } + +func TestSubQuery(t *testing.T) { + users := []User{ + {Name: "subquery_1", Age: 10}, + {Name: "subquery_2", Age: 20}, + {Name: "subquery_3", Age: 30}, + {Name: "subquery_4", Age: 40}, + } + + DB.Create(&users) + + if err := DB.Select("*").Where("name IN (?)", DB.Select("name").Table("users").Where("name LIKE ?", "subquery_%")).Find(&users).Error; err != nil { + t.Fatalf("got error: %v", err) + } + + if len(users) != 4 { + t.Errorf("Four users should be found, instead found %d", len(users)) + } + + DB.Select("*").Where("name LIKE ?", "subquery%").Where("age >= (?)", DB. + Select("AVG(age)").Table("users").Where("name LIKE ?", "subquery%")).Find(&users) + + if len(users) != 2 { + t.Errorf("Two users should be found, instead found %d", len(users)) + } +} + +func TestSubQueryWithRaw(t *testing.T) { + users := []User{ + {Name: "subquery_raw_1", Age: 10}, + {Name: "subquery_raw_2", Age: 20}, + {Name: "subquery_raw_3", Age: 30}, + {Name: "subquery_raw_4", Age: 40}, + } + DB.Create(&users) + + var count int64 + err := DB.Raw("select count(*) from (?) tmp", + DB.Table("users"). + Select("name"). + Where("age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_3"}). + Group("name"), + ).Count(&count).Error + + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + + if count != 1 { + t.Errorf("Row count must be 1, instead got %d", count) + } + + err = DB.Raw("select count(*) from (?) tmp", + DB.Table("users"). + Select("name"). + Where("name LIKE ?", "subquery_raw%"). + Not("age <= ?", 10).Not("name IN (?)", []string{"subquery_raw_1", "subquery_raw_3"}). + Group("name"), + ).Count(&count).Error + + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + + if count != 2 { + t.Errorf("Row count must be 2, instead got %d", count) + } +} + +func TestSubQueryWithHaving(t *testing.T) { + users := []User{ + {Name: "subquery_having_1", Age: 10}, + {Name: "subquery_having_2", Age: 20}, + {Name: "subquery_having_3", Age: 30}, + {Name: "subquery_having_4", Age: 40}, + } + DB.Create(&users) + + var results []User + DB.Select("AVG(age) as avgage").Where("name LIKE ?", "subquery_having%").Group("name").Having("AVG(age) > (?)", DB. + Select("AVG(age)").Where("name LIKE ?", "subquery_having%").Table("users")).Find(&results) + + if len(results) != 2 { + t.Errorf("Two user group should be found, instead found %d", len(results)) + } +} From db03616993a9693a578a70401d28779cd15e5382 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 1 Jun 2020 21:39:08 +0800 Subject: [PATCH 378/881] Add customize column test --- tests/customize_column_test.go | 58 ++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 tests/customize_column_test.go diff --git a/tests/customize_column_test.go b/tests/customize_column_test.go new file mode 100644 index 00000000..49447dab --- /dev/null +++ b/tests/customize_column_test.go @@ -0,0 +1,58 @@ +package tests_test + +import ( + "testing" + "time" + + . "github.com/jinzhu/gorm/tests" +) + +func TestCustomizeColumn(t *testing.T) { + type CustomizeColumn struct { + ID int64 `gorm:"column:mapped_id; primary_key:yes"` + Name string `gorm:"column:mapped_name"` + Date *time.Time `gorm:"column:mapped_time"` + } + + DB.Migrator().DropTable(&CustomizeColumn{}) + DB.AutoMigrate(&CustomizeColumn{}) + + expected := "foo" + now := time.Now() + cc := CustomizeColumn{ID: 666, Name: expected, Date: &now} + + if count := DB.Create(&cc).RowsAffected; count != 1 { + t.Error("There should be one record be affected when create record") + } + + var cc1 CustomizeColumn + DB.First(&cc1, "mapped_name = ?", "foo") + + if cc1.Name != expected { + t.Errorf("Failed to query CustomizeColumn") + } + + cc.Name = "bar" + DB.Save(&cc) + + var cc2 CustomizeColumn + DB.First(&cc2, "mapped_id = ?", 666) + if cc2.Name != "bar" { + t.Errorf("Failed to query CustomizeColumn") + } +} + +func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { + // Make sure an ignored field does not interfere with another field's custom + // column name that matches the ignored field. + type CustomColumnAndIgnoredFieldClash struct { + Body string `gorm:"-"` + RawBody string `gorm:"column:body"` + } + + DB.Migrator().DropTable(&CustomColumnAndIgnoredFieldClash{}) + + if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}); err != nil { + t.Errorf("Should not raise error: %v", err) + } +} From e490e09db5bbc707e5bb4cee360b2f58a29d2b7b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 1 Jun 2020 22:31:50 +0800 Subject: [PATCH 379/881] Add SetupJoinTable support --- association.go | 19 ++++++-- callbacks/create.go | 10 ++-- gorm.go | 37 +++++++++++++++ schema/relationship.go | 10 ++-- statement.go | 8 ++-- tests/joins_table_test.go | 99 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 167 insertions(+), 16 deletions(-) create mode 100644 tests/joins_table_test.go diff --git a/association.go b/association.go index 55dd7772..23e5a82f 100644 --- a/association.go +++ b/association.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "reflect" + "strings" "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/schema" @@ -44,7 +45,16 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro tx = association.DB.Model(out) ) - if association.Relationship.JoinTable != nil && !tx.Statement.Unscoped { + if association.Relationship.JoinTable != nil { + if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { + joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} + for _, queryClause := range association.Relationship.JoinTable.QueryClauses { + joinStmt.AddClause(queryClause) + } + joinStmt.Build("WHERE", "LIMIT") + tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) + } + tx.Clauses(clause.From{Joins: []clause.Join{{ Table: clause.Table{Name: association.Relationship.JoinTable.Table}, ON: clause.Where{Exprs: queryConds}, @@ -321,10 +331,13 @@ func (association *Association) Count() (count int64) { ) if association.Relationship.JoinTable != nil { - if !tx.Statement.Unscoped { + if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { + joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} for _, queryClause := range association.Relationship.JoinTable.QueryClauses { - tx.Clauses(queryClause) + joinStmt.AddClause(queryClause) } + joinStmt.Build("WHERE", "LIMIT") + tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) } tx.Clauses(clause.From{Joins: []clause.Join{{ diff --git a/callbacks/create.go b/callbacks/create.go index 01329141..0277407e 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -169,12 +169,12 @@ func CreateWithReturning(db *gorm.DB) { if err != nil { db.AddError(err) } + } + } else if !db.DryRun { + if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { + db.RowsAffected, _ = result.RowsAffected() } else { - if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) - } + db.AddError(err) } } } diff --git a/gorm.go b/gorm.go index 7d6bd2ed..fd0d4b7e 100644 --- a/gorm.go +++ b/gorm.go @@ -108,6 +108,7 @@ func (db *DB) Session(config *Session) *DB { if config.Context != nil { if tx.Statement != nil { tx.Statement = tx.Statement.clone() + tx.Statement.DB = tx } else { tx.Statement = &Statement{ DB: tx, @@ -181,6 +182,42 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) { return nil, false } +func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { + var ( + tx = db.getInstance() + stmt = tx.Statement + modelSchema, joinSchema *schema.Schema + ) + + if err := stmt.Parse(model); err == nil { + modelSchema = stmt.Schema + } else { + return err + } + + if err := stmt.Parse(joinTable); err == nil { + joinSchema = stmt.Schema + } else { + return err + } + + if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil { + for _, ref := range relation.References { + if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { + ref.ForeignKey = f + } else { + return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) + } + } + + relation.JoinTable = joinSchema + } else { + return fmt.Errorf("failed to found relation: %v", field) + } + + return nil +} + // Callback returns callback manager func (db *DB) Callback() *callbacks { return db.callbacks diff --git a/schema/relationship.go b/schema/relationship.go index dffe5988..194fbeff 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -33,7 +33,7 @@ type Relationship struct { Type RelationshipType Field *Field Polymorphic *Polymorphic - References []Reference + References []*Reference Schema *Schema FieldSchema *Schema JoinTable *Schema @@ -139,7 +139,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi } if schema.err == nil { - relation.References = append(relation.References, Reference{ + relation.References = append(relation.References, &Reference{ PrimaryValue: relation.Polymorphic.Value, ForeignKey: relation.Polymorphic.PolymorphicType, }) @@ -150,7 +150,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name) } } - relation.References = append(relation.References, Reference{ + relation.References = append(relation.References, &Reference{ PrimaryKey: primaryKeyField, ForeignKey: relation.Polymorphic.PolymorphicID, OwnPrimaryKey: true, @@ -246,7 +246,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel // build references for _, f := range relation.JoinTable.Fields { - relation.References = append(relation.References, Reference{ + relation.References = append(relation.References, &Reference{ PrimaryKey: fieldsMap[f.Name], ForeignKey: f, OwnPrimaryKey: schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name], @@ -326,7 +326,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH // build references for idx, foreignField := range foreignFields { - relation.References = append(relation.References, Reference{ + relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], ForeignKey: foreignField, OwnPrimaryKey: schema == primarySchema && guessHas, diff --git a/statement.go b/statement.go index 03d1b8a8..e78dfea9 100644 --- a/statement.go +++ b/statement.go @@ -158,9 +158,11 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { writer.WriteString("(NULL)") } case *DB: - result := v.Session(&Session{DryRun: true, WithConditions: true}).Find(nil).Statement - writer.WriteString(result.SQL.String()) - stmt.Vars = append(stmt.Vars, result.Vars...) + subdb := v.Session(&Session{DryRun: true, WithConditions: true}).getInstance() + subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...) + subdb.callbacks.Query().Execute(subdb) + writer.WriteString(subdb.Statement.SQL.String()) + stmt.Vars = subdb.Statement.Vars default: switch rv := reflect.ValueOf(v); rv.Kind() { case reflect.Slice, reflect.Array: diff --git a/tests/joins_table_test.go b/tests/joins_table_test.go new file mode 100644 index 00000000..091ca65c --- /dev/null +++ b/tests/joins_table_test.go @@ -0,0 +1,99 @@ +package tests_test + +import ( + "testing" + "time" + + "github.com/jinzhu/gorm" + . "github.com/jinzhu/gorm/tests" +) + +type Person struct { + ID int + Name string + Addresses []Address `gorm:"many2many:person_addresses;"` +} + +type Address struct { + ID uint + Name string +} + +type PersonAddress struct { + PersonID int + AddressID int + CreatedAt time.Time + DeletedAt gorm.DeletedAt +} + +func TestOverrideJoinTable(t *testing.T) { + DB.Migrator().DropTable(&Person{}, &Address{}, &PersonAddress{}) + + if err := DB.SetupJoinTable(&Person{}, "Addresses", &PersonAddress{}); err != nil { + t.Fatalf("Failed to setup join table for person, got error %v", err) + } + + if err := DB.AutoMigrate(&Person{}, &Address{}); err != nil { + t.Fatalf("Failed to migrate, got %v", err) + } + + address1 := Address{Name: "address 1"} + address2 := Address{Name: "address 2"} + person := Person{Name: "person", Addresses: []Address{address1, address2}} + DB.Create(&person) + + var addresses1 []Address + if err := DB.Model(&person).Association("Addresses").Find(&addresses1); err != nil || len(addresses1) != 2 { + t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses1)) + } + + if err := DB.Model(&person).Association("Addresses").Delete(&person.Addresses[0]); err != nil { + t.Fatalf("Failed to delete address, got error %v", err) + } + + if len(person.Addresses) != 1 { + t.Fatalf("Should have one address left") + } + + if DB.Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 1 { + t.Fatalf("Should found one address") + } + + var addresses2 []Address + if err := DB.Model(&person).Association("Addresses").Find(&addresses2); err != nil || len(addresses2) != 1 { + t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses2)) + } + + if DB.Model(&person).Association("Addresses").Count() != 1 { + t.Fatalf("Should found one address") + } + + var addresses3 []Address + if err := DB.Unscoped().Model(&person).Association("Addresses").Find(&addresses3); err != nil || len(addresses3) != 2 { + t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses3)) + } + + if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 2 { + t.Fatalf("Should found soft deleted addresses with unscoped") + } + + if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 { + t.Fatalf("Should found soft deleted addresses with unscoped") + } + + DB.Model(&person).Association("Addresses").Clear() + + if DB.Model(&person).Association("Addresses").Count() != 0 { + t.Fatalf("Should deleted all addresses") + } + + if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 { + t.Fatalf("Should found soft deleted addresses with unscoped") + } + + DB.Unscoped().Model(&person).Association("Addresses").Clear() + + if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 { + t.Fatalf("address should be deleted when clear with unscoped") + } +} From 9807fffdbce47865d911eca391a76c8ba0f02db1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 2 Jun 2020 00:03:38 +0800 Subject: [PATCH 380/881] Fix mssql tests --- dialects/mssql/create.go | 95 ++++++++++++++++++++++++++-------------- 1 file changed, 62 insertions(+), 33 deletions(-) diff --git a/dialects/mssql/create.go b/dialects/mssql/create.go index 6820bb7b..84732427 100644 --- a/dialects/mssql/create.go +++ b/dialects/mssql/create.go @@ -2,6 +2,7 @@ package mssql import ( "reflect" + "sort" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" @@ -17,10 +18,35 @@ func Create(db *gorm.DB) { } if db.Statement.SQL.String() == "" { + setIdentityInsert := false c := db.Statement.Clauses["ON CONFLICT"] onConflict, hasConflict := c.Expression.(clause.OnConflict) - if hasConflict { + if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil { + setIdentityInsert = false + switch db.Statement.ReflectValue.Kind() { + case reflect.Struct: + _, isZero := field.ValueOf(db.Statement.ReflectValue) + setIdentityInsert = !isZero + case reflect.Slice: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + _, isZero := field.ValueOf(db.Statement.ReflectValue.Index(i)) + setIdentityInsert = !isZero + break + } + } + + if setIdentityInsert && (field.DataType == schema.Int || field.DataType == schema.Uint) { + setIdentityInsert = true + db.Statement.WriteString("SET IDENTITY_INSERT ") + db.Statement.WriteQuoted(db.Statement.Table) + db.Statement.WriteString(" ON;") + } else { + setIdentityInsert = false + } + } + + if hasConflict && len(db.Statement.Schema.PrimaryFields) > 0 { MergeCreate(db, onConflict) } else { db.Statement.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: db.Statement.Table}}) @@ -55,10 +81,16 @@ func Create(db *gorm.DB) { db.Statement.WriteString(";") } else { - db.Statement.WriteString("DEFAULT VALUES") + db.Statement.WriteString("DEFAULT VALUES;") } } } + + if setIdentityInsert { + db.Statement.WriteString("SET IDENTITY_INSERT ") + db.Statement.WriteQuoted(db.Statement.Table) + db.Statement.WriteString(" OFF;") + } } if !db.DryRun { @@ -67,25 +99,32 @@ func Create(db *gorm.DB) { if err == nil { defer rows.Close() - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if len(db.Statement.Schema.PrimaryFields) > 0 { - values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) + if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { + sortedKeys := []string{} + for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { + sortedKeys = append(sortedKeys, field.DBName) + } + sort.Strings(sortedKeys) + returnningFields := make([]*schema.Field, len(sortedKeys)) + for idx, key := range sortedKeys { + returnningFields[idx] = db.Statement.Schema.LookUpField(key) + } + + values := make([]interface{}, len(returnningFields)) + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: for rows.Next() { - for idx, field := range db.Statement.Schema.PrimaryFields { + for idx, field := range returnningFields { values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() } db.RowsAffected++ db.AddError(rows.Scan(values...)) } - } - case reflect.Struct: - if len(db.Statement.Schema.PrimaryFields) > 0 { - values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) - - for idx, field := range db.Statement.Schema.PrimaryFields { + case reflect.Struct: + for idx, field := range returnningFields { values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() } @@ -103,16 +142,6 @@ func Create(db *gorm.DB) { func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) { values := callbacks.ConvertToCreateValues(db.Statement) - setIdentityInsert := false - - if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil { - if field.DataType == schema.Int || field.DataType == schema.Uint { - setIdentityInsert = true - db.Statement.WriteString("SET IDENTITY_INSERT ") - db.Statement.WriteQuoted(db.Statement.Table) - db.Statement.WriteString("ON;") - } - } db.Statement.WriteString("MERGE INTO ") db.Statement.WriteQuoted(db.Statement.Table) @@ -174,23 +203,23 @@ func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) { db.Statement.WriteString(")") outputInserted(db) db.Statement.WriteString(";") - - if setIdentityInsert { - db.Statement.WriteString("SET IDENTITY_INSERT ") - db.Statement.WriteQuoted(db.Statement.Table) - db.Statement.WriteString("OFF;") - } } func outputInserted(db *gorm.DB) { - if len(db.Statement.Schema.PrimaryFields) > 0 { - db.Statement.WriteString(" OUTPUT ") - for idx, field := range db.Statement.Schema.PrimaryFields { + if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { + sortedKeys := []string{} + for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { + sortedKeys = append(sortedKeys, field.DBName) + } + sort.Strings(sortedKeys) + + db.Statement.WriteString(" OUTPUT") + for idx, key := range sortedKeys { if idx > 0 { db.Statement.WriteString(",") } db.Statement.WriteString(" INSERTED.") - db.Statement.AddVar(db.Statement, clause.Column{Name: field.DBName}) + db.Statement.AddVar(db.Statement, clause.Column{Name: key}) } } } From bc01eb28ada22b7413fd2452b1260c0787b79388 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 2 Jun 2020 00:24:16 +0800 Subject: [PATCH 381/881] Fix tests script --- tests/main_test.go | 5 +++++ tests/migrate_test.go | 2 +- tests/tests.go | 4 +--- tests/tests_all.sh | 10 ++++++++-- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/tests/main_test.go b/tests/main_test.go index 095588a2..60cc4611 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -6,6 +6,11 @@ import ( . "github.com/jinzhu/gorm/tests" ) +func TestMain(m *testing.M) { + RunMigrations() + m.Run() +} + func TestExceptionsWithInvalidSql(t *testing.T) { var columns []string if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 957db8d6..e786b1cc 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -25,7 +25,7 @@ func TestMigrate(t *testing.T) { for _, m := range allModels { if !DB.Migrator().HasTable(m) { - t.Fatalf("Failed to create table for %#v", m) + t.Fatalf("Failed to create table for %#v---", m) } } } diff --git a/tests/tests.go b/tests/tests.go index d9257898..fa8ac836 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -19,9 +19,7 @@ var DB *gorm.DB func init() { var err error - if DB, err = OpenTestConnection(); err == nil { - RunMigrations() - } else { + if DB, err = OpenTestConnection(); err != nil { log.Printf("failed to connect database, got error %v\n", err) os.Exit(1) } diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 243af787..3a1b45c8 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -9,11 +9,17 @@ for dialect in "${dialects[@]}" ; do then echo "testing ${dialect}..." + race="" if [ "$GORM_VERBOSE" = "" ] then - DEBUG=false GORM_DIALECT=${dialect} go test -race -count=1 ./... + race="-race" + fi + + if [ "$GORM_VERBOSE" = "" ] + then + DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 ./... else - DEBUG=false GORM_DIALECT=${dialect} go test -race -count=1 -v ./... + DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 -v ./... fi fi done From b71171dd92cafcca395ee9131a6b40d41d72217e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 2 Jun 2020 00:44:48 +0800 Subject: [PATCH 382/881] Add more preload tests --- callbacks/preload.go | 24 +- schema/utils.go | 18 +- tests/preload_suits_test.go | 1510 +++++++++++++++++++++++++++++++++++ 3 files changed, 1544 insertions(+), 8 deletions(-) create mode 100644 tests/preload_suits_test.go diff --git a/callbacks/preload.go b/callbacks/preload.go index 5b5beb06..6c763da4 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -19,6 +19,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { foreignFields []*schema.Field foreignValues [][]interface{} identityMap = map[string][]reflect.Value{} + inlineConds []interface{} ) if len(rels) > 1 { @@ -64,7 +65,8 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { - identityMap[utils.ToStringKey(joinFieldValues...)] = results + joinKey := utils.ToStringKey(joinFieldValues...) + identityMap[joinKey] = append(identityMap[joinKey], results...) } } @@ -92,12 +94,23 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { reflectResults := rel.FieldSchema.MakeSlice().Elem() column, values := schema.ToQueryValues(relForeignKeys, foreignValues) - tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), conds...) + + for _, cond := range conds { + if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { + tx = fc(tx) + } else { + inlineConds = append(inlineConds, cond) + } + } + + tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...) fieldValues := make([]interface{}, len(relForeignFields)) + for i := 0; i < reflectResults.Len(); i++ { + elem := reflectResults.Index(i) for idx, field := range relForeignFields { - fieldValues[idx], _ = field.ValueOf(reflectResults.Index(i)) + fieldValues[idx], _ = field.ValueOf(elem) } for _, data := range identityMap[utils.ToStringKey(fieldValues...)] { @@ -105,15 +118,16 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) } + reflectFieldValue = reflect.Indirect(reflectFieldValue) switch reflectFieldValue.Kind() { case reflect.Struct: rel.Field.Set(data, reflectResults.Index(i).Interface()) case reflect.Slice, reflect.Array: if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(data, reflect.Append(reflectFieldValue, reflectResults.Index(i)).Interface()) + rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) } else { - rel.Field.Set(data, reflect.Append(reflectFieldValue, reflectResults.Index(i).Elem()).Interface()) + rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) } } } diff --git a/schema/utils.go b/schema/utils.go index f7808f0e..ca4ef2f4 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -95,6 +95,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map var ( results = [][]interface{}{} dataResults = map[string][]reflect.Value{} + loaded = map[interface{}]bool{} notZero, zero bool ) @@ -114,10 +115,21 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue} case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { + elem := reflectValue.Index(i) + elemKey := elem.Interface() + if elem.Kind() != reflect.Ptr { + elemKey = elem.Addr().Interface() + } + + if _, ok := loaded[elemKey]; ok { + continue + } + loaded[elemKey] = true + fieldValues := make([]interface{}, len(fields)) notZero = false for idx, field := range fields { - fieldValues[idx], zero = field.ValueOf(reflectValue.Index(i)) + fieldValues[idx], zero = field.ValueOf(elem) notZero = notZero || !zero } @@ -125,9 +137,9 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map dataKey := utils.ToStringKey(fieldValues...) if _, ok := dataResults[dataKey]; !ok { results = append(results, fieldValues[:]) - dataResults[dataKey] = []reflect.Value{reflectValue.Index(i)} + dataResults[dataKey] = []reflect.Value{elem} } else { - dataResults[dataKey] = append(dataResults[dataKey], reflectValue.Index(i)) + dataResults[dataKey] = append(dataResults[dataKey], elem) } } } diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go new file mode 100644 index 00000000..2e7eeb1f --- /dev/null +++ b/tests/preload_suits_test.go @@ -0,0 +1,1510 @@ +package tests_test + +import ( + "database/sql" + "encoding/json" + "reflect" + "testing" + + "github.com/jinzhu/gorm" + . "github.com/jinzhu/gorm/tests" +) + +func toJSONString(v interface{}) []byte { + r, _ := json.Marshal(v) + return r +} + +func TestNestedPreload1(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + } + ) + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { + t.Error(err) + } + + want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} + if err := DB.Create(&want).Error; err != nil { + t.Error(err) + } + + var got Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + + if err := DB.Preload("Level2").Preload("Level2.Level1").First(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound { + t.Error(err) + } +} + +func TestNestedPreload2(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []*Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2s []Level2 + } + ) + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { + t.Error(err) + } + + want := Level3{ + Level2s: []Level2{ + { + Level1s: []*Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + { + Level1s: []*Level1{ + {Value: "value3"}, + }, + }, + }, + } + if err := DB.Create(&want).Error; err != nil { + t.Error(err) + } + + var got Level3 + if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload3(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + Name string + ID uint + Level2s []Level2 + } + ) + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { + t.Error(err) + } + + want := Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value1"}}, + {Level1: Level1{Value: "value2"}}, + }, + } + if err := DB.Create(&want).Error; err != nil { + t.Error(err) + } + + var got Level3 + if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload4(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + } + ) + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { + t.Error(err) + } + + want := Level3{ + Level2: Level2{ + Level1s: []Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + } + if err := DB.Create(&want).Error; err != nil { + t.Error(err) + } + + var got Level3 + if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +// Slice: []Level3 +func TestNestedPreload5(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + } + ) + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { + t.Error(err) + } + + want := make([]Level3, 2) + want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} + if err := DB.Create(&want[0]).Error; err != nil { + t.Error(err) + } + want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} + if err := DB.Create(&want[1]).Error; err != nil { + t.Error(err) + } + + var got []Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload6(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2s []Level2 + } + ) + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { + t.Error(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2s: []Level2{ + { + Level1s: []Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + { + Level1s: []Level1{ + {Value: "value3"}, + }, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + t.Error(err) + } + + want[1] = Level3{ + Level2s: []Level2{ + { + Level1s: []Level1{ + {Value: "value3"}, + {Value: "value4"}, + }, + }, + { + Level1s: []Level1{ + {Value: "value5"}, + }, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + t.Error(err) + } + + var got []Level3 + if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload7(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2s []Level2 + } + ) + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { + t.Error(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value1"}}, + {Level1: Level1{Value: "value2"}}, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + t.Error(err) + } + + want[1] = Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value3"}}, + {Level1: Level1{Value: "value4"}}, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + t.Error(err) + } + + var got []Level3 + if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload8(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + } + ) + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { + t.Error(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + t.Error(err) + } + want[1] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + {Value: "value3"}, + {Value: "value4"}, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + t.Error(err) + } + + var got []Level3 + if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload9(t *testing.T) { + type ( + Level0 struct { + ID uint + Value string + Level1ID uint + } + Level1 struct { + ID uint + Value string + Level2ID uint + Level2_1ID uint + Level0s []Level0 `json:",omitempty"` + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level2_1 struct { + ID uint + Level1s []Level1 `json:",omitempty"` + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + Level2_1 Level2_1 + } + ) + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}); err != nil { + t.Error(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + Level2_1: Level2_1{ + Level1s: []Level1{ + { + Value: "value1-1", + Level0s: []Level0{{Value: "Level0-1"}}, + }, + { + Value: "value2-2", + Level0s: []Level0{{Value: "Level0-2"}}, + }, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + t.Error(err) + } + want[1] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + {Value: "value3"}, + {Value: "value4"}, + }, + }, + Level2_1: Level2_1{ + Level1s: []Level1{ + { + Value: "value3-3", + Level0s: []Level0{}, + }, + { + Value: "value4-4", + Level0s: []Level0{}, + }, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + t.Error(err) + } + + var got []Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil { + t.Error(err) + } + + if string(toJSONString(got)) != string(toJSONString(want)) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +type LevelA1 struct { + ID uint + Value string +} + +type LevelA2 struct { + ID uint + Value string + LevelA3s []*LevelA3 `json:",omitempty"` +} + +type LevelA3 struct { + ID uint + Value string + LevelA1ID sql.NullInt64 + LevelA1 *LevelA1 + LevelA2ID sql.NullInt64 + LevelA2 *LevelA2 +} + +func TestNestedPreload10(t *testing.T) { + DB.Migrator().DropTable(&LevelA3{}, &LevelA2{}, &LevelA1{}) + if err := DB.AutoMigrate(&LevelA1{}, &LevelA2{}, &LevelA3{}); err != nil { + t.Error(err) + } + + levelA1 := &LevelA1{Value: "foo"} + if err := DB.Save(levelA1).Error; err != nil { + t.Error(err) + } + + want := []*LevelA2{ + { + Value: "bar", + LevelA3s: []*LevelA3{ + { + Value: "qux", + LevelA1: levelA1, + }, + }, + }, + { + Value: "bar 2", + LevelA3s: []*LevelA3{}, + }, + } + for _, levelA2 := range want { + if err := DB.Save(levelA2).Error; err != nil { + t.Error(err) + } + } + + var got []*LevelA2 + if err := DB.Preload("LevelA3s.LevelA1").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(toJSONString(got), toJSONString(want)) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +type LevelB1 struct { + ID uint + Value string + LevelB3s []*LevelB3 +} + +type LevelB2 struct { + ID uint + Value string +} + +type LevelB3 struct { + ID uint + Value string + LevelB1ID sql.NullInt64 + LevelB1 *LevelB1 + LevelB2s []*LevelB2 `gorm:"many2many:levelb1_levelb3_levelb2s" json:",omitempty"` +} + +func TestNestedPreload11(t *testing.T) { + DB.Migrator().DropTable(&LevelB3{}, &LevelB2{}, &LevelB1{}) + if err := DB.AutoMigrate(&LevelB1{}, &LevelB2{}, &LevelB3{}); err != nil { + t.Error(err) + } + + levelB1 := &LevelB1{Value: "foo"} + if err := DB.Create(levelB1).Error; err != nil { + t.Error(err) + } + + levelB3 := &LevelB3{ + Value: "bar", + LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)}, + LevelB2s: []*LevelB2{}, + } + if err := DB.Create(levelB3).Error; err != nil { + t.Error(err) + } + levelB1.LevelB3s = []*LevelB3{levelB3} + + want := []*LevelB1{levelB1} + var got []*LevelB1 + if err := DB.Preload("LevelB3s.LevelB2s").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(toJSONString(got), toJSONString(want)) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +type LevelC1 struct { + ID uint + Value string + LevelC2ID uint +} + +type LevelC2 struct { + ID uint + Value string + LevelC1 LevelC1 +} + +type LevelC3 struct { + ID uint + Value string + LevelC2ID uint + LevelC2 LevelC2 +} + +func TestNestedPreload12(t *testing.T) { + DB.Migrator().DropTable(&LevelC3{}, &LevelC2{}, &LevelC1{}) + if err := DB.AutoMigrate(&LevelC1{}, &LevelC2{}, &LevelC3{}); err != nil { + t.Error(err) + } + + level2 := LevelC2{ + Value: "c2", + LevelC1: LevelC1{ + Value: "c1", + }, + } + DB.Create(&level2) + + want := []LevelC3{ + { + Value: "c3-1", + LevelC2: level2, + }, { + Value: "c3-2", + LevelC2: level2, + }, + } + + for i := range want { + if err := DB.Create(&want[i]).Error; err != nil { + t.Error(err) + } + } + + var got []LevelC3 + if err := DB.Preload("LevelC2").Preload("LevelC2.LevelC1").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { + t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") + } + + type ( + Level1 struct { + ID uint `gorm:"primary_key;"` + LanguageCode string `gorm:"primary_key"` + Value string + } + Level2 struct { + ID uint `gorm:"primary_key;"` + LanguageCode string `gorm:"primary_key"` + Value string + Level1s []Level1 `gorm:"many2many:levels;"` + } + ) + + DB.Migrator().DropTable(&Level2{}, &Level1{}) + DB.Migrator().DropTable("levels") + + if err := DB.AutoMigrate(&Level2{}, &Level1{}); err != nil { + t.Error(err) + } + + want := Level2{Value: "Bob", LanguageCode: "ru", Level1s: []Level1{ + {Value: "ru", LanguageCode: "ru"}, + {Value: "en", LanguageCode: "en"}, + }} + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + + want2 := Level2{Value: "Tom", LanguageCode: "zh", Level1s: []Level1{ + {Value: "zh", LanguageCode: "zh"}, + {Value: "de", LanguageCode: "de"}, + }} + if err := DB.Save(&want2).Error; err != nil { + t.Error(err) + } + + var got Level2 + if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { + t.Error(err) + } + return + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + + var got2 Level2 + if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got2, want2) { + t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) + } + + var got3 []Level2 + if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got3, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) + } + + var got4 []Level2 + if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } + + var ruLevel1 Level1 + var zhLevel1 Level1 + DB.First(&ruLevel1, "value = ?", "ru") + DB.First(&zhLevel1, "value = ?", "zh") + + got.Level1s = []Level1{ruLevel1} + got2.Level1s = []Level1{zhLevel1} + if !reflect.DeepEqual(got4, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) + } + + if err := DB.Preload("Level1s").Find(&got4, "value IN (?)", []string{"non-existing"}).Error; err != nil { + t.Error(err) + } +} + +func TestManyToManyPreloadForNestedPointer(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:levels;"` + } + Level3 struct { + ID uint + Value string + Level2ID sql.NullInt64 + Level2 *Level2 + } + ) + + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + DB.Migrator().DropTable("levels") + + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { + t.Error(err) + } + + want := Level3{ + Value: "Bob", + Level2: &Level2{ + Value: "Foo", + Level1s: []*Level1{ + {Value: "ru"}, + {Value: "en"}, + }, + }, + } + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + + want2 := Level3{ + Value: "Tom", + Level2: &Level2{ + Value: "Bar", + Level1s: []*Level1{ + {Value: "zh"}, + {Value: "de"}, + }, + }, + } + if err := DB.Save(&want2).Error; err != nil { + t.Error(err) + } + + var got Level3 + if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + + var got2 Level3 + if err := DB.Preload("Level2.Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got2, want2) { + t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) + } + + var got3 []Level3 + if err := DB.Preload("Level2.Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got3, []Level3{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level3{got, got2})) + } + + var got4 []Level3 + if err := DB.Preload("Level2.Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } + + var got5 Level3 + DB.Preload("Level2.Level1s").Find(&got5, "value = ?", "bogus") + + var ruLevel1 Level1 + var zhLevel1 Level1 + DB.First(&ruLevel1, "value = ?", "ru") + DB.First(&zhLevel1, "value = ?", "zh") + + got.Level2.Level1s = []*Level1{&ruLevel1} + got2.Level2.Level1s = []*Level1{&zhLevel1} + if !reflect.DeepEqual(got4, []Level3{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level3{got, got2})) + } +} + +func TestNestedManyToManyPreload(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:level1_level2;"` + } + Level3 struct { + ID uint + Value string + Level2s []Level2 `gorm:"many2many:level2_level3;"` + } + ) + + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, "level1_level2", "level2_level3") + + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { + t.Error(err) + } + + want := Level3{ + Value: "Level3", + Level2s: []Level2{ + { + Value: "Bob", + Level1s: []*Level1{ + {Value: "ru"}, + {Value: "en"}, + }, + }, { + Value: "Tom", + Level1s: []*Level1{ + {Value: "zh"}, + {Value: "de"}, + }, + }, + }, + } + + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + + var got Level3 + if err := DB.Preload("Level2s").Preload("Level2s.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + + if err := DB.Preload("Level2s.Level1s").First(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { + t.Error(err) + } +} + +func TestNestedManyToManyPreload2(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:level1_level2;"` + } + Level3 struct { + ID uint + Value string + Level2ID sql.NullInt64 + Level2 *Level2 + } + ) + + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + DB.Migrator().DropTable("level1_level2") + + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { + t.Error(err) + } + + want := Level3{ + Value: "Level3", + Level2: &Level2{ + Value: "Bob", + Level1s: []*Level1{ + {Value: "ru"}, + {Value: "en"}, + }, + }, + } + + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + + var got Level3 + if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + + if err := DB.Preload("Level2.Level1s").First(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { + t.Error(err) + } +} + +func TestNestedManyToManyPreload3(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:level1_level2;"` + } + Level3 struct { + ID uint + Value string + Level2ID sql.NullInt64 + Level2 *Level2 + } + ) + + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, "level1_level2") + + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { + t.Error(err) + } + + level1Zh := &Level1{Value: "zh"} + level1Ru := &Level1{Value: "ru"} + level1En := &Level1{Value: "en"} + + level21 := &Level2{ + Value: "Level2-1", + Level1s: []*Level1{level1Zh, level1Ru}, + } + + level22 := &Level2{ + Value: "Level2-2", + Level1s: []*Level1{level1Zh, level1En}, + } + + wants := []*Level3{ + { + Value: "Level3-1", + Level2: level21, + }, + { + Value: "Level3-2", + Level2: level22, + }, + { + Value: "Level3-3", + Level2: level21, + }, + } + + for _, want := range wants { + if err := DB.Save(want).Error; err != nil { + t.Error(err) + } + } + + var gots []*Level3 + if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { + return db.Order("level1.id ASC") + }).Find(&gots).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(gots, wants) { + t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) + } +} + +func TestNestedManyToManyPreload3ForStruct(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []Level1 `gorm:"many2many:level1_level2;"` + } + Level3 struct { + ID uint + Value string + Level2ID sql.NullInt64 + Level2 Level2 + } + ) + + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + DB.Migrator().DropTable("level1_level2") + + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { + t.Error(err) + } + + level1Zh := Level1{Value: "zh"} + level1Ru := Level1{Value: "ru"} + level1En := Level1{Value: "en"} + + level21 := Level2{ + Value: "Level2-1", + Level1s: []Level1{level1Zh, level1Ru}, + } + + level22 := Level2{ + Value: "Level2-2", + Level1s: []Level1{level1Zh, level1En}, + } + + wants := []*Level3{ + { + Value: "Level3-1", + Level2: level21, + }, + { + Value: "Level3-2", + Level2: level22, + }, + { + Value: "Level3-3", + Level2: level21, + }, + } + + for _, want := range wants { + if err := DB.Save(want).Error; err != nil { + t.Error(err) + } + } + + var gots []*Level3 + if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { + return db.Order("level1.id ASC") + }).Find(&gots).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(gots, wants) { + t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) + } +} + +func TestNestedManyToManyPreload4(t *testing.T) { + type ( + Level4 struct { + ID uint + Value string + Level3ID uint + } + Level3 struct { + ID uint + Value string + Level4s []*Level4 + } + Level2 struct { + ID uint + Value string + Level3s []*Level3 `gorm:"many2many:level2_level3;"` + } + Level1 struct { + ID uint + Value string + Level2s []*Level2 `gorm:"many2many:level1_level2;"` + } + ) + + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level4{}) + DB.Migrator().DropTable("level1_level2") + DB.Migrator().DropTable("level2_level3") + + dummy := Level1{ + Value: "Level1", + Level2s: []*Level2{{ + Value: "Level2", + Level3s: []*Level3{{ + Value: "Level3", + Level4s: []*Level4{{ + Value: "Level4", + }}, + }}, + }}, + } + + if err := DB.AutoMigrate(&Level4{}, &Level3{}, &Level2{}, &Level1{}); err != nil { + t.Error(err) + } + + if err := DB.Save(&dummy).Error; err != nil { + t.Error(err) + } + + var level1 Level1 + if err := DB.Preload("Level2s").Preload("Level2s.Level3s").Preload("Level2s.Level3s.Level4s").First(&level1).Error; err != nil { + t.Error(err) + } +} + +func TestManyToManyPreloadForPointer(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:levels;"` + } + ) + + DB.Migrator().DropTable(&Level2{}, &Level1{}) + DB.Migrator().DropTable("levels") + + if err := DB.AutoMigrate(&Level2{}, &Level1{}); err != nil { + t.Error(err) + } + + want := Level2{Value: "Bob", Level1s: []*Level1{ + {Value: "ru"}, + {Value: "en"}, + }} + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + + want2 := Level2{Value: "Tom", Level1s: []*Level1{ + {Value: "zh"}, + {Value: "de"}, + }} + if err := DB.Save(&want2).Error; err != nil { + t.Error(err) + } + + var got Level2 + if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + + var got2 Level2 + if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got2, want2) { + t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) + } + + var got3 []Level2 + if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got3, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) + } + + var got4 []Level2 + if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } + + var got5 Level2 + DB.Preload("Level1s").First(&got5, "value = ?", "bogus") + + var ruLevel1 Level1 + var zhLevel1 Level1 + DB.First(&ruLevel1, "value = ?", "ru") + DB.First(&zhLevel1, "value = ?", "zh") + + got.Level1s = []*Level1{&ruLevel1} + got2.Level1s = []*Level1{&zhLevel1} + if !reflect.DeepEqual(got4, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) + } +} + +func TestNilPointerSlice(t *testing.T) { + type ( + Level3 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level3ID uint + Level3 *Level3 + } + Level1 struct { + ID uint + Value string + Level2ID uint + Level2 *Level2 + } + ) + + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { + t.Error(err) + } + + want := Level1{ + Value: "Bob", + Level2: &Level2{ + Value: "en", + Level3: &Level3{ + Value: "native", + }, + }, + } + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + + want2 := Level1{ + Value: "Tom", + Level2: nil, + } + if err := DB.Save(&want2).Error; err != nil { + t.Error(err) + } + + var got []Level1 + if err := DB.Preload("Level2").Preload("Level2.Level3").Find(&got).Error; err != nil { + t.Error(err) + } + + if len(got) != 2 { + t.Errorf("got %v items, expected 2", len(got)) + } + + if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { + t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) + } + + if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { + t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want2)) + } +} + +func TestNilPointerSlice2(t *testing.T) { + type ( + Level4 struct { + ID uint + } + Level3 struct { + ID uint + Level4ID sql.NullInt64 `sql:"index"` + Level4 *Level4 + } + Level2 struct { + ID uint + Level3s []*Level3 `gorm:"many2many:level2_level3s"` + } + Level1 struct { + ID uint + Level2ID sql.NullInt64 `sql:"index"` + Level2 *Level2 + } + ) + + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level4{}) + + if err := DB.AutoMigrate(new(Level4), new(Level3), new(Level2), new(Level1)); err != nil { + t.Error(err) + } + + want := new(Level1) + if err := DB.Save(want).Error; err != nil { + t.Error(err) + } + + got := new(Level1) + err := DB.Preload("Level2.Level3s.Level4").Last(&got).Error + if err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestPrefixedPreloadDuplication(t *testing.T) { + type ( + Level4 struct { + ID uint + Name string + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level4s []*Level4 `json:",omitempty"` + } + Level2 struct { + ID uint + Name string + Level3ID sql.NullInt64 `sql:"index"` + Level3 *Level3 + } + Level1 struct { + ID uint + Name string + Level2ID sql.NullInt64 `sql:"index"` + Level2 *Level2 + } + ) + + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level4{}) + + if err := DB.AutoMigrate(new(Level3), new(Level4), new(Level2), new(Level1)); err != nil { + t.Error(err) + } + + lvl := &Level3{} + if err := DB.Save(lvl).Error; err != nil { + t.Error(err) + } + + sublvl1 := &Level4{Level3ID: lvl.ID} + if err := DB.Save(sublvl1).Error; err != nil { + t.Error(err) + } + sublvl2 := &Level4{Level3ID: lvl.ID} + if err := DB.Save(sublvl2).Error; err != nil { + t.Error(err) + } + + lvl.Level4s = []*Level4{sublvl1, sublvl2} + + want1 := Level1{ + Level2: &Level2{ + Level3: lvl, + }, + } + if err := DB.Save(&want1).Error; err != nil { + t.Error(err) + } + + want2 := Level1{ + Level2: &Level2{ + Level3: lvl, + }, + } + if err := DB.Save(&want2).Error; err != nil { + t.Error(err) + } + + want := []Level1{want1, want2} + + var got []Level1 + err := DB.Preload("Level2.Level3.Level4s").Find(&got).Error + if err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestPreloadManyToManyCallbacks(t *testing.T) { + type ( + Level2 struct { + ID uint + Name string + } + Level1 struct { + ID uint + Name string + Level2s []Level2 `gorm:"many2many:level1_level2s"` + } + ) + + DB.Migrator().DropTable(&Level2{}, &Level1{}) + DB.Migrator().DropTable("level1_level2s") + + if err := DB.AutoMigrate(new(Level1), new(Level2)); err != nil { + t.Error(err) + } + + lvl := Level1{ + Name: "l1", + Level2s: []Level2{ + {Name: "l2-1"}, {Name: "l2-2"}, + }, + } + DB.Save(&lvl) + + called := 0 + + DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(_ *gorm.DB) { + called = called + 1 + }) + + DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID) + + if called != 3 { + t.Errorf("Wanted callback to be called 3 times but got %d", called) + } +} From 5ecbf25b225b824660c70dba134051888e78ee76 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 2 Jun 2020 07:28:29 +0800 Subject: [PATCH 383/881] Drop table with CASCADE option --- dialects/mysql/migrator.go | 15 +++++++++++++++ dialects/postgres/migrator.go | 13 +++++++++++++ gorm.go | 1 + migrator/migrator.go | 13 +++++-------- schema/relationship.go | 10 ++++++++++ tests/preload_suits_test.go | 13 +++++-------- 6 files changed, 49 insertions(+), 16 deletions(-) diff --git a/dialects/mysql/migrator.go b/dialects/mysql/migrator.go index 74c11277..467da9a2 100644 --- a/dialects/mysql/migrator.go +++ b/dialects/mysql/migrator.go @@ -24,6 +24,21 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { }) } +func (m Migrator) DropTable(values ...interface{}) error { + values = m.ReorderModels(values, false) + tx := m.DB.Session(&gorm.Session{}) + tx.Exec("SET FOREIGN_KEY_CHECKS = 0;") + for i := len(values) - 1; i >= 0; i-- { + if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { + return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error + }); err != nil { + return err + } + } + tx.Exec("SET FOREIGN_KEY_CHECKS = 1;") + return nil +} + func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { for _, chk := range stmt.Schema.ParseCheckConstraints() { diff --git a/dialects/postgres/migrator.go b/dialects/postgres/migrator.go index d93f681c..ef582f00 100644 --- a/dialects/postgres/migrator.go +++ b/dialects/postgres/migrator.go @@ -108,6 +108,19 @@ func (m Migrator) HasTable(value interface{}) bool { return count > 0 } +func (m Migrator) DropTable(values ...interface{}) error { + values = m.ReorderModels(values, false) + tx := m.DB.Session(&gorm.Session{}) + for i := len(values) - 1; i >= 0; i-- { + if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { + return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error + }); err != nil { + return err + } + } + return nil +} + func (m Migrator) HasColumn(value interface{}, field string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { diff --git a/gorm.go b/gorm.go index fd0d4b7e..07f94266 100644 --- a/gorm.go +++ b/gorm.go @@ -204,6 +204,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil { for _, ref := range relation.References { if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { + f.DataType = ref.ForeignKey.DataType ref.ForeignKey = f } else { return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) diff --git a/migrator/migrator.go b/migrator/migrator.go index 4e0f28b5..d78c6224 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -203,14 +203,11 @@ func (m Migrator) CreateTable(values ...interface{}) error { func (m Migrator) DropTable(values ...interface{}) error { values = m.ReorderModels(values, false) for i := len(values) - 1; i >= 0; i-- { - value := values[i] - if m.DB.Migrator().HasTable(value) { - tx := m.DB.Session(&gorm.Session{}) - if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { - return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error - }); err != nil { - return err - } + tx := m.DB.Session(&gorm.Session{}) + if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { + return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error + }); err != nil { + return err } } return nil diff --git a/schema/relationship.go b/schema/relationship.go index 194fbeff..8b5e987c 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -150,6 +150,10 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name) } } + + // use same data type for foreign keys + relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType + relation.References = append(relation.References, &Reference{ PrimaryKey: primaryKeyField, ForeignKey: relation.Polymorphic.PolymorphicID, @@ -246,6 +250,9 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel // build references for _, f := range relation.JoinTable.Fields { + // use same data type for foreign keys + f.DataType = fieldsMap[f.Name].DataType + relation.References = append(relation.References, &Reference{ PrimaryKey: fieldsMap[f.Name], ForeignKey: f, @@ -326,6 +333,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH // build references for idx, foreignField := range foreignFields { + // use same data type for foreign keys + foreignField.DataType = primaryFields[idx].DataType + relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], ForeignKey: foreignField, diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go index 2e7eeb1f..b71b7299 100644 --- a/tests/preload_suits_test.go +++ b/tests/preload_suits_test.go @@ -1167,9 +1167,8 @@ func TestNestedManyToManyPreload4(t *testing.T) { } ) + DB.Migrator().DropTable("level1_level2", "level2_level3") DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level4{}) - DB.Migrator().DropTable("level1_level2") - DB.Migrator().DropTable("level2_level3") dummy := Level1{ Value: "Level1", @@ -1211,8 +1210,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) { } ) - DB.Migrator().DropTable(&Level2{}, &Level1{}) - DB.Migrator().DropTable("levels") + DB.Migrator().DropTable("levels", &Level2{}, &Level1{}) if err := DB.AutoMigrate(&Level2{}, &Level1{}); err != nil { t.Error(err) @@ -1296,7 +1294,7 @@ func TestNilPointerSlice(t *testing.T) { Level1 struct { ID uint Value string - Level2ID uint + Level2ID *uint Level2 *Level2 } ) @@ -1325,7 +1323,7 @@ func TestNilPointerSlice(t *testing.T) { Level2: nil, } if err := DB.Save(&want2).Error; err != nil { - t.Error(err) + t.Fatalf("Got error %v", err) } var got []Level1 @@ -1481,8 +1479,7 @@ func TestPreloadManyToManyCallbacks(t *testing.T) { } ) - DB.Migrator().DropTable(&Level2{}, &Level1{}) - DB.Migrator().DropTable("level1_level2s") + DB.Migrator().DropTable("level1_level2s", &Level2{}, &Level1{}) if err := DB.AutoMigrate(new(Level1), new(Level2)); err != nil { t.Error(err) From e986371a42bb5ded77ac65b46e07c80d0f450eae Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 2 Jun 2020 09:16:07 +0800 Subject: [PATCH 384/881] Rename package name --- README.md | 5 +++-- association.go | 6 +++--- callbacks.go | 6 +++--- callbacks/associations.go | 8 ++++---- callbacks/callbacks.go | 2 +- callbacks/create.go | 6 +++--- callbacks/delete.go | 6 +++--- callbacks/helper.go | 4 ++-- callbacks/interface.go | 2 +- callbacks/preload.go | 8 ++++---- callbacks/query.go | 6 +++--- callbacks/raw.go | 2 +- callbacks/row.go | 2 +- callbacks/transaction.go | 2 +- callbacks/update.go | 6 +++--- chainable_api.go | 4 ++-- clause/benchmarks_test.go | 8 ++++---- clause/clause_test.go | 8 ++++---- clause/delete_test.go | 2 +- clause/expression_test.go | 8 ++++---- clause/from_test.go | 2 +- clause/group_by_test.go | 2 +- clause/insert_test.go | 2 +- clause/limit_test.go | 2 +- clause/locking_test.go | 2 +- clause/order_by_test.go | 2 +- clause/returning_test.go | 2 +- clause/select_test.go | 2 +- clause/set_test.go | 2 +- clause/update_test.go | 2 +- clause/values_test.go | 2 +- clause/where_test.go | 2 +- dialects/mssql/create.go | 8 ++++---- dialects/mssql/migrator.go | 6 +++--- dialects/mssql/mssql.go | 12 ++++++------ dialects/mysql/migrator.go | 6 +++--- dialects/mysql/mysql.go | 12 ++++++------ dialects/postgres/migrator.go | 8 ++++---- dialects/postgres/postgres.go | 12 ++++++------ dialects/sqlite/migrator.go | 8 ++++---- dialects/sqlite/sqlite.go | 12 ++++++------ finisher_api.go | 2 +- go.mod | 7 ++++--- gorm.go | 6 +++--- interfaces.go | 4 ++-- logger/logger.go | 2 +- logger/sql_test.go | 4 ++-- migrator/migrator.go | 6 +++--- scan.go | 2 +- schema/callbacks_test.go | 4 ++-- schema/check_test.go | 2 +- schema/field.go | 4 ++-- schema/field_test.go | 6 +++--- schema/index_test.go | 2 +- schema/model_test.go | 4 ++-- schema/naming.go | 2 +- schema/relationship.go | 4 ++-- schema/relationship_test.go | 4 ++-- schema/schema.go | 4 ++-- schema/schema_helper_test.go | 4 ++-- schema/schema_test.go | 4 ++-- schema/utils.go | 2 +- soft_delete.go | 4 ++-- statement.go | 4 ++-- tests/associations_belongs_to_test.go | 2 +- tests/associations_has_many_test.go | 2 +- tests/associations_has_one_test.go | 2 +- tests/associations_many2many_test.go | 2 +- tests/associations_test.go | 2 +- tests/callbacks_test.go | 2 +- tests/count_test.go | 2 +- tests/create_test.go | 6 +++--- tests/customize_column_test.go | 2 +- tests/delete_test.go | 4 ++-- tests/dummy_dialecter.go | 8 ++++---- tests/embedded_struct_test.go | 4 ++-- tests/group_by_test.go | 2 +- tests/hooks_test.go | 4 ++-- tests/joins_table_test.go | 4 ++-- tests/joins_test.go | 4 ++-- tests/main_test.go | 2 +- tests/migrate_test.go | 4 ++-- tests/model.go | 2 +- tests/multi_primary_keys_test.go | 2 +- tests/named_polymorphic_test.go | 2 +- tests/non_std_test.go | 2 +- tests/preload_suits_test.go | 4 ++-- tests/preload_test.go | 4 ++-- tests/query_test.go | 4 ++-- tests/scan_test.go | 2 +- tests/scanner_valuer_test.go | 4 ++-- tests/scopes_test.go | 4 ++-- tests/soft_delete_test.go | 2 +- tests/sql_builder_test.go | 4 ++-- tests/tests.go | 12 ++++++------ tests/transaction_test.go | 4 ++-- tests/update_belongs_to_test.go | 2 +- tests/update_has_many_test.go | 2 +- tests/update_has_one_test.go | 2 +- tests/update_many2many_test.go | 2 +- tests/update_test.go | 4 ++-- tests/upsert_test.go | 4 ++-- tests/utils.go | 2 +- 103 files changed, 213 insertions(+), 211 deletions(-) diff --git a/README.md b/README.md index 6d231103..84236bb9 100644 --- a/README.md +++ b/README.md @@ -2,14 +2,14 @@ The fantastic ORM library for Golang, aims to be developer friendly. -[![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm) +[![go report card](https://goreportcard.com/badge/gorm.io/gorm "go report card")](https://goreportcard.com/report/gorm.io/gorm) [![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b) [![codecov](https://codecov.io/gh/jinzhu/gorm/branch/master/graph/badge.svg)](https://codecov.io/gh/jinzhu/gorm) [![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm) [![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) -[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) +[![GoDoc](https://godoc.org/gorm.io/gorm?status.svg)](https://godoc.org/gorm.io/gorm) ## Overview @@ -39,3 +39,4 @@ The fantastic ORM library for Golang, aims to be developer friendly. © Jinzhu, 2013~time.Now Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License) + diff --git a/association.go b/association.go index 23e5a82f..928dcf3e 100644 --- a/association.go +++ b/association.go @@ -6,9 +6,9 @@ import ( "reflect" "strings" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/schema" - "github.com/jinzhu/gorm/utils" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) // Association Mode contains some helper methods to handle relationship things easily. diff --git a/callbacks.go b/callbacks.go index d3cd8e62..c5654c50 100644 --- a/callbacks.go +++ b/callbacks.go @@ -7,9 +7,9 @@ import ( "reflect" "time" - "github.com/jinzhu/gorm/logger" - "github.com/jinzhu/gorm/schema" - "github.com/jinzhu/gorm/utils" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) func initializeCallbacks(db *DB) *callbacks { diff --git a/callbacks/associations.go b/callbacks/associations.go index d19f7339..5ff63cc4 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -3,10 +3,10 @@ package callbacks import ( "reflect" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/schema" - "github.com/jinzhu/gorm/utils" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) func SaveBeforeAssociations(db *gorm.DB) { diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index 1c1d6ade..f61252d4 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -1,7 +1,7 @@ package callbacks import ( - "github.com/jinzhu/gorm" + "gorm.io/gorm" ) type Config struct { diff --git a/callbacks/create.go b/callbacks/create.go index 0277407e..0b88e263 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -3,9 +3,9 @@ package callbacks import ( "reflect" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" ) func BeforeCreate(db *gorm.DB) { diff --git a/callbacks/delete.go b/callbacks/delete.go index 451569cf..b8691ff9 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -3,9 +3,9 @@ package callbacks import ( "reflect" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" ) func BeforeDelete(db *gorm.DB) { diff --git a/callbacks/helper.go b/callbacks/helper.go index 818d9c2c..828e025a 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -3,8 +3,8 @@ package callbacks import ( "sort" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/clause" + "gorm.io/gorm" + "gorm.io/gorm/clause" ) // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false diff --git a/callbacks/interface.go b/callbacks/interface.go index 0ef64fcd..ee0044e8 100644 --- a/callbacks/interface.go +++ b/callbacks/interface.go @@ -1,6 +1,6 @@ package callbacks -import "github.com/jinzhu/gorm" +import "gorm.io/gorm" type beforeSaveInterface interface { BeforeSave(*gorm.DB) error diff --git a/callbacks/preload.go b/callbacks/preload.go index 6c763da4..a9907d68 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -3,10 +3,10 @@ package callbacks import ( "reflect" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/schema" - "github.com/jinzhu/gorm/utils" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { diff --git a/callbacks/query.go b/callbacks/query.go index f7c3271f..b3293576 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -6,9 +6,9 @@ import ( "sort" "strings" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" ) func Query(db *gorm.DB) { diff --git a/callbacks/raw.go b/callbacks/raw.go index cb0cd6c9..4093a5ab 100644 --- a/callbacks/raw.go +++ b/callbacks/raw.go @@ -1,7 +1,7 @@ package callbacks import ( - "github.com/jinzhu/gorm" + "gorm.io/gorm" ) func RawExec(db *gorm.DB) { diff --git a/callbacks/row.go b/callbacks/row.go index f4ff734c..b25503ff 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -1,7 +1,7 @@ package callbacks import ( - "github.com/jinzhu/gorm" + "gorm.io/gorm" ) func RowQuery(db *gorm.DB) { diff --git a/callbacks/transaction.go b/callbacks/transaction.go index 63015364..430a341d 100644 --- a/callbacks/transaction.go +++ b/callbacks/transaction.go @@ -1,7 +1,7 @@ package callbacks import ( - "github.com/jinzhu/gorm" + "gorm.io/gorm" ) func BeginTransaction(db *gorm.DB) { diff --git a/callbacks/update.go b/callbacks/update.go index a52bd310..9b2e924b 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -4,9 +4,9 @@ import ( "reflect" "sort" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" ) func SetupUpdateReflectValue(db *gorm.DB) { diff --git a/chainable_api.go b/chainable_api.go index 6fa605c6..b1ae3132 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -4,8 +4,8 @@ import ( "fmt" "strings" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/utils" + "gorm.io/gorm/clause" + "gorm.io/gorm/utils" ) // Model specify the model you would like to run db operations diff --git a/clause/benchmarks_test.go b/clause/benchmarks_test.go index 47001cd1..2faed773 100644 --- a/clause/benchmarks_test.go +++ b/clause/benchmarks_test.go @@ -4,10 +4,10 @@ import ( "sync" "testing" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/schema" - "github.com/jinzhu/gorm/tests" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/tests" ) func BenchmarkSelect(b *testing.B) { diff --git a/clause/clause_test.go b/clause/clause_test.go index 30ea9343..f9d26a4a 100644 --- a/clause/clause_test.go +++ b/clause/clause_test.go @@ -6,10 +6,10 @@ import ( "sync" "testing" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/schema" - "github.com/jinzhu/gorm/tests" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/tests" ) var db, _ = gorm.Open(tests.DummyDialector{}, nil) diff --git a/clause/delete_test.go b/clause/delete_test.go index 2faf8364..a9a659b3 100644 --- a/clause/delete_test.go +++ b/clause/delete_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - "github.com/jinzhu/gorm/clause" + "gorm.io/gorm/clause" ) func TestDelete(t *testing.T) { diff --git a/clause/expression_test.go b/clause/expression_test.go index e51d189e..4e937650 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -5,10 +5,10 @@ import ( "sync" "testing" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/schema" - "github.com/jinzhu/gorm/tests" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/tests" ) func TestExpr(t *testing.T) { diff --git a/clause/from_test.go b/clause/from_test.go index 4b7b0e18..3ebb754c 100644 --- a/clause/from_test.go +++ b/clause/from_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - "github.com/jinzhu/gorm/clause" + "gorm.io/gorm/clause" ) func TestFrom(t *testing.T) { diff --git a/clause/group_by_test.go b/clause/group_by_test.go index 98aad3eb..589f9613 100644 --- a/clause/group_by_test.go +++ b/clause/group_by_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - "github.com/jinzhu/gorm/clause" + "gorm.io/gorm/clause" ) func TestGroupBy(t *testing.T) { diff --git a/clause/insert_test.go b/clause/insert_test.go index b1a57803..70810bce 100644 --- a/clause/insert_test.go +++ b/clause/insert_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - "github.com/jinzhu/gorm/clause" + "gorm.io/gorm/clause" ) func TestInsert(t *testing.T) { diff --git a/clause/limit_test.go b/clause/limit_test.go index 7b76aaf4..80317dc3 100644 --- a/clause/limit_test.go +++ b/clause/limit_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - "github.com/jinzhu/gorm/clause" + "gorm.io/gorm/clause" ) func TestLimit(t *testing.T) { diff --git a/clause/locking_test.go b/clause/locking_test.go index 6b054404..6f507692 100644 --- a/clause/locking_test.go +++ b/clause/locking_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - "github.com/jinzhu/gorm/clause" + "gorm.io/gorm/clause" ) func TestFor(t *testing.T) { diff --git a/clause/order_by_test.go b/clause/order_by_test.go index 2c74a322..2ea2d192 100644 --- a/clause/order_by_test.go +++ b/clause/order_by_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - "github.com/jinzhu/gorm/clause" + "gorm.io/gorm/clause" ) func TestOrderBy(t *testing.T) { diff --git a/clause/returning_test.go b/clause/returning_test.go index e9fed1cb..bd0ecce8 100644 --- a/clause/returning_test.go +++ b/clause/returning_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - "github.com/jinzhu/gorm/clause" + "gorm.io/gorm/clause" ) func TestReturning(t *testing.T) { diff --git a/clause/select_test.go b/clause/select_test.go index 0863d086..b7296434 100644 --- a/clause/select_test.go +++ b/clause/select_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - "github.com/jinzhu/gorm/clause" + "gorm.io/gorm/clause" ) func TestSelect(t *testing.T) { diff --git a/clause/set_test.go b/clause/set_test.go index 48131218..dbc1e970 100644 --- a/clause/set_test.go +++ b/clause/set_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - "github.com/jinzhu/gorm/clause" + "gorm.io/gorm/clause" ) func TestSet(t *testing.T) { diff --git a/clause/update_test.go b/clause/update_test.go index adc48f03..c704bf5e 100644 --- a/clause/update_test.go +++ b/clause/update_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - "github.com/jinzhu/gorm/clause" + "gorm.io/gorm/clause" ) func TestUpdate(t *testing.T) { diff --git a/clause/values_test.go b/clause/values_test.go index ced4f1e6..9c02c8a5 100644 --- a/clause/values_test.go +++ b/clause/values_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - "github.com/jinzhu/gorm/clause" + "gorm.io/gorm/clause" ) func TestValues(t *testing.T) { diff --git a/clause/where_test.go b/clause/where_test.go index 450a0c89..894e11f4 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - "github.com/jinzhu/gorm/clause" + "gorm.io/gorm/clause" ) func TestWhere(t *testing.T) { diff --git a/dialects/mssql/create.go b/dialects/mssql/create.go index 84732427..b07f13c5 100644 --- a/dialects/mssql/create.go +++ b/dialects/mssql/create.go @@ -4,10 +4,10 @@ import ( "reflect" "sort" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/callbacks" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" ) func Create(db *gorm.DB) { diff --git a/dialects/mssql/migrator.go b/dialects/mssql/migrator.go index 1de49ae9..3bb2086d 100644 --- a/dialects/mssql/migrator.go +++ b/dialects/mssql/migrator.go @@ -3,9 +3,9 @@ package mssql import ( "fmt" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/migrator" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/migrator" ) type Migrator struct { diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 066aa38f..3f87180c 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -7,12 +7,12 @@ import ( "strconv" _ "github.com/denisenkom/go-mssqldb" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/callbacks" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/logger" - "github.com/jinzhu/gorm/migrator" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" ) type Dialector struct { diff --git a/dialects/mysql/migrator.go b/dialects/mysql/migrator.go index 467da9a2..8d3d20c6 100644 --- a/dialects/mysql/migrator.go +++ b/dialects/mysql/migrator.go @@ -3,9 +3,9 @@ package mysql import ( "fmt" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/migrator" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/migrator" ) type Migrator struct { diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index e617a1e1..035a6d79 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -6,12 +6,12 @@ import ( "math" _ "github.com/go-sql-driver/mysql" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/callbacks" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/logger" - "github.com/jinzhu/gorm/migrator" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" ) type Dialector struct { diff --git a/dialects/postgres/migrator.go b/dialects/postgres/migrator.go index ef582f00..6b1085e3 100644 --- a/dialects/postgres/migrator.go +++ b/dialects/postgres/migrator.go @@ -3,10 +3,10 @@ package postgres import ( "fmt" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/migrator" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" ) type Migrator struct { diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index fb3ecc68..57e51d58 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -6,12 +6,12 @@ import ( "regexp" "strconv" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/callbacks" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/logger" - "github.com/jinzhu/gorm/migrator" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" _ "github.com/lib/pq" ) diff --git a/dialects/sqlite/migrator.go b/dialects/sqlite/migrator.go index 252e4183..14c682ca 100644 --- a/dialects/sqlite/migrator.go +++ b/dialects/sqlite/migrator.go @@ -5,10 +5,10 @@ import ( "regexp" "strings" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/migrator" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" ) type Migrator struct { diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 1b9809af..238ad7f9 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -3,12 +3,12 @@ package sqlite import ( "database/sql" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/callbacks" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/logger" - "github.com/jinzhu/gorm/migrator" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" _ "github.com/mattn/go-sqlite3" ) diff --git a/finisher_api.go b/finisher_api.go index 780de267..5023150c 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -6,7 +6,7 @@ import ( "reflect" "strings" - "github.com/jinzhu/gorm/clause" + "gorm.io/gorm/clause" ) // Create insert the value into database diff --git a/go.mod b/go.mod index 7dabdd39..fe07494e 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/jinzhu/gorm +module gorm.io/gorm go 1.14 @@ -6,8 +6,9 @@ require ( github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 github.com/go-sql-driver/mysql v1.5.0 - github.com/jinzhu/inflection v1.0.0 - github.com/jinzhu/now v1.1.1 + gorm.io/gorm v1.9.12 + gorm.io/inflection v1.0.0 + gorm.io/now v1.1.1 github.com/lib/pq v1.1.1 github.com/mattn/go-sqlite3 v2.0.1+incompatible ) diff --git a/gorm.go b/gorm.go index 07f94266..1ab3fd64 100644 --- a/gorm.go +++ b/gorm.go @@ -6,9 +6,9 @@ import ( "sync" "time" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/logger" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" ) // Config GORM config diff --git a/interfaces.go b/interfaces.go index 421428a3..6d9c6212 100644 --- a/interfaces.go +++ b/interfaces.go @@ -4,8 +4,8 @@ import ( "context" "database/sql" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" ) // Dialector GORM database dialector diff --git a/logger/logger.go b/logger/logger.go index 694adedc..2a5e445c 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -6,7 +6,7 @@ import ( "os" "time" - "github.com/jinzhu/gorm/utils" + "gorm.io/gorm/utils" ) // Colors diff --git a/logger/sql_test.go b/logger/sql_test.go index dd7b80c8..bd852479 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -4,8 +4,8 @@ import ( "regexp" "testing" - "github.com/jinzhu/gorm/logger" - "github.com/jinzhu/now" + "gorm.io/gorm/logger" + "gorm.io/now" ) func TestExplainSQL(t *testing.T) { diff --git a/migrator/migrator.go b/migrator/migrator.go index d78c6224..afef65c3 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -6,9 +6,9 @@ import ( "reflect" "strings" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" ) // Migrator m struct diff --git a/scan.go b/scan.go index 4d328fde..fc6b211b 100644 --- a/scan.go +++ b/scan.go @@ -5,7 +5,7 @@ import ( "reflect" "strings" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm/schema" ) func Scan(rows *sql.Rows, db *DB, initialized bool) { diff --git a/schema/callbacks_test.go b/schema/callbacks_test.go index efa01e89..dec41eba 100644 --- a/schema/callbacks_test.go +++ b/schema/callbacks_test.go @@ -5,8 +5,8 @@ import ( "sync" "testing" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm" + "gorm.io/gorm/schema" ) type UserWithCallback struct { diff --git a/schema/check_test.go b/schema/check_test.go index e4bc9ebe..eda043b7 100644 --- a/schema/check_test.go +++ b/schema/check_test.go @@ -5,7 +5,7 @@ import ( "sync" "testing" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm/schema" ) type UserCheck struct { diff --git a/schema/field.go b/schema/field.go index 8a0f01bf..438dadab 100644 --- a/schema/field.go +++ b/schema/field.go @@ -10,8 +10,8 @@ import ( "sync" "time" - "github.com/jinzhu/gorm/utils" - "github.com/jinzhu/now" + "gorm.io/gorm/utils" + "gorm.io/now" ) type DataType string diff --git a/schema/field_test.go b/schema/field_test.go index aac46de9..7a47f195 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -7,9 +7,9 @@ import ( "testing" "time" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/schema" - "github.com/jinzhu/gorm/tests" + "gorm.io/gorm" + "gorm.io/gorm/schema" + "gorm.io/gorm/tests" ) func TestFieldValuerAndSetter(t *testing.T) { diff --git a/schema/index_test.go b/schema/index_test.go index 398ddbb7..384e902b 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -5,7 +5,7 @@ import ( "sync" "testing" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm/schema" ) type UserIndex struct { diff --git a/schema/model_test.go b/schema/model_test.go index 343e324e..068b3050 100644 --- a/schema/model_test.go +++ b/schema/model_test.go @@ -4,8 +4,8 @@ import ( "database/sql" "time" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/tests" + "gorm.io/gorm" + "gorm.io/gorm/tests" ) type User struct { diff --git a/schema/naming.go b/schema/naming.go index f7c82f32..1af45257 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -7,7 +7,7 @@ import ( "sync" "unicode/utf8" - "github.com/jinzhu/inflection" + "gorm.io/inflection" ) // Namer namer interface diff --git a/schema/relationship.go b/schema/relationship.go index 8b5e987c..f24c6e6d 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -6,8 +6,8 @@ import ( "regexp" "strings" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/inflection" + "gorm.io/gorm/clause" + "gorm.io/inflection" ) // RelationshipType relationship type diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 0f62f45d..defba9ce 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -4,8 +4,8 @@ import ( "sync" "testing" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm" + "gorm.io/gorm/schema" ) func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) { diff --git a/schema/schema.go b/schema/schema.go index 231ed1db..60e621de 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -8,8 +8,8 @@ import ( "reflect" "sync" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/logger" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" ) // ErrUnsupportedDataType unsupported data type diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index b5474fe7..b966164e 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -6,8 +6,8 @@ import ( "strings" "testing" - "github.com/jinzhu/gorm/schema" - "github.com/jinzhu/gorm/tests" + "gorm.io/gorm/schema" + "gorm.io/gorm/tests" ) func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { diff --git a/schema/schema_test.go b/schema/schema_test.go index 958e035f..6902cbf2 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -4,8 +4,8 @@ import ( "sync" "testing" - "github.com/jinzhu/gorm/schema" - "github.com/jinzhu/gorm/tests" + "gorm.io/gorm/schema" + "gorm.io/gorm/tests" ) func TestParseSchema(t *testing.T) { diff --git a/schema/utils.go b/schema/utils.go index ca4ef2f4..da236a18 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -5,7 +5,7 @@ import ( "regexp" "strings" - "github.com/jinzhu/gorm/utils" + "gorm.io/gorm/utils" ) func ParseTagSetting(str string, sep string) map[string]string { diff --git a/soft_delete.go b/soft_delete.go index 09cfff37..4ffceba6 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -5,8 +5,8 @@ import ( "database/sql/driver" "reflect" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" ) type DeletedAt sql.NullTime diff --git a/statement.go b/statement.go index e78dfea9..8f4762e7 100644 --- a/statement.go +++ b/statement.go @@ -10,8 +10,8 @@ import ( "strings" "sync" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" ) // Statement statement diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index 236af191..27b82ecb 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "github.com/jinzhu/gorm/tests" + . "gorm.io/gorm/tests" ) func TestBelongsToAssociation(t *testing.T) { diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index 2269d701..88df8532 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "github.com/jinzhu/gorm/tests" + . "gorm.io/gorm/tests" ) func TestHasManyAssociation(t *testing.T) { diff --git a/tests/associations_has_one_test.go b/tests/associations_has_one_test.go index a863cb36..9ddfa9c5 100644 --- a/tests/associations_has_one_test.go +++ b/tests/associations_has_one_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "github.com/jinzhu/gorm/tests" + . "gorm.io/gorm/tests" ) func TestHasOneAssociation(t *testing.T) { diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index a2db9675..d79cdc17 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "github.com/jinzhu/gorm/tests" + . "gorm.io/gorm/tests" ) func TestMany2ManyAssociation(t *testing.T) { diff --git a/tests/associations_test.go b/tests/associations_test.go index 3668b44b..2e30df8b 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "github.com/jinzhu/gorm/tests" + . "gorm.io/gorm/tests" ) func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) { diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index f8dc3e81..1dbae441 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -7,7 +7,7 @@ import ( "strings" "testing" - "github.com/jinzhu/gorm" + "gorm.io/gorm" ) func assertCallbacks(v interface{}, fnames []string) (result bool, msg string) { diff --git a/tests/count_test.go b/tests/count_test.go index 257959c3..d8cfa405 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - . "github.com/jinzhu/gorm/tests" + . "gorm.io/gorm/tests" ) func TestCount(t *testing.T) { diff --git a/tests/create_test.go b/tests/create_test.go index 4b9694b6..4ef14ddb 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -4,9 +4,9 @@ import ( "testing" "time" - "github.com/jinzhu/gorm" - . "github.com/jinzhu/gorm/tests" - "github.com/jinzhu/now" + "gorm.io/gorm" + . "gorm.io/gorm/tests" + "gorm.io/now" ) func TestCreate(t *testing.T) { diff --git a/tests/customize_column_test.go b/tests/customize_column_test.go index 49447dab..0db40869 100644 --- a/tests/customize_column_test.go +++ b/tests/customize_column_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - . "github.com/jinzhu/gorm/tests" + . "gorm.io/gorm/tests" ) func TestCustomizeColumn(t *testing.T) { diff --git a/tests/delete_test.go b/tests/delete_test.go index e7076aa6..0fe2ee75 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -4,8 +4,8 @@ import ( "errors" "testing" - "github.com/jinzhu/gorm" - . "github.com/jinzhu/gorm/tests" + "gorm.io/gorm" + . "gorm.io/gorm/tests" ) func TestDelete(t *testing.T) { diff --git a/tests/dummy_dialecter.go b/tests/dummy_dialecter.go index 4ea17a0f..cd4bbd45 100644 --- a/tests/dummy_dialecter.go +++ b/tests/dummy_dialecter.go @@ -1,10 +1,10 @@ package tests import ( - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/clause" - "github.com/jinzhu/gorm/logger" - "github.com/jinzhu/gorm/schema" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" ) type DummyDialector struct { diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index af003786..74829460 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -3,8 +3,8 @@ package tests_test import ( "testing" - "github.com/jinzhu/gorm" - . "github.com/jinzhu/gorm/tests" + "gorm.io/gorm" + . "gorm.io/gorm/tests" ) func TestEmbeddedStruct(t *testing.T) { diff --git a/tests/group_by_test.go b/tests/group_by_test.go index 66a733aa..5a954348 100644 --- a/tests/group_by_test.go +++ b/tests/group_by_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "github.com/jinzhu/gorm/tests" + . "gorm.io/gorm/tests" ) func TestGroupBy(t *testing.T) { diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 432226a3..418713a6 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jinzhu/gorm" - . "github.com/jinzhu/gorm/tests" + "gorm.io/gorm" + . "gorm.io/gorm/tests" ) type Product struct { diff --git a/tests/joins_table_test.go b/tests/joins_table_test.go index 091ca65c..5738d8f4 100644 --- a/tests/joins_table_test.go +++ b/tests/joins_table_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jinzhu/gorm" - . "github.com/jinzhu/gorm/tests" + "gorm.io/gorm" + . "gorm.io/gorm/tests" ) type Person struct { diff --git a/tests/joins_test.go b/tests/joins_test.go index d9cfd22f..651b20c6 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -4,8 +4,8 @@ import ( "sort" "testing" - "github.com/jinzhu/gorm" - . "github.com/jinzhu/gorm/tests" + "gorm.io/gorm" + . "gorm.io/gorm/tests" ) func TestJoins(t *testing.T) { diff --git a/tests/main_test.go b/tests/main_test.go index 60cc4611..2d466125 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "github.com/jinzhu/gorm/tests" + . "gorm.io/gorm/tests" ) func TestMain(m *testing.M) { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index e786b1cc..b511ab40 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -6,8 +6,8 @@ import ( "testing" "time" - "github.com/jinzhu/gorm" - . "github.com/jinzhu/gorm/tests" + "gorm.io/gorm" + . "gorm.io/gorm/tests" ) func TestMigrate(t *testing.T) { diff --git a/tests/model.go b/tests/model.go index 1ae7c160..878129e8 100644 --- a/tests/model.go +++ b/tests/model.go @@ -4,7 +4,7 @@ import ( "database/sql" "time" - "github.com/jinzhu/gorm" + "gorm.io/gorm" ) // User has one `Account` (has one), many `Pets` (has many) and `Toys` (has many - polymorphic) diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go index b3284f15..139cde69 100644 --- a/tests/multi_primary_keys_test.go +++ b/tests/multi_primary_keys_test.go @@ -5,7 +5,7 @@ import ( "sort" "testing" - . "github.com/jinzhu/gorm/tests" + . "gorm.io/gorm/tests" ) type Blog struct { diff --git a/tests/named_polymorphic_test.go b/tests/named_polymorphic_test.go index 95b8ec7d..99a7865a 100644 --- a/tests/named_polymorphic_test.go +++ b/tests/named_polymorphic_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "github.com/jinzhu/gorm/tests" + . "gorm.io/gorm/tests" ) type Hamster struct { diff --git a/tests/non_std_test.go b/tests/non_std_test.go index 606b4fc9..b3ac6545 100644 --- a/tests/non_std_test.go +++ b/tests/non_std_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - . "github.com/jinzhu/gorm/tests" + . "gorm.io/gorm/tests" ) type Animal struct { diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go index b71b7299..42e94fa0 100644 --- a/tests/preload_suits_test.go +++ b/tests/preload_suits_test.go @@ -6,8 +6,8 @@ import ( "reflect" "testing" - "github.com/jinzhu/gorm" - . "github.com/jinzhu/gorm/tests" + "gorm.io/gorm" + . "gorm.io/gorm/tests" ) func toJSONString(v interface{}) []byte { diff --git a/tests/preload_test.go b/tests/preload_test.go index b14c5b90..e4ecdc87 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -5,8 +5,8 @@ import ( "strconv" "testing" - "github.com/jinzhu/gorm/clause" - . "github.com/jinzhu/gorm/tests" + "gorm.io/gorm/clause" + . "gorm.io/gorm/tests" ) func TestNestedPreload(t *testing.T) { diff --git a/tests/query_test.go b/tests/query_test.go index 12f29ace..9d15a41f 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -8,8 +8,8 @@ import ( "testing" "time" - "github.com/jinzhu/gorm" - . "github.com/jinzhu/gorm/tests" + "gorm.io/gorm" + . "gorm.io/gorm/tests" ) func TestFind(t *testing.T) { diff --git a/tests/scan_test.go b/tests/scan_test.go index fc6c1721..262ac9a7 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -6,7 +6,7 @@ import ( "strings" "testing" - . "github.com/jinzhu/gorm/tests" + . "gorm.io/gorm/tests" ) func TestScan(t *testing.T) { diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 9f91b5d8..7dad081f 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -10,8 +10,8 @@ import ( "testing" "time" - "github.com/jinzhu/gorm" - . "github.com/jinzhu/gorm/tests" + "gorm.io/gorm" + . "gorm.io/gorm/tests" ) func TestScannerValuer(t *testing.T) { diff --git a/tests/scopes_test.go b/tests/scopes_test.go index c0530da5..a2a7de3f 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -3,8 +3,8 @@ package tests_test import ( "testing" - "github.com/jinzhu/gorm" - . "github.com/jinzhu/gorm/tests" + "gorm.io/gorm" + . "gorm.io/gorm/tests" ) func NameIn1And2(d *gorm.DB) *gorm.DB { diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go index f91052c1..24b06498 100644 --- a/tests/soft_delete_test.go +++ b/tests/soft_delete_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "github.com/jinzhu/gorm/tests" + . "gorm.io/gorm/tests" ) func TestSoftDelete(t *testing.T) { diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 0aed82a2..0f3a56ed 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -3,8 +3,8 @@ package tests_test import ( "testing" - "github.com/jinzhu/gorm" - . "github.com/jinzhu/gorm/tests" + "gorm.io/gorm" + . "gorm.io/gorm/tests" ) func TestRow(t *testing.T) { diff --git a/tests/tests.go b/tests/tests.go index fa8ac836..42902685 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -7,12 +7,12 @@ import ( "path/filepath" "time" - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/dialects/mssql" - "github.com/jinzhu/gorm/dialects/mysql" - "github.com/jinzhu/gorm/dialects/postgres" - "github.com/jinzhu/gorm/dialects/sqlite" - "github.com/jinzhu/gorm/logger" + "gorm.io/gorm" + "gorm.io/gorm/dialects/mssql" + "gorm.io/gorm/dialects/mysql" + "gorm.io/gorm/dialects/postgres" + "gorm.io/gorm/dialects/sqlite" + "gorm.io/gorm/logger" ) var DB *gorm.DB diff --git a/tests/transaction_test.go b/tests/transaction_test.go index f39b3167..4ff1b485 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -5,8 +5,8 @@ import ( "errors" "testing" - "github.com/jinzhu/gorm" - . "github.com/jinzhu/gorm/tests" + "gorm.io/gorm" + . "gorm.io/gorm/tests" ) func TestTransaction(t *testing.T) { diff --git a/tests/update_belongs_to_test.go b/tests/update_belongs_to_test.go index 267fd4e8..7c578b38 100644 --- a/tests/update_belongs_to_test.go +++ b/tests/update_belongs_to_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "github.com/jinzhu/gorm/tests" + . "gorm.io/gorm/tests" ) func TestUpdateBelongsTo(t *testing.T) { diff --git a/tests/update_has_many_test.go b/tests/update_has_many_test.go index e723b940..5501c519 100644 --- a/tests/update_has_many_test.go +++ b/tests/update_has_many_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "github.com/jinzhu/gorm/tests" + . "gorm.io/gorm/tests" ) func TestUpdateHasManyAssociations(t *testing.T) { diff --git a/tests/update_has_one_test.go b/tests/update_has_one_test.go index 4c5036cf..721c302a 100644 --- a/tests/update_has_one_test.go +++ b/tests/update_has_one_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "github.com/jinzhu/gorm/tests" + . "gorm.io/gorm/tests" ) func TestUpdateHasOne(t *testing.T) { diff --git a/tests/update_many2many_test.go b/tests/update_many2many_test.go index bc7a60af..5548444f 100644 --- a/tests/update_many2many_test.go +++ b/tests/update_many2many_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "github.com/jinzhu/gorm/tests" + . "gorm.io/gorm/tests" ) func TestUpdateMany2ManyAssociations(t *testing.T) { diff --git a/tests/update_test.go b/tests/update_test.go index a5a62237..aef7f4ce 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/jinzhu/gorm" - . "github.com/jinzhu/gorm/tests" + "gorm.io/gorm" + . "gorm.io/gorm/tests" ) func TestUpdate(t *testing.T) { diff --git a/tests/upsert_test.go b/tests/upsert_test.go index 6f67f603..87b223b4 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jinzhu/gorm/clause" - . "github.com/jinzhu/gorm/tests" + "gorm.io/gorm/clause" + . "gorm.io/gorm/tests" ) func TestUpsert(t *testing.T) { diff --git a/tests/utils.go b/tests/utils.go index 97b5d5c8..0b4b138e 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -11,7 +11,7 @@ import ( "testing" "time" - "github.com/jinzhu/gorm/utils" + "gorm.io/gorm/utils" ) type Config struct { From 5790ba9ef40351a86f531abc8bbef4d0d64efba7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 2 Jun 2020 09:25:55 +0800 Subject: [PATCH 385/881] Fix package path --- go.mod | 6 +++--- logger/sql_test.go | 2 +- schema/field.go | 2 +- schema/naming.go | 2 +- schema/relationship.go | 2 +- tests/create_test.go | 2 +- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index fe07494e..26877c7a 100644 --- a/go.mod +++ b/go.mod @@ -6,9 +6,9 @@ require ( github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 github.com/go-sql-driver/mysql v1.5.0 - gorm.io/gorm v1.9.12 - gorm.io/inflection v1.0.0 - gorm.io/now v1.1.1 + github.com/jinzhu/inflection v1.0.0 + github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.1.1 github.com/mattn/go-sqlite3 v2.0.1+incompatible + gorm.io/gorm v1.9.12 ) diff --git a/logger/sql_test.go b/logger/sql_test.go index bd852479..8bc48116 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -4,8 +4,8 @@ import ( "regexp" "testing" + "github.com/jinzhu/now" "gorm.io/gorm/logger" - "gorm.io/now" ) func TestExplainSQL(t *testing.T) { diff --git a/schema/field.go b/schema/field.go index 438dadab..4f92aae7 100644 --- a/schema/field.go +++ b/schema/field.go @@ -10,8 +10,8 @@ import ( "sync" "time" + "github.com/jinzhu/now" "gorm.io/gorm/utils" - "gorm.io/now" ) type DataType string diff --git a/schema/naming.go b/schema/naming.go index 1af45257..f7c82f32 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -7,7 +7,7 @@ import ( "sync" "unicode/utf8" - "gorm.io/inflection" + "github.com/jinzhu/inflection" ) // Namer namer interface diff --git a/schema/relationship.go b/schema/relationship.go index f24c6e6d..efa44554 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -6,8 +6,8 @@ import ( "regexp" "strings" + "github.com/jinzhu/inflection" "gorm.io/gorm/clause" - "gorm.io/inflection" ) // RelationshipType relationship type diff --git a/tests/create_test.go b/tests/create_test.go index 4ef14ddb..2f853c61 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -4,9 +4,9 @@ import ( "testing" "time" + "github.com/jinzhu/now" "gorm.io/gorm" . "gorm.io/gorm/tests" - "gorm.io/now" ) func TestCreate(t *testing.T) { From 8bb05a5a692f080eaa756b985cac7d9171909194 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 2 Jun 2020 10:34:50 +0800 Subject: [PATCH 386/881] Refactor tests files --- clause/benchmarks_test.go | 2 +- clause/clause_test.go | 2 +- clause/expression_test.go | 2 +- dialects/mssql/create.go | 225 ---------------------- dialects/mssql/migrator.go | 142 -------------- dialects/mssql/mssql.go | 127 ------------ dialects/mysql/migrator.go | 58 ------ dialects/mysql/mysql.go | 169 ---------------- dialects/postgres/migrator.go | 139 ------------- dialects/postgres/postgres.go | 102 ---------- dialects/sqlite/migrator.go | 211 -------------------- dialects/sqlite/sqlite.go | 80 -------- go.mod | 6 - schema/field_test.go | 2 +- schema/model_test.go | 2 +- schema/schema_helper_test.go | 2 +- schema/schema_test.go | 2 +- tests/associations_belongs_to_test.go | 2 +- tests/associations_has_many_test.go | 2 +- tests/associations_has_one_test.go | 2 +- tests/associations_many2many_test.go | 2 +- tests/associations_test.go | 2 +- tests/count_test.go | 2 +- tests/create_test.go | 2 +- tests/customize_column_test.go | 2 - tests/delete_test.go | 2 +- tests/embedded_struct_test.go | 1 - tests/go.mod | 14 ++ tests/group_by_test.go | 2 +- tests/{utils.go => helper_test.go} | 103 +--------- tests/hooks_test.go | 1 - tests/joins_table_test.go | 1 - tests/joins_test.go | 2 +- tests/main_test.go | 2 +- tests/migrate_test.go | 2 +- tests/multi_primary_keys_test.go | 14 +- tests/named_polymorphic_test.go | 2 +- tests/non_std_test.go | 2 - tests/preload_suits_test.go | 5 +- tests/preload_test.go | 2 +- tests/query_test.go | 2 +- tests/scan_test.go | 2 +- tests/scanner_valuer_test.go | 2 +- tests/scopes_test.go | 2 +- tests/soft_delete_test.go | 2 +- tests/sql_builder_test.go | 2 +- tests/tests_all.sh | 5 + tests/{tests.go => tests_test.go} | 22 +-- tests/transaction_test.go | 2 +- tests/update_belongs_to_test.go | 2 +- tests/update_has_many_test.go | 2 +- tests/update_has_one_test.go | 2 +- tests/update_many2many_test.go | 2 +- tests/update_test.go | 2 +- tests/upsert_test.go | 2 +- {tests => utils/tests}/dummy_dialecter.go | 0 tests/model.go => utils/tests/models.go | 0 utils/tests/utils.go | 112 +++++++++++ 58 files changed, 184 insertions(+), 1425 deletions(-) delete mode 100644 dialects/mssql/create.go delete mode 100644 dialects/mssql/migrator.go delete mode 100644 dialects/mssql/mssql.go delete mode 100644 dialects/mysql/migrator.go delete mode 100644 dialects/mysql/mysql.go delete mode 100644 dialects/postgres/migrator.go delete mode 100644 dialects/postgres/postgres.go delete mode 100644 dialects/sqlite/migrator.go delete mode 100644 dialects/sqlite/sqlite.go create mode 100644 tests/go.mod rename tests/{utils.go => helper_test.go} (66%) rename tests/{tests.go => tests_test.go} (87%) rename {tests => utils/tests}/dummy_dialecter.go (100%) rename tests/model.go => utils/tests/models.go (100%) create mode 100644 utils/tests/utils.go diff --git a/clause/benchmarks_test.go b/clause/benchmarks_test.go index 2faed773..88a238e3 100644 --- a/clause/benchmarks_test.go +++ b/clause/benchmarks_test.go @@ -7,7 +7,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" - "gorm.io/gorm/tests" + "gorm.io/gorm/utils/tests" ) func BenchmarkSelect(b *testing.B) { diff --git a/clause/clause_test.go b/clause/clause_test.go index f9d26a4a..6239ff39 100644 --- a/clause/clause_test.go +++ b/clause/clause_test.go @@ -9,7 +9,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" - "gorm.io/gorm/tests" + "gorm.io/gorm/utils/tests" ) var db, _ = gorm.Open(tests.DummyDialector{}, nil) diff --git a/clause/expression_test.go b/clause/expression_test.go index 4e937650..3059aea6 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -8,7 +8,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" - "gorm.io/gorm/tests" + "gorm.io/gorm/utils/tests" ) func TestExpr(t *testing.T) { diff --git a/dialects/mssql/create.go b/dialects/mssql/create.go deleted file mode 100644 index b07f13c5..00000000 --- a/dialects/mssql/create.go +++ /dev/null @@ -1,225 +0,0 @@ -package mssql - -import ( - "reflect" - "sort" - - "gorm.io/gorm" - "gorm.io/gorm/callbacks" - "gorm.io/gorm/clause" - "gorm.io/gorm/schema" -) - -func Create(db *gorm.DB) { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.CreateClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.String() == "" { - setIdentityInsert := false - c := db.Statement.Clauses["ON CONFLICT"] - onConflict, hasConflict := c.Expression.(clause.OnConflict) - - if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil { - setIdentityInsert = false - switch db.Statement.ReflectValue.Kind() { - case reflect.Struct: - _, isZero := field.ValueOf(db.Statement.ReflectValue) - setIdentityInsert = !isZero - case reflect.Slice: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - _, isZero := field.ValueOf(db.Statement.ReflectValue.Index(i)) - setIdentityInsert = !isZero - break - } - } - - if setIdentityInsert && (field.DataType == schema.Int || field.DataType == schema.Uint) { - setIdentityInsert = true - db.Statement.WriteString("SET IDENTITY_INSERT ") - db.Statement.WriteQuoted(db.Statement.Table) - db.Statement.WriteString(" ON;") - } else { - setIdentityInsert = false - } - } - - if hasConflict && len(db.Statement.Schema.PrimaryFields) > 0 { - MergeCreate(db, onConflict) - } else { - db.Statement.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: db.Statement.Table}}) - db.Statement.Build("INSERT") - db.Statement.WriteByte(' ') - - db.Statement.AddClause(callbacks.ConvertToCreateValues(db.Statement)) - if values, ok := db.Statement.Clauses["VALUES"].Expression.(clause.Values); ok { - if len(values.Columns) > 0 { - db.Statement.WriteByte('(') - for idx, column := range values.Columns { - if idx > 0 { - db.Statement.WriteByte(',') - } - db.Statement.WriteQuoted(column) - } - db.Statement.WriteByte(')') - - outputInserted(db) - - db.Statement.WriteString(" VALUES ") - - for idx, value := range values.Values { - if idx > 0 { - db.Statement.WriteByte(',') - } - - db.Statement.WriteByte('(') - db.Statement.AddVar(db.Statement, value...) - db.Statement.WriteByte(')') - } - - db.Statement.WriteString(";") - } else { - db.Statement.WriteString("DEFAULT VALUES;") - } - } - } - - if setIdentityInsert { - db.Statement.WriteString("SET IDENTITY_INSERT ") - db.Statement.WriteQuoted(db.Statement.Table) - db.Statement.WriteString(" OFF;") - } - } - - if !db.DryRun { - rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - - if err == nil { - defer rows.Close() - - if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { - sortedKeys := []string{} - for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { - sortedKeys = append(sortedKeys, field.DBName) - } - sort.Strings(sortedKeys) - - returnningFields := make([]*schema.Field, len(sortedKeys)) - for idx, key := range sortedKeys { - returnningFields[idx] = db.Statement.Schema.LookUpField(key) - } - - values := make([]interface{}, len(returnningFields)) - - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for rows.Next() { - for idx, field := range returnningFields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() - } - - db.RowsAffected++ - db.AddError(rows.Scan(values...)) - } - case reflect.Struct: - for idx, field := range returnningFields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() - } - - if rows.Next() { - db.RowsAffected++ - db.AddError(rows.Scan(values...)) - } - } - } - } else { - db.AddError(err) - } - } -} - -func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) { - values := callbacks.ConvertToCreateValues(db.Statement) - - db.Statement.WriteString("MERGE INTO ") - db.Statement.WriteQuoted(db.Statement.Table) - db.Statement.WriteString(" USING (VALUES") - for idx, value := range values.Values { - if idx > 0 { - db.Statement.WriteByte(',') - } - - db.Statement.WriteByte('(') - db.Statement.AddVar(db.Statement, value...) - db.Statement.WriteByte(')') - } - - db.Statement.WriteString(") AS source (") - for idx, column := range values.Columns { - if idx > 0 { - db.Statement.WriteByte(',') - } - db.Statement.WriteQuoted(column.Name) - } - db.Statement.WriteString(") ON ") - - var where clause.Where - for _, field := range db.Statement.Schema.PrimaryFields { - where.Exprs = append(where.Exprs, clause.Eq{ - Column: clause.Column{Table: db.Statement.Table, Name: field.DBName}, - Value: clause.Column{Table: "source", Name: field.DBName}, - }) - } - where.Build(db.Statement) - - if len(onConflict.DoUpdates) > 0 { - db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ") - onConflict.DoUpdates.Build(db.Statement) - } - - db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (") - - for idx, column := range values.Columns { - if idx > 0 { - db.Statement.WriteByte(',') - } - db.Statement.WriteQuoted(column.Name) - } - - db.Statement.WriteString(") VALUES (") - - for idx, column := range values.Columns { - if idx > 0 { - db.Statement.WriteByte(',') - } - db.Statement.WriteQuoted(clause.Column{ - Table: "source", - Name: column.Name, - }) - } - - db.Statement.WriteString(")") - outputInserted(db) - db.Statement.WriteString(";") -} - -func outputInserted(db *gorm.DB) { - if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { - sortedKeys := []string{} - for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { - sortedKeys = append(sortedKeys, field.DBName) - } - sort.Strings(sortedKeys) - - db.Statement.WriteString(" OUTPUT") - for idx, key := range sortedKeys { - if idx > 0 { - db.Statement.WriteString(",") - } - db.Statement.WriteString(" INSERTED.") - db.Statement.AddVar(db.Statement, clause.Column{Name: key}) - } - } -} diff --git a/dialects/mssql/migrator.go b/dialects/mssql/migrator.go deleted file mode 100644 index 3bb2086d..00000000 --- a/dialects/mssql/migrator.go +++ /dev/null @@ -1,142 +0,0 @@ -package mssql - -import ( - "fmt" - - "gorm.io/gorm" - "gorm.io/gorm/clause" - "gorm.io/gorm/migrator" -) - -type Migrator struct { - migrator.Migrator -} - -func (m Migrator) HasTable(value interface{}) bool { - var count int - m.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Raw( - "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", - stmt.Table, m.CurrentDatabase(), - ).Row().Scan(&count) - }) - return count > 0 -} - -func (m Migrator) RenameTable(oldName, newName interface{}) error { - var oldTable, newTable string - if v, ok := oldName.(string); ok { - oldTable = v - } else { - stmt := &gorm.Statement{DB: m.DB} - if err := stmt.Parse(oldName); err == nil { - oldTable = stmt.Table - } else { - return err - } - } - - if v, ok := newName.(string); ok { - newTable = v - } else { - stmt := &gorm.Statement{DB: m.DB} - if err := stmt.Parse(newName); err == nil { - newTable = stmt.Table - } else { - return err - } - } - - return m.DB.Exec( - "sp_rename @objname = ?, @newname = ?;", - clause.Table{Name: oldTable}, clause.Table{Name: newTable}, - ).Error -} - -func (m Migrator) HasColumn(value interface{}, field string) bool { - var count int64 - m.RunWithValue(value, func(stmt *gorm.Statement) error { - currentDatabase := m.DB.Migrator().CurrentDatabase() - name := field - if field := stmt.Schema.LookUpField(field); field != nil { - name = field.DBName - } - - return m.DB.Raw( - "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", - currentDatabase, stmt.Table, name, - ).Row().Scan(&count) - }) - - return count > 0 -} - -func (m Migrator) AlterColumn(value interface{}, field string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(field); field != nil { - return m.DB.Exec( - "ALTER TABLE ? ALTER COLUMN ? ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), - ).Error - } - return fmt.Errorf("failed to look up field with name: %s", field) - }) -} - -func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(oldName); field != nil { - oldName = field.DBName - } - - if field := stmt.Schema.LookUpField(newName); field != nil { - newName = field.DBName - } - - return m.DB.Exec( - "sp_rename @objname = ?, @newname = ?, @objtype = 'COLUMN';", - fmt.Sprintf("%s.%s", stmt.Table, oldName), clause.Column{Name: newName}, - ).Error - }) -} - -func (m Migrator) HasIndex(value interface{}, name string) bool { - var count int - m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name - } - - return m.DB.Raw( - "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", - name, stmt.Table, - ).Row().Scan(&count) - }) - return count > 0 -} - -func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - - return m.DB.Exec( - "sp_rename @objname = ?, @newname = ?, @objtype = 'INDEX';", - fmt.Sprintf("%s.%s", stmt.Table, oldName), clause.Column{Name: newName}, - ).Error - }) -} - -func (m Migrator) HasConstraint(value interface{}, name string) bool { - var count int64 - m.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Raw( - `SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join information_schema.tables as I on I.TABLE_NAME = T.name WHERE F.name = ? AND T.Name = ? AND I.TABLE_CATALOG = ?;`, - name, stmt.Table, m.CurrentDatabase(), - ).Row().Scan(&count) - }) - return count > 0 -} - -func (m Migrator) CurrentDatabase() (name string) { - m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name) - return -} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go deleted file mode 100644 index 3f87180c..00000000 --- a/dialects/mssql/mssql.go +++ /dev/null @@ -1,127 +0,0 @@ -package mssql - -import ( - "database/sql" - "fmt" - "regexp" - "strconv" - - _ "github.com/denisenkom/go-mssqldb" - "gorm.io/gorm" - "gorm.io/gorm/callbacks" - "gorm.io/gorm/clause" - "gorm.io/gorm/logger" - "gorm.io/gorm/migrator" - "gorm.io/gorm/schema" -) - -type Dialector struct { - DSN string -} - -func (dialector Dialector) Name() string { - return "mssql" -} - -func Open(dsn string) gorm.Dialector { - return &Dialector{DSN: dsn} -} - -func (dialector Dialector) Initialize(db *gorm.DB) (err error) { - // register callbacks - callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) - db.Callback().Create().Replace("gorm:create", Create) - db.ConnPool, err = sql.Open("sqlserver", dialector.DSN) - - for k, v := range dialector.ClauseBuilders() { - db.ClauseBuilders[k] = v - } - return -} - -func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { - return map[string]clause.ClauseBuilder{ - "LIMIT": func(c clause.Clause, builder clause.Builder) { - if limit, ok := c.Expression.(clause.Limit); ok { - if limit.Offset > 0 { - builder.WriteString("OFFSET ") - builder.WriteString(strconv.Itoa(limit.Offset)) - builder.WriteString("ROWS") - } - - if limit.Limit > 0 { - if limit.Offset == 0 { - builder.WriteString(" OFFSET 0 ROWS") - } - builder.WriteString(" FETCH NEXT ") - builder.WriteString(strconv.Itoa(limit.Limit)) - builder.WriteString(" ROWS ONLY") - } - } - }, - } -} - -func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { - return Migrator{migrator.Migrator{Config: migrator.Config{ - DB: db, - Dialector: dialector, - CreateIndexAfterCreateTable: true, - }}} -} - -func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { - writer.WriteString("@p") - writer.WriteString(strconv.Itoa(len(stmt.Vars))) -} - -func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { - writer.WriteByte('"') - writer.WriteString(str) - writer.WriteByte('"') -} - -var numericPlaceholder = regexp.MustCompile("@p(\\d+)") - -func (dialector Dialector) Explain(sql string, vars ...interface{}) string { - return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) -} - -func (dialector Dialector) DataTypeOf(field *schema.Field) string { - switch field.DataType { - case schema.Bool: - return "bit" - case schema.Int, schema.Uint: - var sqlType string - switch { - case field.Size < 16: - sqlType = "smallint" - case field.Size < 31: - sqlType = "int" - default: - sqlType = "bigint" - } - - if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { - return sqlType + " IDENTITY(1,1)" - } - return sqlType - case schema.Float: - return "float" - case schema.String: - size := field.Size - if field.PrimaryKey && size == 0 { - size = 256 - } - if size > 0 && size <= 4000 { - return fmt.Sprintf("nvarchar(%d)", size) - } - return "nvarchar(MAX)" - case schema.Time: - return "datetimeoffset" - case schema.Bytes: - return "varbinary(MAX)" - } - - return "" -} diff --git a/dialects/mysql/migrator.go b/dialects/mysql/migrator.go deleted file mode 100644 index 8d3d20c6..00000000 --- a/dialects/mysql/migrator.go +++ /dev/null @@ -1,58 +0,0 @@ -package mysql - -import ( - "fmt" - - "gorm.io/gorm" - "gorm.io/gorm/clause" - "gorm.io/gorm/migrator" -) - -type Migrator struct { - migrator.Migrator -} - -func (m Migrator) AlterColumn(value interface{}, field string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(field); field != nil { - return m.DB.Exec( - "ALTER TABLE ? MODIFY COLUMN ? ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), - ).Error - } - return fmt.Errorf("failed to look up field with name: %s", field) - }) -} - -func (m Migrator) DropTable(values ...interface{}) error { - values = m.ReorderModels(values, false) - tx := m.DB.Session(&gorm.Session{}) - tx.Exec("SET FOREIGN_KEY_CHECKS = 0;") - for i := len(values) - 1; i >= 0; i-- { - if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { - return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error - }); err != nil { - return err - } - } - tx.Exec("SET FOREIGN_KEY_CHECKS = 1;") - return nil -} - -func (m Migrator) DropConstraint(value interface{}, name string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - for _, chk := range stmt.Schema.ParseCheckConstraints() { - if chk.Name == name { - return m.DB.Exec( - "ALTER TABLE ? DROP CHECK ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: name}, - ).Error - } - } - - return m.DB.Exec( - "ALTER TABLE ? DROP FOREIGN KEY ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: name}, - ).Error - }) -} diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go deleted file mode 100644 index 035a6d79..00000000 --- a/dialects/mysql/mysql.go +++ /dev/null @@ -1,169 +0,0 @@ -package mysql - -import ( - "database/sql" - "fmt" - "math" - - _ "github.com/go-sql-driver/mysql" - "gorm.io/gorm" - "gorm.io/gorm/callbacks" - "gorm.io/gorm/clause" - "gorm.io/gorm/logger" - "gorm.io/gorm/migrator" - "gorm.io/gorm/schema" -) - -type Dialector struct { - DSN string -} - -func Open(dsn string) gorm.Dialector { - return &Dialector{DSN: dsn} -} - -func (dialector Dialector) Name() string { - return "mysql" -} - -func (dialector Dialector) Initialize(db *gorm.DB) (err error) { - // register callbacks - callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) - db.ConnPool, err = sql.Open("mysql", dialector.DSN) - - for k, v := range dialector.ClauseBuilders() { - db.ClauseBuilders[k] = v - } - return -} - -func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { - return map[string]clause.ClauseBuilder{ - "ON CONFLICT": func(c clause.Clause, builder clause.Builder) { - if onConflict, ok := c.Expression.(clause.OnConflict); ok { - builder.WriteString("ON DUPLICATE KEY UPDATE ") - if len(onConflict.DoUpdates) == 0 { - if s := builder.(*gorm.Statement).Schema; s != nil { - var column clause.Column - onConflict.DoNothing = false - - if s.PrioritizedPrimaryField != nil { - column = clause.Column{Name: s.PrioritizedPrimaryField.DBName} - } else { - for _, field := range s.FieldsByDBName { - column = clause.Column{Name: field.DBName} - break - } - } - onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}} - } - } - - onConflict.DoUpdates.Build(builder) - } else { - c.Build(builder) - } - }, - "VALUES": func(c clause.Clause, builder clause.Builder) { - if values, ok := c.Expression.(clause.Values); ok && len(values.Columns) == 0 { - builder.WriteString("VALUES()") - return - } - c.Build(builder) - }, - } -} - -func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { - return Migrator{migrator.Migrator{Config: migrator.Config{ - DB: db, - Dialector: dialector, - }}} -} - -func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { - writer.WriteByte('?') -} - -func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { - writer.WriteByte('`') - writer.WriteString(str) - writer.WriteByte('`') -} - -func (dialector Dialector) Explain(sql string, vars ...interface{}) string { - return logger.ExplainSQL(sql, nil, `"`, vars...) -} - -func (dialector Dialector) DataTypeOf(field *schema.Field) string { - switch field.DataType { - case schema.Bool: - return "boolean" - case schema.Int, schema.Uint: - sqlType := "int" - switch { - case field.Size <= 8: - sqlType = "tinyint" - case field.Size <= 16: - sqlType = "smallint" - case field.Size <= 32: - sqlType = "int" - default: - sqlType = "bigint" - } - - if field.DataType == schema.Uint { - sqlType += " unsigned" - } - - if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { - sqlType += " AUTO_INCREMENT" - } - return sqlType - case schema.Float: - if field.Size <= 32 { - return "float" - } - return "double" - case schema.String: - size := field.Size - if size == 0 { - if field.PrimaryKey || field.HasDefaultValue { - size = 256 - } - } - - if size >= 65536 && size <= int(math.Pow(2, 24)) { - return "mediumtext" - } else if size > int(math.Pow(2, 24)) || size <= 0 { - return "longtext" - } - return fmt.Sprintf("varchar(%d)", size) - case schema.Time: - precision := "" - if field.Precision == 0 { - field.Precision = 3 - } - - if field.Precision > 0 { - precision = fmt.Sprintf("(%d)", field.Precision) - } - - if field.NotNull || field.PrimaryKey { - return "datetime" + precision - } - return "datetime" + precision + " NULL" - case schema.Bytes: - if field.Size > 0 && field.Size < 65536 { - return fmt.Sprintf("varbinary(%d)", field.Size) - } - - if field.Size >= 65536 && field.Size <= int(math.Pow(2, 24)) { - return "mediumblob" - } - - return "longblob" - } - - return "" -} diff --git a/dialects/postgres/migrator.go b/dialects/postgres/migrator.go deleted file mode 100644 index 6b1085e3..00000000 --- a/dialects/postgres/migrator.go +++ /dev/null @@ -1,139 +0,0 @@ -package postgres - -import ( - "fmt" - - "gorm.io/gorm" - "gorm.io/gorm/clause" - "gorm.io/gorm/migrator" - "gorm.io/gorm/schema" -) - -type Migrator struct { - migrator.Migrator -} - -func (m Migrator) CurrentDatabase() (name string) { - m.DB.Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name) - return -} - -func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { - for _, opt := range opts { - str := stmt.Quote(opt.DBName) - if opt.Expression != "" { - str = opt.Expression - } - - if opt.Collate != "" { - str += " COLLATE " + opt.Collate - } - - if opt.Sort != "" { - str += " " + opt.Sort - } - results = append(results, clause.Expr{SQL: str}) - } - return -} - -func (m Migrator) HasIndex(value interface{}, name string) bool { - var count int64 - m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name - } - - return m.DB.Raw( - "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, name, - ).Row().Scan(&count) - }) - - return count > 0 -} - -func (m Migrator) CreateIndex(value interface{}, name string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - opts := m.BuildIndexOptions(idx.Fields, stmt) - values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} - - createIndexSQL := "CREATE " - if idx.Class != "" { - createIndexSQL += idx.Class + " " - } - createIndexSQL += "INDEX ?" - - if idx.Type != "" { - createIndexSQL += " USING " + idx.Type - } - createIndexSQL += " ON ??" - - if idx.Where != "" { - createIndexSQL += " WHERE " + idx.Where - } - - return m.DB.Exec(createIndexSQL, values...).Error - } - - return fmt.Errorf("failed to create index with name %v", name) - }) -} - -func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Exec( - "ALTER INDEX ? RENAME TO ?", - clause.Column{Name: oldName}, clause.Column{Name: newName}, - ).Error - }) -} - -func (m Migrator) DropIndex(value interface{}, name string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name - } - - return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error - }) -} - -func (m Migrator) HasTable(value interface{}) bool { - var count int64 - m.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND table_type = ?", stmt.Table, "BASE TABLE").Row().Scan(&count) - }) - - return count > 0 -} - -func (m Migrator) DropTable(values ...interface{}) error { - values = m.ReorderModels(values, false) - tx := m.DB.Session(&gorm.Session{}) - for i := len(values) - 1; i >= 0; i-- { - if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { - return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error - }); err != nil { - return err - } - } - return nil -} - -func (m Migrator) HasColumn(value interface{}, field string) bool { - var count int64 - m.RunWithValue(value, func(stmt *gorm.Statement) error { - name := field - if field := stmt.Schema.LookUpField(field); field != nil { - name = field.DBName - } - - return m.DB.Raw( - "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?", - stmt.Table, name, - ).Row().Scan(&count) - }) - - return count > 0 -} diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go deleted file mode 100644 index 57e51d58..00000000 --- a/dialects/postgres/postgres.go +++ /dev/null @@ -1,102 +0,0 @@ -package postgres - -import ( - "database/sql" - "fmt" - "regexp" - "strconv" - - "gorm.io/gorm" - "gorm.io/gorm/callbacks" - "gorm.io/gorm/clause" - "gorm.io/gorm/logger" - "gorm.io/gorm/migrator" - "gorm.io/gorm/schema" - _ "github.com/lib/pq" -) - -type Dialector struct { - DSN string -} - -func Open(dsn string) gorm.Dialector { - return &Dialector{DSN: dsn} -} - -func (dialector Dialector) Name() string { - return "postgres" -} - -func (dialector Dialector) Initialize(db *gorm.DB) (err error) { - // register callbacks - callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ - WithReturning: true, - }) - db.ConnPool, err = sql.Open("postgres", dialector.DSN) - return -} - -func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { - return Migrator{migrator.Migrator{Config: migrator.Config{ - DB: db, - Dialector: dialector, - CreateIndexAfterCreateTable: true, - }}} -} - -func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { - writer.WriteByte('$') - writer.WriteString(strconv.Itoa(len(stmt.Vars))) -} - -func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { - writer.WriteByte('"') - writer.WriteString(str) - writer.WriteByte('"') -} - -var numericPlaceholder = regexp.MustCompile("\\$(\\d+)") - -func (dialector Dialector) Explain(sql string, vars ...interface{}) string { - return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) -} - -func (dialector Dialector) DataTypeOf(field *schema.Field) string { - switch field.DataType { - case schema.Bool: - return "boolean" - case schema.Int, schema.Uint: - if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { - switch { - case field.Size < 16: - return "smallserial" - case field.Size < 31: - return "serial" - default: - return "bigserial" - } - } else { - switch { - case field.Size < 16: - return "smallint" - case field.Size < 31: - return "integer" - default: - return "bigint" - } - } - case schema.Float: - return "decimal" - case schema.String: - if field.Size > 0 { - return fmt.Sprintf("varchar(%d)", field.Size) - } - return "text" - case schema.Time: - return "timestamptz" - case schema.Bytes: - return "bytea" - } - - return "" -} diff --git a/dialects/sqlite/migrator.go b/dialects/sqlite/migrator.go deleted file mode 100644 index 14c682ca..00000000 --- a/dialects/sqlite/migrator.go +++ /dev/null @@ -1,211 +0,0 @@ -package sqlite - -import ( - "fmt" - "regexp" - "strings" - - "gorm.io/gorm" - "gorm.io/gorm/clause" - "gorm.io/gorm/migrator" - "gorm.io/gorm/schema" -) - -type Migrator struct { - migrator.Migrator -} - -func (m Migrator) HasTable(value interface{}) bool { - var count int - m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count) - }) - return count > 0 -} - -func (m Migrator) HasColumn(value interface{}, name string) bool { - var count int - m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(name); field != nil { - name = field.DBName - } - - return m.DB.Raw( - "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)", - "table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", - ).Row().Scan(&count) - }) - return count > 0 -} - -func (m Migrator) AlterColumn(value interface{}, name string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(name); field != nil { - var ( - createSQL string - newTableName = stmt.Table + "__temp" - ) - - m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL) - - if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil { - tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ") - if err != nil { - return err - } - - createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) - createSQL = reg.ReplaceAllString(createSQL, "?") - - var columns []string - columnTypes, _ := m.DB.Migrator().ColumnTypes(value) - for _, columnType := range columnTypes { - columns = append(columns, fmt.Sprintf("`%v`", columnType.Name())) - } - - createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table) - return m.DB.Exec(createSQL, m.FullDataTypeOf(field)).Error - } else { - return err - } - } else { - return fmt.Errorf("failed to alter field with name %v", name) - } - }) -} - -func (m Migrator) DropColumn(value interface{}, name string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(name); field != nil { - name = field.DBName - } - - var ( - createSQL string - newTableName = stmt.Table + "__temp" - ) - - m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL) - - if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil { - tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ") - if err != nil { - return err - } - - createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) - createSQL = reg.ReplaceAllString(createSQL, "") - - var columns []string - columnTypes, _ := m.DB.Migrator().ColumnTypes(value) - for _, columnType := range columnTypes { - if columnType.Name() != name { - columns = append(columns, fmt.Sprintf("`%v`", columnType.Name())) - } - } - - createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table) - - return m.DB.Exec(createSQL).Error - } else { - return err - } - }) -} - -func (m Migrator) CreateConstraint(interface{}, string) error { - return gorm.ErrNotImplemented -} - -func (m Migrator) DropConstraint(interface{}, string) error { - return gorm.ErrNotImplemented -} - -func (m Migrator) CurrentDatabase() (name string) { - var null interface{} - m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null) - return -} - -func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { - for _, opt := range opts { - str := stmt.Quote(opt.DBName) - if opt.Expression != "" { - str = opt.Expression - } - - if opt.Collate != "" { - str += " COLLATE " + opt.Collate - } - - if opt.Sort != "" { - str += " " + opt.Sort - } - results = append(results, clause.Expr{SQL: str}) - } - return -} - -func (m Migrator) CreateIndex(value interface{}, name string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - opts := m.BuildIndexOptions(idx.Fields, stmt) - values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} - - createIndexSQL := "CREATE " - if idx.Class != "" { - createIndexSQL += idx.Class + " " - } - createIndexSQL += "INDEX ?" - - if idx.Type != "" { - createIndexSQL += " USING " + idx.Type - } - createIndexSQL += " ON ??" - - if idx.Where != "" { - createIndexSQL += " WHERE " + idx.Where - } - - return m.DB.Exec(createIndexSQL, values...).Error - } - - return fmt.Errorf("failed to create index with name %v", name) - }) -} - -func (m Migrator) HasIndex(value interface{}, name string) bool { - var count int - m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name - } - - m.DB.Raw( - "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name, - ).Row().Scan(&count) - return nil - }) - return count > 0 -} - -func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - var sql string - m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql) - if sql != "" { - return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error - } - return fmt.Errorf("failed to find index with name %v", oldName) - }) -} - -func (m Migrator) DropIndex(value interface{}, name string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name - } - - return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error - }) -} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go deleted file mode 100644 index 238ad7f9..00000000 --- a/dialects/sqlite/sqlite.go +++ /dev/null @@ -1,80 +0,0 @@ -package sqlite - -import ( - "database/sql" - - "gorm.io/gorm" - "gorm.io/gorm/callbacks" - "gorm.io/gorm/clause" - "gorm.io/gorm/logger" - "gorm.io/gorm/migrator" - "gorm.io/gorm/schema" - _ "github.com/mattn/go-sqlite3" -) - -type Dialector struct { - DSN string -} - -func Open(dsn string) gorm.Dialector { - return &Dialector{DSN: dsn} -} - -func (dialector Dialector) Name() string { - return "sqlite" -} - -func (dialector Dialector) Initialize(db *gorm.DB) (err error) { - // register callbacks - callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ - LastInsertIDReversed: true, - }) - db.ConnPool, err = sql.Open("sqlite3", dialector.DSN) - return -} - -func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { - return Migrator{migrator.Migrator{Config: migrator.Config{ - DB: db, - Dialector: dialector, - CreateIndexAfterCreateTable: true, - }}} -} - -func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { - writer.WriteByte('?') -} - -func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { - writer.WriteByte('`') - writer.WriteString(str) - writer.WriteByte('`') -} - -func (dialector Dialector) Explain(sql string, vars ...interface{}) string { - return logger.ExplainSQL(sql, nil, `"`, vars...) -} - -func (dialector Dialector) DataTypeOf(field *schema.Field) string { - switch field.DataType { - case schema.Bool: - return "numeric" - case schema.Int, schema.Uint: - if field.AutoIncrement { - // https://www.sqlite.org/autoinc.html - return "integer PRIMARY KEY AUTOINCREMENT" - } else { - return "integer" - } - case schema.Float: - return "real" - case schema.String: - return "text" - case schema.Time: - return "datetime" - case schema.Bytes: - return "blob" - } - - return "" -} diff --git a/go.mod b/go.mod index 26877c7a..faf63a46 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,6 @@ module gorm.io/gorm go 1.14 require ( - github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc - github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 - github.com/go-sql-driver/mysql v1.5.0 github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.1.1 - github.com/lib/pq v1.1.1 - github.com/mattn/go-sqlite3 v2.0.1+incompatible - gorm.io/gorm v1.9.12 ) diff --git a/schema/field_test.go b/schema/field_test.go index 7a47f195..fe88891f 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -9,7 +9,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/schema" - "gorm.io/gorm/tests" + "gorm.io/gorm/utils/tests" ) func TestFieldValuerAndSetter(t *testing.T) { diff --git a/schema/model_test.go b/schema/model_test.go index 068b3050..a13372b5 100644 --- a/schema/model_test.go +++ b/schema/model_test.go @@ -5,7 +5,7 @@ import ( "time" "gorm.io/gorm" - "gorm.io/gorm/tests" + "gorm.io/gorm/utils/tests" ) type User struct { diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index b966164e..f2ed4145 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -7,7 +7,7 @@ import ( "testing" "gorm.io/gorm/schema" - "gorm.io/gorm/tests" + "gorm.io/gorm/utils/tests" ) func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { diff --git a/schema/schema_test.go b/schema/schema_test.go index 6902cbf2..1029f74f 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -5,7 +5,7 @@ import ( "testing" "gorm.io/gorm/schema" - "gorm.io/gorm/tests" + "gorm.io/gorm/utils/tests" ) func TestParseSchema(t *testing.T) { diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index 27b82ecb..35419666 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestBelongsToAssociation(t *testing.T) { diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index 88df8532..7ef0c218 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestHasManyAssociation(t *testing.T) { diff --git a/tests/associations_has_one_test.go b/tests/associations_has_one_test.go index 9ddfa9c5..f32a692d 100644 --- a/tests/associations_has_one_test.go +++ b/tests/associations_has_one_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestHasOneAssociation(t *testing.T) { diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index d79cdc17..ba9695b7 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestMany2ManyAssociation(t *testing.T) { diff --git a/tests/associations_test.go b/tests/associations_test.go index 2e30df8b..44262109 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) { diff --git a/tests/count_test.go b/tests/count_test.go index d8cfa405..63238089 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestCount(t *testing.T) { diff --git a/tests/create_test.go b/tests/create_test.go index 2f853c61..c497014e 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -6,7 +6,7 @@ import ( "github.com/jinzhu/now" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestCreate(t *testing.T) { diff --git a/tests/customize_column_test.go b/tests/customize_column_test.go index 0db40869..98dea494 100644 --- a/tests/customize_column_test.go +++ b/tests/customize_column_test.go @@ -3,8 +3,6 @@ package tests_test import ( "testing" "time" - - . "gorm.io/gorm/tests" ) func TestCustomizeColumn(t *testing.T) { diff --git a/tests/delete_test.go b/tests/delete_test.go index 0fe2ee75..66c396d1 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -5,7 +5,7 @@ import ( "testing" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestDelete(t *testing.T) { diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 74829460..9a1436fe 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -4,7 +4,6 @@ import ( "testing" "gorm.io/gorm" - . "gorm.io/gorm/tests" ) func TestEmbeddedStruct(t *testing.T) { diff --git a/tests/go.mod b/tests/go.mod new file mode 100644 index 00000000..3954c442 --- /dev/null +++ b/tests/go.mod @@ -0,0 +1,14 @@ +module gorm.io/gorm/tests + +go 1.14 + +require ( + github.com/jinzhu/now v1.1.1 + gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0 + gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 + gorm.io/driver/sqlite v0.0.0-20200602015323-284b563f81c8 + gorm.io/driver/sqlserver v0.0.0-20200602015206-ef9f739c6a30 + gorm.io/gorm v1.9.12 +) + +replace gorm.io/gorm => ../ diff --git a/tests/group_by_test.go b/tests/group_by_test.go index 5a954348..cb4c4f43 100644 --- a/tests/group_by_test.go +++ b/tests/group_by_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestGroupBy(t *testing.T) { diff --git a/tests/utils.go b/tests/helper_test.go similarity index 66% rename from tests/utils.go rename to tests/helper_test.go index 0b4b138e..b05f5297 100644 --- a/tests/utils.go +++ b/tests/helper_test.go @@ -1,17 +1,13 @@ -package tests +package tests_test import ( - "database/sql/driver" - "fmt" - "go/ast" - "reflect" "sort" "strconv" "strings" "testing" "time" - "gorm.io/gorm/utils" + . "gorm.io/gorm/utils/tests" ) type Config struct { @@ -73,101 +69,6 @@ func GetUser(name string, config Config) *User { return &user } -func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { - for _, name := range names { - got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() - expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() - t.Run(name, func(t *testing.T) { - AssertEqual(t, got, expect) - }) - } -} - -func AssertEqual(t *testing.T, got, expect interface{}) { - if !reflect.DeepEqual(got, expect) { - isEqual := func() { - if curTime, ok := got.(time.Time); ok { - format := "2006-01-02T15:04:05Z07:00" - - if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) && curTime.Truncate(time.Second).Format(format) != expect.(time.Time).Truncate(time.Second).Format(format) { - t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime) - } - } else if fmt.Sprint(got) != fmt.Sprint(expect) { - t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) - } - } - - if fmt.Sprint(got) == fmt.Sprint(expect) { - return - } - - if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() { - t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) - return - } - - if valuer, ok := got.(driver.Valuer); ok { - got, _ = valuer.Value() - } - - if valuer, ok := expect.(driver.Valuer); ok { - expect, _ = valuer.Value() - } - - if got != nil { - got = reflect.Indirect(reflect.ValueOf(got)).Interface() - } - - if expect != nil { - expect = reflect.Indirect(reflect.ValueOf(expect)).Interface() - } - - if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() { - t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) - return - } - - if reflect.ValueOf(got).Kind() == reflect.Slice { - if reflect.ValueOf(expect).Kind() == reflect.Slice { - if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() { - for i := 0; i < reflect.ValueOf(got).Len(); i++ { - name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i) - t.Run(name, func(t *testing.T) { - AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface()) - }) - } - } else { - name := reflect.ValueOf(got).Type().Elem().Name() - t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len()) - } - return - } - } - - if reflect.ValueOf(got).Kind() == reflect.Struct { - if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { - for i := 0; i < reflect.ValueOf(got).NumField(); i++ { - if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { - field := reflect.ValueOf(got).Field(i) - t.Run(fieldStruct.Name, func(t *testing.T) { - AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) - }) - } - } - return - } - } - - if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { - got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() - isEqual() - } else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { - expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() - isEqual() - } - } -} - func CheckPet(t *testing.T, pet Pet, expect Pet) { if pet.ID != 0 { var newPet Pet diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 418713a6..e2850c27 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -6,7 +6,6 @@ import ( "testing" "gorm.io/gorm" - . "gorm.io/gorm/tests" ) type Product struct { diff --git a/tests/joins_table_test.go b/tests/joins_table_test.go index 5738d8f4..b8c1be77 100644 --- a/tests/joins_table_test.go +++ b/tests/joins_table_test.go @@ -5,7 +5,6 @@ import ( "time" "gorm.io/gorm" - . "gorm.io/gorm/tests" ) type Person struct { diff --git a/tests/joins_test.go b/tests/joins_test.go index 651b20c6..f01c8211 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -5,7 +5,7 @@ import ( "testing" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestJoins(t *testing.T) { diff --git a/tests/main_test.go b/tests/main_test.go index 2d466125..ff293e6e 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestMain(m *testing.M) { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index b511ab40..5293898f 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -7,7 +7,7 @@ import ( "time" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestMigrate(t *testing.T) { diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go index 139cde69..05267bbb 100644 --- a/tests/multi_primary_keys_test.go +++ b/tests/multi_primary_keys_test.go @@ -4,8 +4,6 @@ import ( "reflect" "sort" "testing" - - . "gorm.io/gorm/tests" ) type Blog struct { @@ -36,8 +34,8 @@ func compareTags(tags []Tag, contents []string) bool { } func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { - if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { - t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") + if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { + t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") @@ -125,8 +123,8 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { } func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { - if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { - t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") + if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { + t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") @@ -246,8 +244,8 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { } func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { - if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { - t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") + if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { + t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") diff --git a/tests/named_polymorphic_test.go b/tests/named_polymorphic_test.go index 99a7865a..61655784 100644 --- a/tests/named_polymorphic_test.go +++ b/tests/named_polymorphic_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) type Hamster struct { diff --git a/tests/non_std_test.go b/tests/non_std_test.go index b3ac6545..d3561b11 100644 --- a/tests/non_std_test.go +++ b/tests/non_std_test.go @@ -3,8 +3,6 @@ package tests_test import ( "testing" "time" - - . "gorm.io/gorm/tests" ) type Animal struct { diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go index 42e94fa0..98f24daf 100644 --- a/tests/preload_suits_test.go +++ b/tests/preload_suits_test.go @@ -7,7 +7,6 @@ import ( "testing" "gorm.io/gorm" - . "gorm.io/gorm/tests" ) func toJSONString(v interface{}) []byte { @@ -691,8 +690,8 @@ func TestNestedPreload12(t *testing.T) { } func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { - if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { - t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") + if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { + t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } type ( diff --git a/tests/preload_test.go b/tests/preload_test.go index e4ecdc87..06e38f09 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -6,7 +6,7 @@ import ( "testing" "gorm.io/gorm/clause" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestNestedPreload(t *testing.T) { diff --git a/tests/query_test.go b/tests/query_test.go index 9d15a41f..f6fb1081 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -9,7 +9,7 @@ import ( "time" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestFind(t *testing.T) { diff --git a/tests/scan_test.go b/tests/scan_test.go index 262ac9a7..d6a372bb 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -6,7 +6,7 @@ import ( "strings" "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestScan(t *testing.T) { diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 7dad081f..7d72db15 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -11,7 +11,7 @@ import ( "time" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestScannerValuer(t *testing.T) { diff --git a/tests/scopes_test.go b/tests/scopes_test.go index a2a7de3f..c9787d36 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -4,7 +4,7 @@ import ( "testing" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func NameIn1And2(d *gorm.DB) *gorm.DB { diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go index 24b06498..c632c753 100644 --- a/tests/soft_delete_test.go +++ b/tests/soft_delete_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestSoftDelete(t *testing.T) { diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 0f3a56ed..278a5b96 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -4,7 +4,7 @@ import ( "testing" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestRow(t *testing.T) { diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 3a1b45c8..95245804 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -18,8 +18,13 @@ for dialect in "${dialects[@]}" ; do if [ "$GORM_VERBOSE" = "" ] then DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 ./... + cd tests + DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 ./... else DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 -v ./... + cd tests + DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 -v ./... fi + cd .. fi done diff --git a/tests/tests.go b/tests/tests_test.go similarity index 87% rename from tests/tests.go rename to tests/tests_test.go index 42902685..40816c3c 100644 --- a/tests/tests.go +++ b/tests/tests_test.go @@ -1,4 +1,4 @@ -package tests +package tests_test import ( "log" @@ -7,12 +7,13 @@ import ( "path/filepath" "time" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/driver/sqlserver" "gorm.io/gorm" - "gorm.io/gorm/dialects/mssql" - "gorm.io/gorm/dialects/mysql" - "gorm.io/gorm/dialects/postgres" - "gorm.io/gorm/dialects/sqlite" "gorm.io/gorm/logger" + . "gorm.io/gorm/utils/tests" ) var DB *gorm.DB @@ -40,17 +41,17 @@ func OpenTestConnection() (db *gorm.DB, err error) { dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" } db, err = gorm.Open(postgres.Open(dbDSN), &gorm.Config{}) - case "mssql": + case "sqlserver": // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; // CREATE DATABASE gorm; // USE gorm; // CREATE USER gorm FROM LOGIN gorm; // sp_changedbowner 'gorm'; - log.Println("testing mssql...") + log.Println("testing sqlserver...") if dbDSN == "" { dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" } - db, err = gorm.Open(mssql.Open(dbDSN), &gorm.Config{}) + db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{}) default: log.Println("testing sqlite3...") db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) @@ -90,8 +91,3 @@ func RunMigrations() { } } } - -func Now() *time.Time { - now := time.Now() - return &now -} diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 4ff1b485..b810e3bb 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -6,7 +6,7 @@ import ( "testing" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestTransaction(t *testing.T) { diff --git a/tests/update_belongs_to_test.go b/tests/update_belongs_to_test.go index 7c578b38..47076e69 100644 --- a/tests/update_belongs_to_test.go +++ b/tests/update_belongs_to_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestUpdateBelongsTo(t *testing.T) { diff --git a/tests/update_has_many_test.go b/tests/update_has_many_test.go index 5501c519..01ea2e3a 100644 --- a/tests/update_has_many_test.go +++ b/tests/update_has_many_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestUpdateHasManyAssociations(t *testing.T) { diff --git a/tests/update_has_one_test.go b/tests/update_has_one_test.go index 721c302a..7b29f424 100644 --- a/tests/update_has_one_test.go +++ b/tests/update_has_one_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestUpdateHasOne(t *testing.T) { diff --git a/tests/update_many2many_test.go b/tests/update_many2many_test.go index 5548444f..a46deeb0 100644 --- a/tests/update_many2many_test.go +++ b/tests/update_many2many_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestUpdateMany2ManyAssociations(t *testing.T) { diff --git a/tests/update_test.go b/tests/update_test.go index aef7f4ce..524e9ea6 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -8,7 +8,7 @@ import ( "time" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestUpdate(t *testing.T) { diff --git a/tests/upsert_test.go b/tests/upsert_test.go index 87b223b4..412be305 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -5,7 +5,7 @@ import ( "time" "gorm.io/gorm/clause" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestUpsert(t *testing.T) { diff --git a/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go similarity index 100% rename from tests/dummy_dialecter.go rename to utils/tests/dummy_dialecter.go diff --git a/tests/model.go b/utils/tests/models.go similarity index 100% rename from tests/model.go rename to utils/tests/models.go diff --git a/utils/tests/utils.go b/utils/tests/utils.go new file mode 100644 index 00000000..5248e620 --- /dev/null +++ b/utils/tests/utils.go @@ -0,0 +1,112 @@ +package tests + +import ( + "database/sql/driver" + "fmt" + "go/ast" + "reflect" + "testing" + "time" + + "gorm.io/gorm/utils" +) + +func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { + for _, name := range names { + got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() + expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() + t.Run(name, func(t *testing.T) { + AssertEqual(t, got, expect) + }) + } +} + +func AssertEqual(t *testing.T, got, expect interface{}) { + if !reflect.DeepEqual(got, expect) { + isEqual := func() { + if curTime, ok := got.(time.Time); ok { + format := "2006-01-02T15:04:05Z07:00" + + if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) && curTime.Truncate(time.Second).Format(format) != expect.(time.Time).Truncate(time.Second).Format(format) { + t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime) + } + } else if fmt.Sprint(got) != fmt.Sprint(expect) { + t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) + } + } + + if fmt.Sprint(got) == fmt.Sprint(expect) { + return + } + + if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() { + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) + return + } + + if valuer, ok := got.(driver.Valuer); ok { + got, _ = valuer.Value() + } + + if valuer, ok := expect.(driver.Valuer); ok { + expect, _ = valuer.Value() + } + + if got != nil { + got = reflect.Indirect(reflect.ValueOf(got)).Interface() + } + + if expect != nil { + expect = reflect.Indirect(reflect.ValueOf(expect)).Interface() + } + + if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() { + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) + return + } + + if reflect.ValueOf(got).Kind() == reflect.Slice { + if reflect.ValueOf(expect).Kind() == reflect.Slice { + if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() { + for i := 0; i < reflect.ValueOf(got).Len(); i++ { + name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i) + t.Run(name, func(t *testing.T) { + AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface()) + }) + } + } else { + name := reflect.ValueOf(got).Type().Elem().Name() + t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len()) + } + return + } + } + + if reflect.ValueOf(got).Kind() == reflect.Struct { + if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { + for i := 0; i < reflect.ValueOf(got).NumField(); i++ { + if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { + field := reflect.ValueOf(got).Field(i) + t.Run(fieldStruct.Name, func(t *testing.T) { + AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) + }) + } + } + return + } + } + + if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { + got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() + isEqual() + } else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { + expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() + isEqual() + } + } +} + +func Now() *time.Time { + now := time.Now() + return &now +} From 64ed645e4da552703257f3a3b37bf92714368859 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 2 Jun 2020 11:09:17 +0800 Subject: [PATCH 387/881] Returns ping error --- gorm.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/gorm.go b/gorm.go index 1ab3fd64..8a801d68 100644 --- a/gorm.go +++ b/gorm.go @@ -91,6 +91,17 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { if dialector != nil { err = dialector.Initialize(db) } + + if err == nil { + if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok { + err = pinger.Ping() + } + } + + if err != nil { + config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err) + } + return } From 669ce48f1924d1d67cbaca2fcccec94c074cb5ca Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 2 Jun 2020 11:30:21 +0800 Subject: [PATCH 388/881] Fix order by primary key if it is not defined --- statement.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/statement.go b/statement.go index 8f4762e7..ebd6e234 100644 --- a/statement.go +++ b/statement.go @@ -90,6 +90,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { if v.Name == clause.PrimaryKey { if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil { stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName) + } else if len(stmt.Schema.DBNames) > 0 { + stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.DBNames[0]) } } else if v.Raw { writer.WriteString(v.Name) From e959a67f87d5a7264724fdabc759bce92a1de68c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 2 Jun 2020 12:46:55 +0800 Subject: [PATCH 389/881] Fix callbacks with Match --- callbacks.go | 1 + 1 file changed, 1 insertion(+) diff --git a/callbacks.go b/callbacks.go index c5654c50..a9a6dd85 100644 --- a/callbacks.go +++ b/callbacks.go @@ -150,6 +150,7 @@ func (p *processor) compile() (err error) { callbacks = append(callbacks, callback) } } + p.callbacks = callbacks if p.fns, err = sortCallbacks(p.callbacks); err != nil { logger.Default.Error(context.Background(), "Got error when compile callbacks, got %v", err) From 2218e32999cb1f205c16a139e28e1bd877e4d151 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 2 Jun 2020 15:48:19 +0800 Subject: [PATCH 390/881] Allow customize table name with TableName --- schema/schema.go | 15 ++++++++++++--- schema/schema_test.go | 18 ++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index 60e621de..9e05303a 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -82,6 +82,10 @@ func (schema Schema) LookUpField(name string) *Field { return nil } +type Tabler interface { + TableName() string +} + // get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() @@ -100,10 +104,16 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) return v.(*Schema), nil } + modelValue := reflect.New(modelType) + tableName := namer.TableName(modelType.Name()) + if tabler, ok := modelValue.Interface().(Tabler); ok { + tableName = tabler.TableName() + } + schema := &Schema{ Name: modelType.Name(), ModelType: modelType, - Table: namer.TableName(modelType.Name()), + Table: tableName, FieldsByName: map[string]*Field{}, FieldsByDBName: map[string]*Field{}, Relationships: Relationships{Relations: map[string]*Relationship{}}, @@ -200,10 +210,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } - reflectValue := reflect.New(modelType) callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} for _, name := range callbacks { - if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() { + if methodValue := modelValue.MethodByName(name); methodValue.IsValid() { switch methodValue.Type().String() { case "func(*gorm.DB) error": // TODO hack reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) diff --git a/schema/schema_test.go b/schema/schema_test.go index 1029f74f..82f07fa8 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -142,3 +142,21 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) { }) } } + +type CustomizeTable struct { +} + +func (CustomizeTable) TableName() string { + return "customize" +} + +func TestCustomizeTableName(t *testing.T) { + customize, err := schema.Parse(&CustomizeTable{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse pointer user, got error %v", err) + } + + if customize.Table != "customize" { + t.Errorf("Failed to customize table with TableName method") + } +} From 94685d102430d8549aa60180dff83e3970e2fb91 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 2 Jun 2020 22:13:53 +0800 Subject: [PATCH 391/881] Fix can't scan null value into normal data types --- finisher_api.go | 2 +- scan.go | 158 ++++++++++++++++++++++++------------ schema/field.go | 121 ++++++++++++++------------- statement.go | 12 ++- tests/main_test.go | 5 -- tests/preload_suits_test.go | 1 - tests/query_test.go | 32 ++++++++ tests/tests_all.sh | 4 +- tests/tests_test.go | 2 + tests/update_test.go | 6 +- tests/upsert_test.go | 5 ++ 11 files changed, 226 insertions(+), 122 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 5023150c..b97f2301 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -168,7 +168,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { return tx.Create(dest) } else if len(tx.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]) + exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) assigns := map[string]interface{}{} for _, expr := range exprs { if eq, ok := expr.(clause.Eq); ok { diff --git a/scan.go b/scan.go index fc6b211b..14a4699d 100644 --- a/scan.go +++ b/scan.go @@ -14,40 +14,53 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { switch dest := db.Statement.Dest.(type) { case map[string]interface{}, *map[string]interface{}: - for idx, _ := range columns { - values[idx] = new(interface{}) - } - if initialized || rows.Next() { + for idx := range columns { + values[idx] = new(interface{}) + } + db.RowsAffected++ db.AddError(rows.Scan(values...)) - } - mapValue, ok := dest.(map[string]interface{}) - if ok { - if v, ok := dest.(*map[string]interface{}); ok { - mapValue = *v + mapValue, ok := dest.(map[string]interface{}) + if !ok { + if v, ok := dest.(*map[string]interface{}); ok { + mapValue = *v + } + } + + for idx, column := range columns { + if v, ok := values[idx].(*interface{}); ok { + if v == nil { + mapValue[column] = nil + } else { + mapValue[column] = *v + } + } } } - - for idx, column := range columns { - mapValue[column] = *(values[idx].(*interface{})) - } case *[]map[string]interface{}: - for idx, _ := range columns { - values[idx] = new(interface{}) - } - for initialized || rows.Next() { + for idx := range columns { + values[idx] = new(interface{}) + } + initialized = false db.RowsAffected++ db.AddError(rows.Scan(values...)) - v := map[string]interface{}{} + mapValue := map[string]interface{}{} for idx, column := range columns { - v[column] = *(values[idx].(*interface{})) + if v, ok := values[idx].(*interface{}); ok { + if v == nil { + mapValue[column] = nil + } else { + mapValue[column] = *v + } + } } - *dest = append(*dest, v) + + *dest = append(*dest, mapValue) } case *int, *int64, *uint, *uint64: for initialized || rows.Next() { @@ -85,28 +98,52 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } for initialized || rows.Next() { + for idx := range columns { + values[idx] = new(interface{}) + } + initialized = false + db.RowsAffected++ + elem := reflect.New(reflectValueType).Elem() if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 { + // pluck values[0] = elem.Addr().Interface() + db.AddError(rows.Scan(values...)) } else { - for idx, field := range fields { - if field != nil { - values[idx] = field.ReflectValueOf(elem).Addr().Interface() - } else if joinFields[idx][0] != nil { - relValue := joinFields[idx][0].ReflectValueOf(elem) - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - relValue.Set(reflect.New(relValue.Type().Elem())) - } + db.AddError(rows.Scan(values...)) - values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface() + for idx, field := range fields { + if v, ok := values[idx].(*interface{}); ok { + if field != nil { + if v == nil { + field.Set(elem, v) + } else { + field.Set(elem, *v) + } + } else if joinFields[idx][0] != nil { + relValue := joinFields[idx][0].ReflectValueOf(elem) + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if v == nil { + continue + } + relValue.Set(reflect.New(relValue.Type().Elem())) + } + + if v == nil { + joinFields[idx][1].Set(relValue, nil) + } else { + joinFields[idx][1].Set(relValue, *v) + } + } } } - } - db.RowsAffected++ - db.AddError(rows.Scan(values...)) + for idx := range columns { + values[idx] = new(interface{}) + } + } if isPtr { db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr())) @@ -115,30 +152,45 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } } case reflect.Struct: - for idx, column := range columns { - if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { - relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - relValue.Set(reflect.New(relValue.Type().Elem())) - } - - values[idx] = field.ReflectValueOf(relValue).Addr().Interface() - continue - } - } - values[idx] = &sql.RawBytes{} - } else { - values[idx] = &sql.RawBytes{} - } - } - if initialized || rows.Next() { + for idx := range columns { + values[idx] = new(interface{}) + } + db.RowsAffected++ db.AddError(rows.Scan(values...)) + + for idx, column := range columns { + if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { + if v, ok := values[idx].(*interface{}); ok { + if v == nil { + field.Set(db.Statement.ReflectValue, v) + } else { + field.Set(db.Statement.ReflectValue, *v) + } + } + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { + relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + if v, ok := values[idx].(*interface{}); ok { + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if v == nil { + continue + } + relValue.Set(reflect.New(relValue.Type().Elem())) + } + + if v == nil { + field.Set(relValue, nil) + } else { + field.Set(relValue, *v) + } + } + } + } + } + } } } } diff --git a/schema/field.go b/schema/field.go index 4f92aae7..8861a00d 100644 --- a/schema/field.go +++ b/schema/field.go @@ -402,34 +402,48 @@ func (field *Field) setupValuerAndSetter() { } } - recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { + fallbackSetter := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { if v == nil { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { reflectV := reflect.ValueOf(v) - - if reflectV.Type().ConvertibleTo(field.FieldType) { + if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV) + return + } else if reflectV.Type().ConvertibleTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) - } else if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - return setter(value, v) - } - } else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { + return + } else if field.FieldType.Kind() == reflect.Ptr { fieldValue := field.ReflectValueOf(value) - if fieldValue.IsNil() { - if v == nil { - return nil + + if reflectV.Type().AssignableTo(field.FieldType.Elem()) { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) } - fieldValue.Set(reflect.New(field.FieldType.Elem())) + fieldValue.Elem().Set(reflectV) + return + } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + + fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) + return + } + } + + if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + setter(value, v) } - fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) } else if reflectV.Kind() == reflect.Ptr { - return field.Set(value, reflectV.Elem().Interface()) + setter(value, reflectV.Elem().Interface()) } else { return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) } } - return err + + return } // Set @@ -441,8 +455,17 @@ func (field *Field) setupValuerAndSetter() { field.ReflectValueOf(value).SetBool(data) case *bool: field.ReflectValueOf(value).SetBool(*data) + case int64: + if data > 0 { + field.ReflectValueOf(value).SetBool(true) + } else { + field.ReflectValueOf(value).SetBool(false) + } + case string: + b, _ := strconv.ParseBool(data) + field.ReflectValueOf(value).SetBool(b) default: - return recoverFunc(value, v, field.Set) + return fallbackSetter(value, v, field.Set) } return nil } @@ -498,7 +521,7 @@ func (field *Field) setupValuerAndSetter() { field.ReflectValueOf(value).SetInt(0) } default: - return recoverFunc(value, v, field.Set) + return fallbackSetter(value, v, field.Set) } return err } @@ -538,7 +561,7 @@ func (field *Field) setupValuerAndSetter() { return err } default: - return recoverFunc(value, v, field.Set) + return fallbackSetter(value, v, field.Set) } return err } @@ -578,7 +601,7 @@ func (field *Field) setupValuerAndSetter() { return err } default: - return recoverFunc(value, v, field.Set) + return fallbackSetter(value, v, field.Set) } return err } @@ -594,7 +617,7 @@ func (field *Field) setupValuerAndSetter() { case float64, float32: field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) default: - return recoverFunc(value, v, field.Set) + return fallbackSetter(value, v, field.Set) } return err } @@ -615,7 +638,7 @@ func (field *Field) setupValuerAndSetter() { return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) } default: - return recoverFunc(value, v, field.Set) + return fallbackSetter(value, v, field.Set) } return nil } @@ -625,9 +648,6 @@ func (field *Field) setupValuerAndSetter() { case time.Time: fieldValue := field.ReflectValueOf(value) if fieldValue.IsNil() { - if v == nil { - return nil - } fieldValue.Set(reflect.New(field.FieldType.Elem())) } fieldValue.Elem().Set(reflect.ValueOf(v)) @@ -647,7 +667,7 @@ func (field *Field) setupValuerAndSetter() { return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) } default: - return recoverFunc(value, v, field.Set) + return fallbackSetter(value, v, field.Set) } return nil } @@ -655,53 +675,42 @@ func (field *Field) setupValuerAndSetter() { if _, ok := fieldValue.Interface().(sql.Scanner); ok { // struct scanner field.Set = func(value reflect.Value, v interface{}) (err error) { - if v == nil { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + reflectV := reflect.ValueOf(v) + if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) - } else if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) - } - } else { - err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) - } + err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) } return } } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner field.Set = func(value reflect.Value, v interface{}) (err error) { - if v == nil { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) + } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) - } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { - fieldValue := field.ReflectValueOf(value) - if fieldValue.IsNil() { - if v == nil { - return nil - } - fieldValue.Set(reflect.New(field.FieldType.Elem())) - } - fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) - } else if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) - } - } else { - err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) + fieldValue := field.ReflectValueOf(value) + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) } + err = fieldValue.Interface().(sql.Scanner).Scan(v) } return } } else { field.Set = func(value reflect.Value, v interface{}) (err error) { - return recoverFunc(value, v, field.Set) + return fallbackSetter(value, v, field.Set) } } } diff --git a/statement.go b/statement.go index ebd6e234..ffe3c75b 100644 --- a/statement.go +++ b/statement.go @@ -146,8 +146,16 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { case clause.Column, clause.Table: stmt.QuoteTo(writer, v) case clause.Expr: - writer.WriteString(v.SQL) - stmt.Vars = append(stmt.Vars, v.Vars...) + var varStr strings.Builder + var sql = v.SQL + for _, arg := range v.Vars { + stmt.Vars = append(stmt.Vars, arg) + stmt.DB.Dialector.BindVarTo(&varStr, stmt, arg) + sql = strings.Replace(sql, "?", varStr.String(), 1) + varStr.Reset() + } + + writer.WriteString(sql) case driver.Valuer: stmt.Vars = append(stmt.Vars, v) stmt.DB.Dialector.BindVarTo(writer, stmt, v) diff --git a/tests/main_test.go b/tests/main_test.go index ff293e6e..9d933caf 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -6,11 +6,6 @@ import ( . "gorm.io/gorm/utils/tests" ) -func TestMain(m *testing.M) { - RunMigrations() - m.Run() -} - func TestExceptionsWithInvalidSql(t *testing.T) { var columns []string if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go index 98f24daf..8f678b21 100644 --- a/tests/preload_suits_test.go +++ b/tests/preload_suits_test.go @@ -1299,7 +1299,6 @@ func TestNilPointerSlice(t *testing.T) { ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } diff --git a/tests/query_test.go b/tests/query_test.go index f6fb1081..18ffb3fb 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -435,3 +435,35 @@ func TestSubQueryWithHaving(t *testing.T) { t.Errorf("Two user group should be found, instead found %d", len(results)) } } + +func TestScanNullValue(t *testing.T) { + user := GetUser("scan_null_value", Config{}) + DB.Create(&user) + + if err := DB.Model(&user).Update("age", nil).Error; err != nil { + t.Fatalf("failed to update column age for struct, got error %v", err) + } + + var result User + if err := DB.First(&result, "id = ?", user.ID).Error; err != nil { + t.Fatalf("failed to query struct data with null age, got error %v", err) + } + + AssertEqual(t, result, user) + + users := []User{ + *GetUser("scan_null_value_for_slice_1", Config{}), + *GetUser("scan_null_value_for_slice_2", Config{}), + *GetUser("scan_null_value_for_slice_3", Config{}), + } + DB.Create(&users) + + if err := DB.Model(&users[0]).Update("age", nil).Error; err != nil { + t.Fatalf("failed to update column age for struct, got error %v", err) + } + + var results []User + if err := DB.Find(&results, "name like ?", "scan_null_value_for_slice%").Error; err != nil { + t.Fatalf("failed to query slice data with null age, got error %v", err) + } +} diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 95245804..92a28f3b 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -1,4 +1,4 @@ -dialects=("sqlite" "mysql" "postgres" "mssql") +dialects=("sqlite" "mysql" "postgres" "sqlserver") if [[ $(pwd) == *"gorm/tests"* ]]; then cd .. @@ -10,7 +10,7 @@ for dialect in "${dialects[@]}" ; do echo "testing ${dialect}..." race="" - if [ "$GORM_VERBOSE" = "" ] + if [ "$GORM_DIALECT" = "sqlserver" ] then race="-race" fi diff --git a/tests/tests_test.go b/tests/tests_test.go index 40816c3c..09850003 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -23,6 +23,8 @@ func init() { if DB, err = OpenTestConnection(); err != nil { log.Printf("failed to connect database, got error %v\n", err) os.Exit(1) + } else { + RunMigrations() } } diff --git a/tests/update_test.go b/tests/update_test.go index 524e9ea6..220d3e76 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -155,12 +155,14 @@ func TestUpdates(t *testing.T) { AssertEqual(t, user2.UpdatedAt, user3.UpdatedAt) // update with gorm exprs - DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}) + if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { + t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) + } var user4 User DB.First(&user4, user3.ID) user3.Age += 100 - AssertEqual(t, user4.UpdatedAt, user3.UpdatedAt) + AssertObjEqual(t, user4, user3, "UpdatedAt", "Age") } func TestUpdateColumn(t *testing.T) { diff --git a/tests/upsert_test.go b/tests/upsert_test.go index 412be305..f132a7da 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -121,6 +121,11 @@ func TestFindOrCreate(t *testing.T) { updatedAt1 := user4.UpdatedAt DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4) + + if user4.Age != 55 { + t.Errorf("Failed to set change to 55, got %v", user4.Age) + } + if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) { t.Errorf("UpdateAt should be changed when update values with assign") } From b32658358cd0bd5ee76f1229dfaa4613c0045fee Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 3 Jun 2020 08:44:13 +0800 Subject: [PATCH 392/881] Fix can't scan null value into normal data types --- scan.go | 94 ++++++++++++++++++++++--------------------------- schema/field.go | 37 ++++++++++++++----- tests/go.mod | 4 +-- 3 files changed, 73 insertions(+), 62 deletions(-) diff --git a/scan.go b/scan.go index 14a4699d..acba4e9f 100644 --- a/scan.go +++ b/scan.go @@ -87,6 +87,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + fields[idx] = field joinFields[idx] = [2]*schema.Field{rel.Field, field} continue } @@ -98,50 +99,39 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } for initialized || rows.Next() { - for idx := range columns { - values[idx] = new(interface{}) - } - initialized = false db.RowsAffected++ elem := reflect.New(reflectValueType).Elem() - if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 { // pluck values[0] = elem.Addr().Interface() db.AddError(rows.Scan(values...)) } else { - db.AddError(rows.Scan(values...)) - for idx, field := range fields { - if v, ok := values[idx].(*interface{}); ok { - if field != nil { - if v == nil { - field.Set(elem, v) - } else { - field.Set(elem, *v) - } - } else if joinFields[idx][0] != nil { - relValue := joinFields[idx][0].ReflectValueOf(elem) - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if v == nil { - continue - } - relValue.Set(reflect.New(relValue.Type().Elem())) - } - - if v == nil { - joinFields[idx][1].Set(relValue, nil) - } else { - joinFields[idx][1].Set(relValue, *v) - } - } + if field != nil { + values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() } } - for idx := range columns { - values[idx] = new(interface{}) + db.AddError(rows.Scan(values...)) + + for idx, field := range fields { + if joinFields[idx][0] != nil { + value := reflect.ValueOf(values[idx]).Elem() + relValue := joinFields[idx][0].ReflectValueOf(elem) + + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value.IsNil() { + continue + } + relValue.Set(reflect.New(relValue.Type().Elem())) + } + + field.Set(relValue, values[idx]) + } else if field != nil { + field.Set(elem, values[idx]) + } } } @@ -153,8 +143,20 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } case reflect.Struct: if initialized || rows.Next() { - for idx := range columns { - values[idx] = new(interface{}) + for idx, column := range columns { + if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { + values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + continue + } + } + values[idx] = &sql.RawBytes{} + } else { + values[idx] = &sql.RawBytes{} + } } db.RowsAffected++ @@ -162,31 +164,21 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { for idx, column := range columns { if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { - if v, ok := values[idx].(*interface{}); ok { - if v == nil { - field.Set(db.Statement.ReflectValue, v) - } else { - field.Set(db.Statement.ReflectValue, *v) - } - } + field.Set(db.Statement.ReflectValue, values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - if v, ok := values[idx].(*interface{}); ok { - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if v == nil { - continue - } - relValue.Set(reflect.New(relValue.Type().Elem())) - } + value := reflect.ValueOf(values[idx]).Elem() - if v == nil { - field.Set(relValue, nil) - } else { - field.Set(relValue, *v) + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value.IsNil() { + continue } + relValue.Set(reflect.New(relValue.Type().Elem())) } + + field.Set(relValue, values[idx]) } } } diff --git a/schema/field.go b/schema/field.go index 8861a00d..a27fdd87 100644 --- a/schema/field.go +++ b/schema/field.go @@ -247,7 +247,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int)) { + if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if strings.ToUpper(v) == "NANO" { field.AutoCreateTime = UnixNanosecond } else { @@ -255,7 +255,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int)) { + if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if strings.ToUpper(v) == "NANO" { field.AutoUpdateTime = UnixNanosecond } else { @@ -407,6 +407,7 @@ func (field *Field) setupValuerAndSetter() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { reflectV := reflect.ValueOf(v) + if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV) return @@ -437,7 +438,11 @@ func (field *Field) setupValuerAndSetter() { setter(value, v) } } else if reflectV.Kind() == reflect.Ptr { - setter(value, reflectV.Elem().Interface()) + if reflectV.IsNil() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + setter(value, reflectV.Elem().Interface()) + } } else { return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) } @@ -680,8 +685,14 @@ func (field *Field) setupValuerAndSetter() { } reflectV := reflect.ValueOf(v) - if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { + if !reflectV.IsValid() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Kind() == reflect.Ptr { + if reflectV.Elem().IsNil() || !reflectV.Elem().IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + return field.Set(value, reflectV.Elem().Interface()) + } } else { err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) } @@ -691,14 +702,22 @@ func (field *Field) setupValuerAndSetter() { // pointer scanner field.Set = func(value reflect.Value, v interface{}) (err error) { if valuer, ok := v.(driver.Valuer); ok { - v, _ = valuer.Value() + if valuer == nil { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + v, _ = valuer.Value() + } } reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) - } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV) + } else if reflectV.Kind() == reflect.Ptr { + if reflectV.IsNil() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + field.Set(value, reflectV.Elem().Interface()) + } } else { fieldValue := field.ReflectValueOf(value) if fieldValue.IsNil() { diff --git a/tests/go.mod b/tests/go.mod index 3954c442..3401b9b2 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,8 +7,8 @@ require ( gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0 gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 gorm.io/driver/sqlite v0.0.0-20200602015323-284b563f81c8 - gorm.io/driver/sqlserver v0.0.0-20200602015206-ef9f739c6a30 - gorm.io/gorm v1.9.12 + gorm.io/driver/sqlserver v0.0.0-20200602144728-79c224f6c1a2 + gorm.io/gorm v0.0.0-00010101000000-000000000000 ) replace gorm.io/gorm => ../ From 9934207c42df1d2e587f0523b8cefefe17212b30 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 3 Jun 2020 14:39:36 +0800 Subject: [PATCH 393/881] Fix logger panic on windows --- utils/utils.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/utils/utils.go b/utils/utils.go index e177999e..ce42b218 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -5,25 +5,24 @@ import ( "fmt" "path/filepath" "reflect" - "regexp" "runtime" "strconv" "strings" "unicode" ) -var goSrcRegexp, goTestRegexp *regexp.Regexp +var gormSourceDir string func init() { _, file, _, _ := runtime.Caller(0) - goSrcRegexp = regexp.MustCompile(filepath.Join(filepath.Dir(filepath.Dir(file)), ".*.go")) - goTestRegexp = regexp.MustCompile(filepath.Join(filepath.Dir(filepath.Dir(file)), ".*test.go")) + gormSourceDir = filepath.Dir(filepath.Dir(file)) } func FileWithLineNum() string { for i := 2; i < 15; i++ { _, file, line, ok := runtime.Caller(i) - if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) { + + if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) { return fmt.Sprintf("%v:%v", file, line) } } From c8e7878b3ed2265d2255b55e93cd49101b3f6ee8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 5 Jun 2020 10:08:22 +0800 Subject: [PATCH 394/881] Add PrepareStmt support --- finisher_api.go | 44 +++++++++++-------- gorm.go | 49 ++++++++++++++------- interfaces.go | 8 +++- prepare_stmt.go | 92 +++++++++++++++++++++++++++++++++++++++ tests/transaction_test.go | 3 +- 5 files changed, 159 insertions(+), 37 deletions(-) create mode 100644 prepare_stmt.go diff --git a/finisher_api.go b/finisher_api.go index b97f2301..e493b406 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -310,28 +310,36 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } // Begin begins a transaction -func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { - tx = db.getInstance() - if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { - var opt *sql.TxOptions - var err error - if len(opts) > 0 { - opt = opts[0] - } +func (db *DB) Begin(opts ...*sql.TxOptions) *DB { + var ( + tx = db.getInstance() + opt *sql.TxOptions + err error + ) - if tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt); err != nil { - tx.AddError(err) - } - } else { - tx.AddError(ErrInvalidTransaction) + if len(opts) > 0 { + opt = opts[0] } - return + + if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + } else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + } else { + err = ErrInvalidTransaction + } + + if err != nil { + tx.AddError(err) + } + + return tx } // Commit commit a transaction func (db *DB) Commit() *DB { - if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { - db.AddError(comminter.Commit()) + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + db.AddError(committer.Commit()) } else { db.AddError(ErrInvalidTransaction) } @@ -340,8 +348,8 @@ func (db *DB) Commit() *DB { // Rollback rollback a transaction func (db *DB) Rollback() *DB { - if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { - db.AddError(comminter.Rollback()) + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + db.AddError(committer.Rollback()) } else { db.AddError(ErrInvalidTransaction) } diff --git a/gorm.go b/gorm.go index 8a801d68..e6a28635 100644 --- a/gorm.go +++ b/gorm.go @@ -2,6 +2,7 @@ package gorm import ( "context" + "database/sql" "fmt" "sync" "time" @@ -25,6 +26,9 @@ type Config struct { // DryRun generate sql without execute DryRun bool + // PrepareStmt executes the given query in cached statement + PrepareStmt bool + // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder // ConnPool db conn pool @@ -48,6 +52,7 @@ type DB struct { // Session session config when create session with Session() method type Session struct { DryRun bool + PrepareStmt bool WithConditions bool Context context.Context Logger logger.Interface @@ -92,6 +97,22 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { err = dialector.Initialize(db) } + if config.PrepareStmt { + db.ConnPool = &PreparedStmtDB{ + ConnPool: db.ConnPool, + stmts: map[string]*sql.Stmt{}, + } + } + + if db.Statement == nil { + db.Statement = &Statement{ + DB: db, + ConnPool: db.ConnPool, + Context: context.Background(), + Clauses: map[string]clause.Clause{}, + } + } + if err == nil { if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok { err = pinger.Ping() @@ -131,6 +152,13 @@ func (db *DB) Session(config *Session) *DB { tx.Statement.Context = config.Context } + if config.PrepareStmt { + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + stmts: map[string]*sql.Stmt{}, + } + } + if config.WithConditions { tx.clone = 3 } @@ -256,6 +284,12 @@ func (db *DB) getInstance() *DB { switch db.clone { case 1: // clone with new statement + tx.Statement = &Statement{ + DB: tx, + ConnPool: db.Statement.ConnPool, + Context: db.Statement.Context, + Clauses: map[string]clause.Clause{}, + } case 2: // with old statement, generate new statement for future call, used to pass to callbacks db.clone = 1 tx.Statement = db.Statement @@ -266,21 +300,6 @@ func (db *DB) getInstance() *DB { } } - if tx.Statement == nil { - tx.Statement = &Statement{ - DB: tx, - Clauses: map[string]clause.Clause{}, - } - } - - if db.Statement != nil { - tx.Statement.Context = db.Statement.Context - tx.Statement.ConnPool = db.Statement.ConnPool - } else { - tx.Statement.Context = context.Background() - tx.Statement.ConnPool = db.ConnPool - } - return tx } diff --git a/interfaces.go b/interfaces.go index 6d9c6212..4be54565 100644 --- a/interfaces.go +++ b/interfaces.go @@ -21,8 +21,8 @@ type Dialector interface { // ConnPool db conns pool interface type ConnPool interface { - ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row } @@ -31,7 +31,11 @@ type TxBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } -type TxCommiter interface { +type ConnPoolBeginner interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) +} + +type TxCommitter interface { Commit() error Rollback() error } diff --git a/prepare_stmt.go b/prepare_stmt.go new file mode 100644 index 00000000..bc11abbf --- /dev/null +++ b/prepare_stmt.go @@ -0,0 +1,92 @@ +package gorm + +import ( + "context" + "database/sql" + "sync" +) + +type PreparedStmtDB struct { + stmts map[string]*sql.Stmt + mux sync.RWMutex + ConnPool +} + +func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { + db.mux.RLock() + if stmt, ok := db.stmts[query]; ok { + db.mux.RUnlock() + return stmt, nil + } + db.mux.RUnlock() + + db.mux.Lock() + stmt, err := db.ConnPool.PrepareContext(context.Background(), query) + if err == nil { + db.stmts[query] = stmt + } + db.mux.Unlock() + + return stmt, err +} + +func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { + if beginner, ok := db.ConnPool.(TxBeginner); ok { + tx, err := beginner.BeginTx(ctx, opt) + return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err + } + return nil, ErrInvalidTransaction +} + +func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + stmt, err := db.prepare(query) + if err == nil { + return stmt.ExecContext(ctx, args...) + } + return nil, err +} + +func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + stmt, err := db.prepare(query) + if err == nil { + return stmt.QueryContext(ctx, args...) + } + return nil, err +} + +func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + stmt, err := db.prepare(query) + if err == nil { + return stmt.QueryRowContext(ctx, args...) + } + return &sql.Row{} +} + +type PreparedStmtTX struct { + *sql.Tx + PreparedStmtDB *PreparedStmtDB +} + +func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + stmt, err := tx.PreparedStmtDB.prepare(query) + if err == nil { + return tx.Tx.Stmt(stmt).ExecContext(ctx, args...) + } + return nil, err +} + +func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + stmt, err := tx.PreparedStmtDB.prepare(query) + if err == nil { + return tx.Tx.Stmt(stmt).QueryContext(ctx, args...) + } + return nil, err +} + +func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + stmt, err := tx.PreparedStmtDB.prepare(query) + if err == nil { + return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...) + } + return &sql.Row{} +} diff --git a/tests/transaction_test.go b/tests/transaction_test.go index b810e3bb..0c04e2ed 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -1,7 +1,6 @@ package tests_test import ( - "database/sql" "errors" "testing" @@ -21,7 +20,7 @@ func TestTransaction(t *testing.T) { t.Fatalf("Should find saved record, but got %v", err) } - if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil { + if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil { t.Fatalf("Should return the underlying sql.Tx") } From d50879cc280520f944a965577ce3198cb1933161 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 5 Jun 2020 19:18:22 +0800 Subject: [PATCH 395/881] Add field permission test --- callbacks/update.go | 40 +++++--- schema/field.go | 64 ++++++------ schema/field_test.go | 12 ++- schema/schema_helper_test.go | 12 ++- tests/customize_column_test.go | 56 ----------- tests/customize_field_test.go | 172 +++++++++++++++++++++++++++++++++ tests/go.mod | 2 +- tests/query_test.go | 47 ++++++--- tests/sql_builder_test.go | 16 +++ 9 files changed, 300 insertions(+), 121 deletions(-) delete mode 100644 tests/customize_column_test.go create mode 100644 tests/customize_field_test.go diff --git a/callbacks/update.go b/callbacks/update.go index 9b2e924b..2589370f 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -10,7 +10,7 @@ import ( ) func SetupUpdateReflectValue(db *gorm.DB) { - if db.Error == nil { + if db.Error == nil && db.Statement.Schema != nil { if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest { db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) for db.Statement.ReflectValue.Kind() == reflect.Ptr { @@ -172,26 +172,38 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { sort.Strings(keys) for _, k := range keys { - if field := stmt.Schema.LookUpField(k); field != nil { - if field.DBName != "" { - if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + if field.DBName != "" { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) + assignValue(field, value[k]) + } + } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { assignValue(field, value[k]) } - } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { - assignValue(field, value[k]) + continue } - } else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + } + + if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]}) } } - if !stmt.DisableUpdateTime { + if !stmt.DisableUpdateTime && stmt.Schema != nil { for _, field := range stmt.Schema.FieldsByDBName { if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { now := stmt.DB.NowFunc() - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) assignValue(field, now) + + if field.AutoUpdateTime == schema.UnixNanosecond { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) + } else if field.DataType == schema.Time { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) + } else { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) + } } } } @@ -205,7 +217,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { value, isZero := field.ValueOf(updatingValue) if !stmt.DisableUpdateTime { if field.AutoUpdateTime > 0 { - value = stmt.DB.NowFunc() + if field.AutoUpdateTime == schema.UnixNanosecond { + value = stmt.DB.NowFunc().UnixNano() + } else if field.DataType == schema.Time { + value = stmt.DB.NowFunc() + } else { + value = stmt.DB.NowFunc().Unix() + } isZero = false } } diff --git a/schema/field.go b/schema/field.go index a27fdd87..854ec520 100644 --- a/schema/field.go +++ b/schema/field.go @@ -133,33 +133,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - // setup permission - if _, ok := field.TagSettings["-"]; ok { - field.Creatable = false - field.Updatable = false - field.Readable = false - } - - if v, ok := field.TagSettings["<-"]; ok { - if v != "<-" { - if !strings.Contains(v, "create") { - field.Creatable = false - } - - if !strings.Contains(v, "update") { - field.Updatable = false - } - } - - field.Readable = false - } - - if _, ok := field.TagSettings["->"]; ok { - field.Creatable = false - field.Updatable = false - field.Readable = true - } - if dbName, ok := field.TagSettings["COLUMN"]; ok { field.DBName = dbName } @@ -276,6 +249,39 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + // setup permission + if _, ok := field.TagSettings["-"]; ok { + field.Creatable = false + field.Updatable = false + field.Readable = false + field.DataType = "" + } + + if v, ok := field.TagSettings["->"]; ok { + field.Creatable = false + field.Updatable = false + if strings.ToLower(v) == "false" { + field.Readable = false + } else { + field.Readable = true + } + } + + if v, ok := field.TagSettings["<-"]; ok { + field.Creatable = true + field.Updatable = true + + if v != "<-" { + if !strings.Contains(v, "create") { + field.Creatable = false + } + + if !strings.Contains(v, "update") { + field.Updatable = false + } + } + } + if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { var err error field.Creatable = false @@ -510,14 +516,14 @@ func (field *Field) setupValuerAndSetter() { return err } case time.Time: - if field.AutoCreateTime == UnixNanosecond { + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(value).SetInt(data.UnixNano()) } else { field.ReflectValueOf(value).SetInt(data.Unix()) } case *time.Time: if data != nil { - if field.AutoCreateTime == UnixNanosecond { + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(value).SetInt(data.UnixNano()) } else { field.ReflectValueOf(value).SetInt(data.Unix()) diff --git a/schema/field_test.go b/schema/field_test.go index fe88891f..cc4b53fc 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -225,6 +225,7 @@ type UserWithPermissionControl struct { Name4 string `gorm:"<-:create"` Name5 string `gorm:"<-:update"` Name6 string `gorm:"<-:create,update"` + Name7 string `gorm:"->:false;<-:create,update"` } func TestParseFieldWithPermission(t *testing.T) { @@ -235,12 +236,13 @@ func TestParseFieldWithPermission(t *testing.T) { fields := []schema.Field{ {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true}, - {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String, Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false}, + {Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false}, {Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true}, - {Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: false}, - {Name: "Name4", DBName: "name4", BindNames: []string{"Name4"}, DataType: schema.String, Tag: `gorm:"<-:create"`, Creatable: true, Updatable: false, Readable: false}, - {Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: false}, - {Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: false}, + {Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: true}, + {Name: "Name4", DBName: "name4", BindNames: []string{"Name4"}, DataType: schema.String, Tag: `gorm:"<-:create"`, Creatable: true, Updatable: false, Readable: true}, + {Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: true}, + {Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: true}, + {Name: "Name7", DBName: "name7", BindNames: []string{"Name7"}, DataType: schema.String, Tag: `gorm:"->:false;<-:create,update"`, Creatable: true, Updatable: true, Readable: false}, } for _, f := range fields { diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index f2ed4145..d2e68536 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -54,13 +54,17 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* } else { tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") - if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { - t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + if f.DBName != "" { + if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { + t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + } } for _, name := range []string{f.DBName, f.Name} { - if field := s.LookUpField(name); field == nil || parsedField != field { - t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + if name != "" { + if field := s.LookUpField(name); field == nil || parsedField != field { + t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + } } } diff --git a/tests/customize_column_test.go b/tests/customize_column_test.go deleted file mode 100644 index 98dea494..00000000 --- a/tests/customize_column_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package tests_test - -import ( - "testing" - "time" -) - -func TestCustomizeColumn(t *testing.T) { - type CustomizeColumn struct { - ID int64 `gorm:"column:mapped_id; primary_key:yes"` - Name string `gorm:"column:mapped_name"` - Date *time.Time `gorm:"column:mapped_time"` - } - - DB.Migrator().DropTable(&CustomizeColumn{}) - DB.AutoMigrate(&CustomizeColumn{}) - - expected := "foo" - now := time.Now() - cc := CustomizeColumn{ID: 666, Name: expected, Date: &now} - - if count := DB.Create(&cc).RowsAffected; count != 1 { - t.Error("There should be one record be affected when create record") - } - - var cc1 CustomizeColumn - DB.First(&cc1, "mapped_name = ?", "foo") - - if cc1.Name != expected { - t.Errorf("Failed to query CustomizeColumn") - } - - cc.Name = "bar" - DB.Save(&cc) - - var cc2 CustomizeColumn - DB.First(&cc2, "mapped_id = ?", 666) - if cc2.Name != "bar" { - t.Errorf("Failed to query CustomizeColumn") - } -} - -func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { - // Make sure an ignored field does not interfere with another field's custom - // column name that matches the ignored field. - type CustomColumnAndIgnoredFieldClash struct { - Body string `gorm:"-"` - RawBody string `gorm:"column:body"` - } - - DB.Migrator().DropTable(&CustomColumnAndIgnoredFieldClash{}) - - if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}); err != nil { - t.Errorf("Should not raise error: %v", err) - } -} diff --git a/tests/customize_field_test.go b/tests/customize_field_test.go new file mode 100644 index 00000000..910fa6ae --- /dev/null +++ b/tests/customize_field_test.go @@ -0,0 +1,172 @@ +package tests_test + +import ( + "testing" + "time" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestCustomizeColumn(t *testing.T) { + type CustomizeColumn struct { + ID int64 `gorm:"column:mapped_id; primary_key:yes"` + Name string `gorm:"column:mapped_name"` + Date *time.Time `gorm:"column:mapped_time"` + } + + DB.Migrator().DropTable(&CustomizeColumn{}) + DB.AutoMigrate(&CustomizeColumn{}) + + expected := "foo" + now := time.Now() + cc := CustomizeColumn{ID: 666, Name: expected, Date: &now} + + if count := DB.Create(&cc).RowsAffected; count != 1 { + t.Error("There should be one record be affected when create record") + } + + var cc1 CustomizeColumn + DB.First(&cc1, "mapped_name = ?", "foo") + + if cc1.Name != expected { + t.Errorf("Failed to query CustomizeColumn") + } + + cc.Name = "bar" + DB.Save(&cc) + + var cc2 CustomizeColumn + DB.First(&cc2, "mapped_id = ?", 666) + if cc2.Name != "bar" { + t.Errorf("Failed to query CustomizeColumn") + } +} + +func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { + // Make sure an ignored field does not interfere with another field's custom + // column name that matches the ignored field. + type CustomColumnAndIgnoredFieldClash struct { + Body string `gorm:"-"` + RawBody string `gorm:"column:body"` + } + + DB.Migrator().DropTable(&CustomColumnAndIgnoredFieldClash{}) + + if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}); err != nil { + t.Errorf("Should not raise error: %v", err) + } +} + +func TestCustomizeField(t *testing.T) { + type CustomizeFieldStruct struct { + gorm.Model + Name string + FieldAllowCreate string `gorm:"<-:create"` + FieldAllowUpdate string `gorm:"<-:update"` + FieldAllowSave string `gorm:"<-"` + FieldAllowSave2 string `gorm:"<-:create,update"` + FieldAllowSave3 string `gorm:"->:false;<-:create"` + FieldReadonly string `gorm:"->"` + FieldIgnore string `gorm:"-"` + AutoUnixCreateTime int64 `gorm:"autocreatetime"` + AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"` + AutoUnixUpdateTime int64 `gorm:"autoupdatetime"` + AutoUnixNanoUpdateTime int64 `gorm:"autoupdatetime:nano"` + } + + DB.Migrator().DropTable(&CustomizeFieldStruct{}) + + if err := DB.AutoMigrate(&CustomizeFieldStruct{}); err != nil { + t.Errorf("Failed to migrate, got error: %v", err) + } + + if DB.Migrator().HasColumn(&CustomizeFieldStruct{}, "FieldIgnore") { + t.Errorf("FieldIgnore should not be created") + } + + if DB.Migrator().HasColumn(&CustomizeFieldStruct{}, "field_ignore") { + t.Errorf("FieldIgnore should not be created") + } + + generateStruct := func(name string) *CustomizeFieldStruct { + return &CustomizeFieldStruct{ + Name: name, + FieldAllowCreate: name + "_allow_create", + FieldAllowUpdate: name + "_allow_update", + FieldAllowSave: name + "_allow_save", + FieldAllowSave2: name + "_allow_save2", + FieldAllowSave3: name + "_allow_save3", + FieldReadonly: name + "_allow_readonly", + FieldIgnore: name + "_allow_ignore", + } + } + + create := generateStruct("create") + DB.Create(&create) + + var result CustomizeFieldStruct + DB.Find(&result, "name = ?", "create") + + AssertObjEqual(t, result, create, "Name", "FieldAllowCreate", "FieldAllowSave", "FieldAllowSave2") + + if result.FieldAllowUpdate != "" || result.FieldReadonly != "" || result.FieldIgnore != "" || result.FieldAllowSave3 != "" { + t.Fatalf("invalid result: %#v", result) + } + + if result.AutoUnixCreateTime != result.AutoUnixUpdateTime || result.AutoUnixCreateTime == 0 { + t.Fatalf("invalid create/update unix time: %#v", result) + } + + if result.AutoUnixNanoCreateTime != result.AutoUnixNanoUpdateTime || result.AutoUnixNanoCreateTime == 0 || result.AutoUnixNanoCreateTime/result.AutoUnixCreateTime < 1e6 { + t.Fatalf("invalid create/update unix nano time: %#v", result) + } + + result.FieldAllowUpdate = "field_allow_update_updated" + result.FieldReadonly = "field_readonly_updated" + result.FieldIgnore = "field_ignore_updated" + DB.Save(&result) + + var result2 CustomizeFieldStruct + DB.Find(&result2, "name = ?", "create") + + if result2.FieldAllowUpdate != result.FieldAllowUpdate || result2.FieldReadonly != "" || result2.FieldIgnore != "" { + t.Fatalf("invalid updated result: %#v", result2) + } + + if err := DB.Table("customize_field_structs").Where("1 = 1").UpdateColumn("field_readonly", "readonly").Error; err != nil { + t.Fatalf("failed to update field_readonly column") + } + + var result3 CustomizeFieldStruct + DB.Find(&result3, "name = ?", "create") + + if result3.FieldReadonly != "readonly" { + t.Fatalf("invalid updated result: %#v", result3) + } + + var result4 CustomizeFieldStruct + if err := DB.First(&result4, "field_allow_save3 = ?", create.FieldAllowSave3).Error; err != nil { + t.Fatalf("failed to query with inserted field, got error %v", err) + } + + AssertEqual(t, result3, result4) + + createWithDefaultTime := generateStruct("create_with_default_time") + createWithDefaultTime.AutoUnixCreateTime = 100 + createWithDefaultTime.AutoUnixUpdateTime = 100 + createWithDefaultTime.AutoUnixNanoCreateTime = 100 + createWithDefaultTime.AutoUnixNanoUpdateTime = 100 + DB.Create(&createWithDefaultTime) + + var createWithDefaultTimeResult CustomizeFieldStruct + DB.Find(&createWithDefaultTimeResult, "name = ?", createWithDefaultTime.Name) + + if createWithDefaultTimeResult.AutoUnixCreateTime != createWithDefaultTimeResult.AutoUnixUpdateTime || createWithDefaultTimeResult.AutoUnixCreateTime != 100 { + t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult) + } + + if createWithDefaultTimeResult.AutoUnixNanoCreateTime != createWithDefaultTimeResult.AutoUnixNanoUpdateTime || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 { + t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult) + } +} diff --git a/tests/go.mod b/tests/go.mod index 3401b9b2..de58a0de 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,7 +6,7 @@ require ( github.com/jinzhu/now v1.1.1 gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0 gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 - gorm.io/driver/sqlite v0.0.0-20200602015323-284b563f81c8 + gorm.io/driver/sqlite v1.0.0 gorm.io/driver/sqlserver v0.0.0-20200602144728-79c224f6c1a2 gorm.io/gorm v0.0.0-00010101000000-000000000000 ) diff --git a/tests/query_test.go b/tests/query_test.go index 18ffb3fb..66413b3b 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -67,22 +67,39 @@ func TestFind(t *testing.T) { } }) - var allMap = []map[string]interface{}{} - if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { - t.Errorf("errors happened when query first: %v", err) - } else { - for idx, user := range users { - t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { - for _, name := range []string{"Name", "Age", "Birthday"} { - t.Run(name, func(t *testing.T) { - dbName := DB.NamingStrategy.ColumnName("", name) - reflectValue := reflect.Indirect(reflect.ValueOf(user)) - AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) - }) - } - }) + t.Run("FirstPtrMap", func(t *testing.T) { + var first = map[string]interface{}{} + if err := DB.Model(&User{}).Where("name = ?", "find").First(&first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) + AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) + }) + } } - } + }) + + t.Run("FirstSliceOfMap", func(t *testing.T) { + var allMap = []map[string]interface{}{} + if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for idx, user := range users { + t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + reflectValue := reflect.Indirect(reflect.ValueOf(user)) + AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) + }) + } + }) + } + } + }) } func TestFillSmallerStruct(t *testing.T) { diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 278a5b96..a60514c9 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -122,3 +122,19 @@ func TestQueryRaw(t *testing.T) { DB.Raw("select * from users WHERE id = ?", users[1].ID).First(&user) CheckUser(t, user, *users[1]) } + +func TestDryRun(t *testing.T) { + user := *GetUser("dry-run", Config{}) + + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Create(&user).Statement + if stmt.SQL.String() == "" || len(stmt.Vars) != 9 { + t.Errorf("Failed to generate sql, got %v", stmt.SQL.String()) + } + + stmt2 := dryRunDB.Find(&user, "id = ?", user.ID).Statement + if stmt2.SQL.String() == "" || len(stmt2.Vars) != 1 { + t.Errorf("Failed to generate sql, got %v", stmt2.SQL.String()) + } +} From eda2f023b0d0ed31666185645cdbb82c714b8548 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 5 Jun 2020 19:19:08 +0800 Subject: [PATCH 396/881] Add Distinct support --- callbacks/query.go | 2 +- chainable_api.go | 10 +++++++ clause/select.go | 5 ++++ errors.go | 2 ++ finisher_api.go | 38 +++++++++++++++++++++----- statement.go | 2 ++ tests/distinct_test.go | 60 ++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 111 insertions(+), 8 deletions(-) create mode 100644 tests/distinct_test.go diff --git a/callbacks/query.go b/callbacks/query.go index b3293576..16202187 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -37,7 +37,7 @@ func Query(db *gorm.DB) { } func BuildQuerySQL(db *gorm.DB) { - clauseSelect := clause.Select{} + clauseSelect := clause.Select{Distinct: db.Statement.Distinct} if db.Statement.ReflectValue.Kind() == reflect.Struct { var conds []clause.Expression diff --git a/chainable_api.go b/chainable_api.go index b1ae3132..6c5a6f77 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -45,6 +45,16 @@ func (db *DB) Table(name string) (tx *DB) { return } +// Distinct specify distinct fields that you want querying +func (db *DB) Distinct(args ...interface{}) (tx *DB) { + tx = db + if len(args) > 0 { + tx = tx.Select(args[0], args[1:]...) + } + tx.Statement.Distinct = true + return tx +} + // Select specify fields that you want when querying, creating, updating func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() diff --git a/clause/select.go b/clause/select.go index 20b17e07..a1b77de8 100644 --- a/clause/select.go +++ b/clause/select.go @@ -2,6 +2,7 @@ package clause // Select select attrs when querying, updating, creating type Select struct { + Distinct bool Columns []Column Expression Expression } @@ -12,6 +13,10 @@ func (s Select) Name() string { func (s Select) Build(builder Builder) { if len(s.Columns) > 0 { + if s.Distinct { + builder.WriteString(" DISTINCT ") + } + for idx, column := range s.Columns { if idx > 0 { builder.WriteByte(',') diff --git a/errors.go b/errors.go index 82f24df2..ff06f24e 100644 --- a/errors.go +++ b/errors.go @@ -23,4 +23,6 @@ var ( ErrPtrStructSupported = errors.New("only ptr of struct supported") // ErrorPrimaryKeyRequired primary keys required ErrorPrimaryKeyRequired = errors.New("primary key required") + // ErrorModelValueRequired model value required + ErrorModelValueRequired = errors.New("model value required") ) diff --git a/finisher_api.go b/finisher_api.go index e493b406..d6de7aa3 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -233,13 +233,24 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Count(count *int64) (tx *DB) { tx = db.getInstance() - if s, ok := tx.Statement.Clauses["SELECT"].Expression.(clause.Select); !ok || len(s.Columns) == 0 { - tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) - } - if tx.Statement.Model == nil { tx.Statement.Model = tx.Statement.Dest } + + if len(tx.Statement.Selects) == 0 { + tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) + } else if len(tx.Statement.Selects) == 1 && !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") { + column := tx.Statement.Selects[0] + if tx.Statement.Parse(tx.Statement.Model) == nil { + if f := tx.Statement.Schema.LookUpField(column); f != nil { + column = f.DBName + } + } + tx.Statement.AddClause(clause.Select{ + Expression: clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: column}}}, + }) + } + tx.Statement.Dest = count tx.callbacks.Query().Execute(tx) if db.RowsAffected != 1 { @@ -273,9 +284,22 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { // db.Find(&users).Pluck("age", &ages) func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClauseIfNotExists(clause.Select{Columns: []clause.Column{{Name: column}}}) - tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) + if tx.Statement.Model != nil { + if tx.Statement.Parse(tx.Statement.Model) == nil { + if f := tx.Statement.Schema.LookUpField(column); f != nil { + column = f.DBName + } + } + + tx.Statement.AddClauseIfNotExists(clause.Select{ + Distinct: tx.Statement.Distinct, + Columns: []clause.Column{{Name: column}}, + }) + tx.Statement.Dest = dest + tx.callbacks.Query().Execute(tx) + } else { + tx.AddError(ErrorModelValueRequired) + } return } diff --git a/statement.go b/statement.go index ffe3c75b..755d93ac 100644 --- a/statement.go +++ b/statement.go @@ -23,6 +23,7 @@ type Statement struct { Dest interface{} ReflectValue reflect.Value Clauses map[string]clause.Clause + Distinct bool Selects []string // selected columns Omits []string // omit columns Joins map[string][]interface{} @@ -331,6 +332,7 @@ func (stmt *Statement) clone() *Statement { Dest: stmt.Dest, ReflectValue: stmt.ReflectValue, Clauses: map[string]clause.Clause{}, + Distinct: stmt.Distinct, Selects: stmt.Selects, Omits: stmt.Omits, Joins: map[string][]interface{}{}, diff --git a/tests/distinct_test.go b/tests/distinct_test.go new file mode 100644 index 00000000..f5a969a8 --- /dev/null +++ b/tests/distinct_test.go @@ -0,0 +1,60 @@ +package tests_test + +import ( + "testing" + + . "gorm.io/gorm/utils/tests" +) + +func TestDistinct(t *testing.T) { + var users = []User{ + *GetUser("distinct", Config{}), + *GetUser("distinct", Config{}), + *GetUser("distinct", Config{}), + *GetUser("distinct-2", Config{}), + *GetUser("distinct-3", Config{}), + } + users[0].Age = 20 + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create users: %v", err) + } + + var names []string + DB.Model(&User{}).Where("name like ?", "distinct%").Order("name").Pluck("Name", &names) + AssertEqual(t, names, []string{"distinct", "distinct", "distinct", "distinct-2", "distinct-3"}) + + var names1 []string + DB.Model(&User{}).Where("name like ?", "distinct%").Distinct().Order("name").Pluck("Name", &names1) + + AssertEqual(t, names1, []string{"distinct", "distinct-2", "distinct-3"}) + + var results []User + if err := DB.Distinct("name", "age").Where("name like ?", "distinct%").Order("name, age desc").Find(&results).Error; err != nil { + t.Errorf("failed to query users, got error: %v", err) + } + + expects := []User{ + {Name: "distinct", Age: 20}, + {Name: "distinct", Age: 18}, + {Name: "distinct-2", Age: 18}, + {Name: "distinct-3", Age: 18}, + } + + if len(results) != 4 { + t.Fatalf("invalid results length found, expects: %v, got %v", len(expects), len(results)) + } + + for idx, expect := range expects { + AssertObjEqual(t, results[idx], expect, "Name", "Age") + } + + var count int64 + if err := DB.Model(&User{}).Where("name like ?", "distinct%").Count(&count).Error; err != nil || count != 5 { + t.Errorf("failed to query users count, got error: %v, count: %v", err, count) + } + + if err := DB.Model(&User{}).Distinct("name").Where("name like ?", "distinct%").Count(&count).Error; err != nil || count != 3 { + t.Errorf("failed to query users count, got error: %v, count %v", err, count) + } +} From 163200d05fb18f6c5ea8ea66ad61e76d5d26dfe3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 5 Jun 2020 20:24:15 +0800 Subject: [PATCH 397/881] Test Hooks --- callbacks/create.go | 24 ++++++------ callbacks/delete.go | 8 ++-- callbacks/query.go | 4 +- callbacks/update.go | 32 ++++++++-------- finisher_api.go | 4 +- statement.go | 2 +- tests/hooks_test.go | 91 ++++++++++++++++++++++++++++++++++++++++++++- 7 files changed, 126 insertions(+), 39 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 0b88e263..99140612 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -12,31 +12,31 @@ func BeforeCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { - var ok bool + var called bool if db.Statement.Schema.BeforeSave { if i, ok := value.(gorm.BeforeSaveInterface); ok { - ok = true + called = true db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeCreate { if i, ok := value.(gorm.BeforeCreateInterface); ok { - ok = true + called = true db.AddError(i.BeforeCreate(tx)) } } - return ok + return called } if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Interface()) + callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) } case reflect.Struct: - callMethod(db.Statement.ReflectValue.Interface()) + callMethod(db.Statement.ReflectValue.Addr().Interface()) } } } @@ -184,31 +184,31 @@ func AfterCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { - var ok bool + var called bool if db.Statement.Schema.AfterSave { if i, ok := value.(gorm.AfterSaveInterface); ok { - ok = true + called = true db.AddError(i.AfterSave(tx)) } } if db.Statement.Schema.AfterCreate { if i, ok := value.(gorm.AfterCreateInterface); ok { - ok = true + called = true db.AddError(i.AfterCreate(tx)) } } - return ok + return called } if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Interface()) + callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) } case reflect.Struct: - callMethod(db.Statement.ReflectValue.Interface()) + callMethod(db.Statement.ReflectValue.Addr().Interface()) } } } diff --git a/callbacks/delete.go b/callbacks/delete.go index b8691ff9..f1a49c11 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -25,10 +25,10 @@ func BeforeDelete(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Interface()) + callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) } case reflect.Struct: - callMethod(db.Statement.ReflectValue.Interface()) + callMethod(db.Statement.ReflectValue.Addr().Interface()) } } } @@ -101,10 +101,10 @@ func AfterDelete(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Interface()) + callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) } case reflect.Struct: - callMethod(db.Statement.ReflectValue.Interface()) + callMethod(db.Statement.ReflectValue.Addr().Interface()) } } } diff --git a/callbacks/query.go b/callbacks/query.go index 16202187..b6667414 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -203,10 +203,10 @@ func AfterQuery(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Interface()) + callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) } case reflect.Struct: - callMethod(db.Statement.ReflectValue.Interface()) + callMethod(db.Statement.ReflectValue.Addr().Interface()) } } } diff --git a/callbacks/update.go b/callbacks/update.go index 2589370f..9c922956 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -29,34 +29,34 @@ func SetupUpdateReflectValue(db *gorm.DB) { } func BeforeUpdate(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { - var ok bool + var called bool if db.Statement.Schema.BeforeSave { if i, ok := value.(gorm.BeforeSaveInterface); ok { - ok = true + called = true db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeUpdate { if i, ok := value.(gorm.BeforeUpdateInterface); ok { - ok = true + called = true db.AddError(i.BeforeUpdate(tx)) } } - return ok + return called } if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Interface()) + callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) } case reflect.Struct: - callMethod(db.Statement.ReflectValue.Interface()) + callMethod(db.Statement.ReflectValue.Addr().Interface()) } } } @@ -98,34 +98,34 @@ func Update(db *gorm.DB) { } func AfterUpdate(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { - var ok bool + var called bool if db.Statement.Schema.AfterSave { if i, ok := value.(gorm.AfterSaveInterface); ok { - ok = true + called = true db.AddError(i.AfterSave(tx)) } } if db.Statement.Schema.AfterUpdate { if i, ok := value.(gorm.AfterUpdateInterface); ok { - ok = true + called = true db.AddError(i.AfterUpdate(tx)) } } - return ok + return called } if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Interface()) + callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) } case reflect.Struct: - callMethod(db.Statement.ReflectValue.Interface()) + callMethod(db.Statement.ReflectValue.Addr().Interface()) } } } @@ -191,7 +191,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } - if !stmt.DisableUpdateTime && stmt.Schema != nil { + if !stmt.UpdatingColumn && stmt.Schema != nil { for _, field := range stmt.Schema.FieldsByDBName { if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { now := stmt.DB.NowFunc() @@ -215,7 +215,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { value, isZero := field.ValueOf(updatingValue) - if !stmt.DisableUpdateTime { + if !stmt.UpdatingColumn { if field.AutoUpdateTime > 0 { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() diff --git a/finisher_api.go b/finisher_api.go index d6de7aa3..e94fd095 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -207,7 +207,7 @@ func (db *DB) Updates(values interface{}) (tx *DB) { func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} - tx.Statement.DisableUpdateTime = true + tx.Statement.UpdatingColumn = true tx.callbacks.Update().Execute(tx) return } @@ -215,7 +215,7 @@ func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { func (db *DB) UpdateColumns(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values - tx.Statement.DisableUpdateTime = true + tx.Statement.UpdatingColumn = true tx.callbacks.Update().Execute(tx) return } diff --git a/statement.go b/statement.go index 755d93ac..e3f324b9 100644 --- a/statement.go +++ b/statement.go @@ -33,7 +33,7 @@ type Statement struct { Schema *schema.Schema Context context.Context RaiseErrorOnNotFound bool - DisableUpdateTime bool + UpdatingColumn bool SQL strings.Builder Vars []interface{} NamedVars []sql.NamedArg diff --git a/tests/hooks_test.go b/tests/hooks_test.go index e2850c27..c74e8f10 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -3,9 +3,11 @@ package tests_test import ( "errors" "reflect" + "strings" "testing" "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" ) type Product struct { @@ -98,7 +100,7 @@ func TestRunCallbacks(t *testing.T) { DB.Save(&p) if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { - t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) + t.Fatalf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) } DB.Where("Code = ?", "unique_code").First(&p) @@ -114,7 +116,7 @@ func TestRunCallbacks(t *testing.T) { var products []Product DB.Find(&products, "code = ?", "unique_code") - if products[0].AfterFindCallTimes != 1 { + if products[0].AfterFindCallTimes != 2 { t.Fatalf("AfterFind callbacks should work with slice, called %v", products[0].AfterFindCallTimes) } @@ -198,3 +200,88 @@ func TestCallbacksWithErrors(t *testing.T) { t.Fatalf("Record shouldn't be deleted because of an error happened in after delete callback") } } + +type Product2 struct { + gorm.Model + Name string + Code string + Price int64 + Owner string +} + +func (s Product2) BeforeCreate(tx *gorm.DB) (err error) { + if !strings.HasSuffix(s.Name, "_clone") { + newProduft := s + newProduft.Price *= 2 + newProduft.Name += "_clone" + err = tx.Create(&newProduft).Error + } + + if s.Name == "Invalid" { + return errors.New("invalid") + } + + return nil +} + +func (s *Product2) BeforeUpdate(tx *gorm.DB) (err error) { + tx.Statement.Where("owner != ?", "admin") + return +} + +func TestUseDBInHooks(t *testing.T) { + DB.Migrator().DropTable(&Product2{}) + DB.AutoMigrate(&Product2{}) + + product := Product2{Name: "Invalid", Price: 100} + + if err := DB.Create(&product).Error; err == nil { + t.Fatalf("should returns error %v when creating product, but got nil", err) + } + + product2 := Product2{Name: "Nice", Price: 100} + + if err := DB.Create(&product2).Error; err != nil { + t.Fatalf("Failed to create product, got error: %v", err) + } + + var result Product2 + if err := DB.First(&result, "name = ?", "Nice").Error; err != nil { + t.Fatalf("Failed to query product, got error: %v", err) + } + + var resultClone Product2 + if err := DB.First(&resultClone, "name = ?", "Nice_clone").Error; err != nil { + t.Fatalf("Failed to find cloned product, got error: %v", err) + } + + result.Price *= 2 + result.Name += "_clone" + AssertObjEqual(t, result, resultClone, "Price", "Name") + + DB.Model(&result).Update("Price", 500) + var result2 Product2 + DB.First(&result2, "name = ?", "Nice") + + if result2.Price != 500 { + t.Errorf("Failed to update product's price, expects: %v, got %v", 500, result2.Price) + } + + product3 := Product2{Name: "Nice2", Price: 600, Owner: "admin"} + if err := DB.Create(&product3).Error; err != nil { + t.Fatalf("Failed to create product, got error: %v", err) + } + + var result3 Product2 + if err := DB.First(&result3, "name = ?", "Nice2").Error; err != nil { + t.Fatalf("Failed to query product, got error: %v", err) + } + + DB.Model(&result3).Update("Price", 800) + var result4 Product2 + DB.First(&result4, "name = ?", "Nice2") + + if result4.Price != 600 { + t.Errorf("Admin product's price should not be changed, expects: %v, got %v", 600, result4.Price) + } +} From 1490a062dbd9f6e2043a70e56b41d14364bb07a6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 5 Jun 2020 21:23:20 +0800 Subject: [PATCH 398/881] Refactor codebase and add benchmark test --- callbacks.go | 7 ++- callbacks/callmethod.go | 21 ++++++++ callbacks/create.go | 105 ++++++++++++++-------------------------- callbacks/delete.go | 49 +++++-------------- callbacks/query.go | 24 ++------- callbacks/update.go | 35 ++------------ gorm.go | 62 ++++++++---------------- migrator/migrator.go | 2 +- schema/field_test.go | 2 +- schema/schema.go | 29 +++++------ schema/schema_test.go | 4 +- statement.go | 42 +--------------- tests/benchmark_test.go | 44 +++++++++++++++++ tests/go.mod | 2 +- 14 files changed, 168 insertions(+), 260 deletions(-) create mode 100644 callbacks/callmethod.go create mode 100644 tests/benchmark_test.go diff --git a/callbacks.go b/callbacks.go index a9a6dd85..4f700081 100644 --- a/callbacks.go +++ b/callbacks.go @@ -105,8 +105,11 @@ func (p *processor) Execute(db *DB) { return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected }, db.Error) - stmt.reinit() - // db.Config.statementPool.Put(stmt) + if !stmt.DB.DryRun { + stmt.SQL.Reset() + stmt.Vars = nil + stmt.NamedVars = nil + } } } diff --git a/callbacks/callmethod.go b/callbacks/callmethod.go new file mode 100644 index 00000000..a0e9b0e7 --- /dev/null +++ b/callbacks/callmethod.go @@ -0,0 +1,21 @@ +package callbacks + +import ( + "reflect" + + "gorm.io/gorm" +) + +func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { + tx := db.Session(&gorm.Session{}) + if called := fc(db.Statement.Dest, tx); !called { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + fc(db.Statement.ReflectValue.Index(i).Addr().Interface(), tx) + } + case reflect.Struct: + fc(db.Statement.ReflectValue.Addr().Interface(), tx) + } + } +} diff --git a/callbacks/create.go b/callbacks/create.go index 99140612..ec4ee1d1 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -10,9 +10,7 @@ import ( func BeforeCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { - tx := db.Session(&gorm.Session{}) - callMethod := func(value interface{}) bool { - var called bool + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { if i, ok := value.(gorm.BeforeSaveInterface); ok { called = true @@ -27,18 +25,7 @@ func BeforeCreate(db *gorm.DB) { } } return called - } - - if ok := callMethod(db.Statement.Dest); !ok { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) - } - case reflect.Struct: - callMethod(db.Statement.ReflectValue.Addr().Interface()) - } - } + }) } } @@ -67,28 +54,26 @@ func Create(config *Config) func(db *gorm.DB) { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { - if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { - if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { - if insertID, err := result.LastInsertId(); err == nil { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID-- - } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID++ - } + if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + if insertID, err := result.LastInsertId(); err == nil { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID-- + } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID++ } - case reflect.Struct: - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) } - } else { - db.AddError(err) + case reflect.Struct: + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) } + } else { + db.AddError(err) } } db.RowsAffected, _ = result.RowsAffected() @@ -122,19 +107,17 @@ func CreateWithReturning(db *gorm.DB) { db.Statement.WriteString(" RETURNING ") var ( - idx int fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) ) - for dbName, field := range sch.FieldsWithDefaultDBValue { - if idx != 0 { + for idx, field := range sch.FieldsWithDefaultDBValue { + if idx > 0 { db.Statement.WriteByte(',') } fields[idx] = field - db.Statement.WriteQuoted(dbName) - idx++ + db.Statement.WriteQuoted(field.DBName) } if !db.DryRun { @@ -149,10 +132,11 @@ func CreateWithReturning(db *gorm.DB) { for idx, field := range fields { values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() } + + db.RowsAffected++ if err := rows.Scan(values...); err != nil { db.AddError(err) } - db.RowsAffected++ } case reflect.Struct: for idx, field := range fields { @@ -161,12 +145,10 @@ func CreateWithReturning(db *gorm.DB) { if rows.Next() { db.RowsAffected++ - err = rows.Scan(values...) + db.AddError(rows.Scan(values...)) } } - } - - if err != nil { + } else { db.AddError(err) } } @@ -182,9 +164,7 @@ func CreateWithReturning(db *gorm.DB) { func AfterCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { - tx := db.Session(&gorm.Session{}) - callMethod := func(value interface{}) bool { - var called bool + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { if i, ok := value.(gorm.AfterSaveInterface); ok { called = true @@ -199,18 +179,7 @@ func AfterCreate(db *gorm.DB) { } } return called - } - - if ok := callMethod(db.Statement.Dest); !ok { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) - } - case reflect.Struct: - callMethod(db.Statement.ReflectValue.Addr().Interface()) - } - } + }) } } @@ -230,7 +199,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { ) for _, db := range stmt.Schema.DBNames { - if stmt.Schema.FieldsWithDefaultDBValue[db] == nil { + if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { values.Columns = append(values.Columns, clause.Column{Name: db}) } @@ -257,13 +226,13 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { } } - for db, field := range stmt.Schema.FieldsWithDefaultDBValue { - if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, isZero := field.ValueOf(rv); !isZero { - if len(defaultValueFieldsHavingValue[db]) == 0 { - defaultValueFieldsHavingValue[db] = make([]interface{}, stmt.ReflectValue.Len()) + if len(defaultValueFieldsHavingValue[field.DBName]) == 0 { + defaultValueFieldsHavingValue[field.DBName] = make([]interface{}, stmt.ReflectValue.Len()) } - defaultValueFieldsHavingValue[db][i] = v + defaultValueFieldsHavingValue[field.DBName][i] = v } } } @@ -294,10 +263,10 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { } } - for db, field := range stmt.Schema.FieldsWithDefaultDBValue { - if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero { - values.Columns = append(values.Columns, clause.Column{Name: db}) + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) values.Values[0] = append(values.Values[0], v) } } diff --git a/callbacks/delete.go b/callbacks/delete.go index f1a49c11..b246e69f 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -10,27 +10,14 @@ import ( func BeforeDelete(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { - tx := db.Session(&gorm.Session{}) - callMethod := func(value interface{}) bool { - if db.Statement.Schema.BeforeDelete { - if i, ok := value.(gorm.BeforeDeleteInterface); ok { - db.AddError(i.BeforeDelete(tx)) - return true - } + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(gorm.BeforeDeleteInterface); ok { + db.AddError(i.BeforeDelete(tx)) + return true } - return false - } - if ok := callMethod(db.Statement.Dest); !ok { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) - } - case reflect.Struct: - callMethod(db.Statement.ReflectValue.Addr().Interface()) - } - } + return false + }) } } @@ -86,26 +73,12 @@ func Delete(db *gorm.DB) { func AfterDelete(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { - tx := db.Session(&gorm.Session{}) - callMethod := func(value interface{}) bool { - if db.Statement.Schema.AfterDelete { - if i, ok := value.(gorm.AfterDeleteInterface); ok { - db.AddError(i.AfterDelete(tx)) - return true - } + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(gorm.AfterDeleteInterface); ok { + db.AddError(i.AfterDelete(tx)) + return true } return false - } - - if ok := callMethod(db.Statement.Dest); !ok { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) - } - case reflect.Struct: - callMethod(db.Statement.ReflectValue.Addr().Interface()) - } - } + }) } } diff --git a/callbacks/query.go b/callbacks/query.go index b6667414..41f09375 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -188,26 +188,12 @@ func Preload(db *gorm.DB) { func AfterQuery(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind { - tx := db.Session(&gorm.Session{}) - callMethod := func(value interface{}) bool { - if db.Statement.Schema.AfterFind { - if i, ok := value.(gorm.AfterFindInterface); ok { - db.AddError(i.AfterFind(tx)) - return true - } + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(gorm.AfterFindInterface); ok { + db.AddError(i.AfterFind(tx)) + return true } return false - } - - if ok := callMethod(db.Statement.Dest); !ok { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) - } - case reflect.Struct: - callMethod(db.Statement.ReflectValue.Addr().Interface()) - } - } + }) } } diff --git a/callbacks/update.go b/callbacks/update.go index 9c922956..a41a3c59 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -30,9 +30,7 @@ func SetupUpdateReflectValue(db *gorm.DB) { func BeforeUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { - tx := db.Session(&gorm.Session{}) - callMethod := func(value interface{}) bool { - var called bool + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { if i, ok := value.(gorm.BeforeSaveInterface); ok { called = true @@ -46,19 +44,9 @@ func BeforeUpdate(db *gorm.DB) { db.AddError(i.BeforeUpdate(tx)) } } - return called - } - if ok := callMethod(db.Statement.Dest); !ok { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) - } - case reflect.Struct: - callMethod(db.Statement.ReflectValue.Addr().Interface()) - } - } + return called + }) } } @@ -99,9 +87,7 @@ func Update(db *gorm.DB) { func AfterUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { - tx := db.Session(&gorm.Session{}) - callMethod := func(value interface{}) bool { - var called bool + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { if i, ok := value.(gorm.AfterSaveInterface); ok { called = true @@ -116,18 +102,7 @@ func AfterUpdate(db *gorm.DB) { } } return called - } - - if ok := callMethod(db.Statement.Dest); !ok { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) - } - case reflect.Struct: - callMethod(db.Statement.ReflectValue.Addr().Interface()) - } - } + }) } } diff --git a/gorm.go b/gorm.go index e6a28635..cea744f7 100644 --- a/gorm.go +++ b/gorm.go @@ -25,9 +25,10 @@ type Config struct { NowFunc func() time.Time // DryRun generate sql without execute DryRun bool - // PrepareStmt executes the given query in cached statement PrepareStmt bool + // DisableAutomaticPing + DisableAutomaticPing bool // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder @@ -93,8 +94,8 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { config.ClauseBuilders = map[string]clause.ClauseBuilder{} } - if dialector != nil { - err = dialector.Initialize(db) + if config.Dialector != nil { + err = config.Dialector.Initialize(db) } if config.PrepareStmt { @@ -104,16 +105,14 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { } } - if db.Statement == nil { - db.Statement = &Statement{ - DB: db, - ConnPool: db.ConnPool, - Context: context.Background(), - Clauses: map[string]clause.Clause{}, - } + db.Statement = &Statement{ + DB: db, + ConnPool: db.ConnPool, + Context: context.Background(), + Clauses: map[string]clause.Clause{}, } - if err == nil { + if err == nil && !config.DisableAutomaticPing { if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok { err = pinger.Ping() } @@ -138,17 +137,8 @@ func (db *DB) Session(config *Session) *DB { ) if config.Context != nil { - if tx.Statement != nil { - tx.Statement = tx.Statement.clone() - tx.Statement.DB = tx - } else { - tx.Statement = &Statement{ - DB: tx, - Clauses: map[string]clause.Clause{}, - ConnPool: tx.ConnPool, - } - } - + tx.Statement = tx.Statement.clone() + tx.Statement.DB = tx tx.Statement.Context = config.Context } @@ -160,7 +150,7 @@ func (db *DB) Session(config *Session) *DB { } if config.WithConditions { - tx.clone = 3 + tx.clone = 2 } if config.DryRun { @@ -200,10 +190,7 @@ func (db *DB) Set(key string, value interface{}) *DB { // Get get value with key from current db instance's context func (db *DB) Get(key string) (interface{}, bool) { - if db.Statement != nil { - return db.Statement.Settings.Load(key) - } - return nil, false + return db.Statement.Settings.Load(key) } // InstanceSet store value with key into current db instance's context @@ -215,10 +202,7 @@ func (db *DB) InstanceSet(key string, value interface{}) *DB { // InstanceGet get value with key from current db instance's context func (db *DB) InstanceGet(key string) (interface{}, bool) { - if db.Statement != nil { - return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) - } - return nil, false + return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) } func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { @@ -282,22 +266,18 @@ func (db *DB) getInstance() *DB { if db.clone > 0 { tx := &DB{Config: db.Config} - switch db.clone { - case 1: // clone with new statement + if db.clone == 1 { + // clone with new statement tx.Statement = &Statement{ DB: tx, ConnPool: db.Statement.ConnPool, Context: db.Statement.Context, Clauses: map[string]clause.Clause{}, } - case 2: // with old statement, generate new statement for future call, used to pass to callbacks - db.clone = 1 - tx.Statement = db.Statement - case 3: // with clone statement - if db.Statement != nil { - tx.Statement = db.Statement.clone() - tx.Statement.DB = tx - } + } else { + // with clone statement + tx.Statement = db.Statement.clone() + tx.Statement.DB = tx } return tx diff --git a/migrator/migrator.go b/migrator/migrator.go index afef65c3..18b2593d 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -62,7 +62,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { expr.SQL += " UNIQUE" } - if field.HasDefaultValue { + if field.HasDefaultValue && field.DefaultValue != "" { if field.DataType == schema.String { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValue}} m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValue) diff --git a/schema/field_test.go b/schema/field_test.go index cc4b53fc..0936c0d1 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -235,7 +235,7 @@ func TestParseFieldWithPermission(t *testing.T) { } fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true}, + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true, HasDefaultValue: true}, {Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false}, {Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true}, {Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: true}, diff --git a/schema/schema.go b/schema/schema.go index 9e05303a..d2c4d08b 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -26,7 +26,7 @@ type Schema struct { Fields []*Field FieldsByName map[string]*Field FieldsByDBName map[string]*Field - FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database + FieldsWithDefaultDBValue []*Field // fields with default value assigned by database Relationships Relationships CreateClauses []clause.Interface QueryClauses []clause.Interface @@ -153,23 +153,14 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) schema.FieldsByName[field.Name] = field if v != nil && v.PrimaryKey { - if schema.PrioritizedPrimaryField == v { - schema.PrioritizedPrimaryField = nil - } - for idx, f := range schema.PrimaryFields { if f == v { schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) - } else if schema.PrioritizedPrimaryField == nil { - schema.PrioritizedPrimaryField = f } } } if field.PrimaryKey { - if schema.PrioritizedPrimaryField == nil { - schema.PrioritizedPrimaryField = field - } schema.PrimaryFields = append(schema.PrimaryFields, field) } } @@ -192,21 +183,27 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } + if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 { + schema.PrioritizedPrimaryField = schema.PrimaryFields[0] + } + for _, field := range schema.PrimaryFields { schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) } - schema.FieldsWithDefaultDBValue = map[string]*Field{} - for db, field := range schema.FieldsByDBName { + for _, field := range schema.FieldsByDBName { if field.HasDefaultValue && field.DefaultValueInterface == nil { - schema.FieldsWithDefaultDBValue[db] = field + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } } - if schema.PrioritizedPrimaryField != nil { - switch schema.PrioritizedPrimaryField.DataType { + if field := schema.PrioritizedPrimaryField; field != nil { + switch field.DataType { case Int, Uint: - schema.FieldsWithDefaultDBValue[schema.PrioritizedPrimaryField.DBName] = schema.PrioritizedPrimaryField + if !field.HasDefaultValue || field.DefaultValueInterface != nil { + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) + } + field.HasDefaultValue = true } } diff --git a/schema/schema_test.go b/schema/schema_test.go index 82f07fa8..4ec7ff0c 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -32,7 +32,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { // check fields fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64}, + {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64, HasDefaultValue: true}, {Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time}, {Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time}, {Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time}, @@ -125,7 +125,7 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) { // check fields fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64}, + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64, HasDefaultValue: true}, {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, {Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time}, diff --git a/statement.go b/statement.go index e3f324b9..2c814547 100644 --- a/statement.go +++ b/statement.go @@ -226,6 +226,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con if sql == "" && len(args) == 0 { return } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { + // looks like a where condition return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} } else if len(args) == 1 { return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}} @@ -242,12 +243,6 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con switch v := arg.(type) { case clause.Expression: conds = append(conds, v) - case *DB: - if v.Statement != nil { - if cs, ok := v.Statement.Clauses["WHERE"]; ok { - conds = append(conds, cs.Expression) - } - } case map[interface{}]interface{}: for i, j := range v { conds = append(conds, clause.Eq{Column: i, Value: j}) @@ -326,7 +321,6 @@ func (stmt *Statement) Parse(value interface{}) (err error) { func (stmt *Statement) clone() *Statement { newStmt := &Statement{ - DB: stmt.DB, Table: stmt.Table, Model: stmt.Model, Dest: stmt.Dest, @@ -357,37 +351,3 @@ func (stmt *Statement) clone() *Statement { return newStmt } - -func (stmt *Statement) reinit() { - // stmt.Table = "" - // stmt.Model = nil - // stmt.Selects = nil - // stmt.Omits = nil - // stmt.ConnPool = stmt.DB.Config.ConnPool - // stmt.Context = context.Background() - // stmt.RaiseErrorOnNotFound = false - - // for k := range stmt.Clauses { - // delete(stmt.Clauses, k) - // } - - // for k := range stmt.Joins { - // delete(stmt.Joins, k) - // } - - // for k := range stmt.Preloads { - // delete(stmt.Preloads, k) - // } - - // stmt.Settings.Range(func(k, _ interface{}) bool { - // stmt.Settings.Delete(k) - // return true - // }) - - // stmt.Schema = nil - if !stmt.DB.DryRun { - stmt.SQL.Reset() - stmt.Vars = nil - stmt.NamedVars = nil - } -} diff --git a/tests/benchmark_test.go b/tests/benchmark_test.go new file mode 100644 index 00000000..c6ce93a2 --- /dev/null +++ b/tests/benchmark_test.go @@ -0,0 +1,44 @@ +package tests_test + +import ( + "testing" + + . "gorm.io/gorm/utils/tests" +) + +func BenchmarkCreate(b *testing.B) { + var user = *GetUser("bench", Config{}) + + for x := 0; x < b.N; x++ { + user.ID = 0 + DB.Create(&user) + } +} + +func BenchmarkFind(b *testing.B) { + var user = *GetUser("find", Config{}) + DB.Create(&user) + + for x := 0; x < b.N; x++ { + DB.Find(&User{}, "id = ?", user.ID) + } +} + +func BenchmarkUpdate(b *testing.B) { + var user = *GetUser("find", Config{}) + DB.Create(&user) + + for x := 0; x < b.N; x++ { + DB.Model(&user).Updates(map[string]interface{}{"Age": x}) + } +} + +func BenchmarkDelete(b *testing.B) { + var user = *GetUser("find", Config{}) + + for x := 0; x < b.N; x++ { + user.ID = 0 + DB.Create(&user) + DB.Delete(&user) + } +} diff --git a/tests/go.mod b/tests/go.mod index de58a0de..3c2dfc6c 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,7 +7,7 @@ require ( gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0 gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 gorm.io/driver/sqlite v1.0.0 - gorm.io/driver/sqlserver v0.0.0-20200602144728-79c224f6c1a2 + gorm.io/driver/sqlserver v0.0.0-20200605135528-04ae0f7a15bf gorm.io/gorm v0.0.0-00010101000000-000000000000 ) From a954d772d7b0ee0dc704573a963826b074e64fe9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 6 Jun 2020 10:47:32 +0800 Subject: [PATCH 399/881] Support customize gorm field type --- migrator/migrator.go | 11 +++++++++++ schema/field.go | 4 ++++ schema/interfaces.go | 23 +++++++++++++++++++++++ schema/schema.go | 16 ---------------- tests/tests_all.sh | 21 ++++++++++++++------- 5 files changed, 52 insertions(+), 23 deletions(-) create mode 100644 schema/interfaces.go diff --git a/migrator/migrator.go b/migrator/migrator.go index 18b2593d..a98f7fe3 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -24,6 +24,10 @@ type Config struct { gorm.Dialector } +type GormDataTypeInterface interface { + GormDBDataType(*gorm.DB, *schema.Field) string +} + func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { stmt := &gorm.Statement{DB: m.DB} if m.DB.Statement != nil { @@ -44,6 +48,13 @@ func (m Migrator) DataTypeOf(field *schema.Field) string { return field.DBDataType } + fieldValue := reflect.New(field.IndirectFieldType) + if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { + if dataType := dataTyper.GormDBDataType(m.DB, field); dataType != "" { + return dataType + } + } + return m.Dialector.DataTypeOf(field) } diff --git a/schema/field.go b/schema/field.go index 854ec520..e0d49e2f 100644 --- a/schema/field.go +++ b/schema/field.go @@ -220,6 +220,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { + field.DataType = DataType(dataTyper.GormDataType()) + } + if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if strings.ToUpper(v) == "NANO" { field.AutoCreateTime = UnixNanosecond diff --git a/schema/interfaces.go b/schema/interfaces.go new file mode 100644 index 00000000..f5d07843 --- /dev/null +++ b/schema/interfaces.go @@ -0,0 +1,23 @@ +package schema + +import "gorm.io/gorm/clause" + +type GormDataTypeInterface interface { + GormDataType() string +} + +type CreateClausesInterface interface { + CreateClauses() []clause.Interface +} + +type QueryClausesInterface interface { + QueryClauses() []clause.Interface +} + +type UpdateClausesInterface interface { + UpdateClauses() []clause.Interface +} + +type DeleteClausesInterface interface { + DeleteClauses() []clause.Interface +} diff --git a/schema/schema.go b/schema/schema.go index d2c4d08b..5b360f5e 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -42,22 +42,6 @@ type Schema struct { cacheStore *sync.Map } -type CreateClausesInterface interface { - CreateClauses() []clause.Interface -} - -type QueryClausesInterface interface { - QueryClauses() []clause.Interface -} - -type UpdateClausesInterface interface { - UpdateClauses() []clause.Interface -} - -type DeleteClausesInterface interface { - DeleteClauses() []clause.Interface -} - func (schema Schema) String() string { if schema.ModelType.Name() == "" { return fmt.Sprintf("%v(%v)", schema.Name, schema.Table) diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 92a28f3b..affb1847 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -17,14 +17,21 @@ for dialect in "${dialects[@]}" ; do if [ "$GORM_VERBOSE" = "" ] then - DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 ./... - cd tests - DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 ./... + GORM_DIALECT=${dialect} go test $race -count=1 ./... + if [ -d tests ] + then + cd tests + GORM_DIALECT=${dialect} go test $race -count=1 ./... + cd .. + fi else - DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 -v ./... - cd tests - DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 -v ./... + GORM_DIALECT=${dialect} go test $race -count=1 -v ./... + if [ -d tests ] + then + cd tests + GORM_DIALECT=${dialect} go test $race -count=1 -v ./... + cd .. + fi fi - cd .. fi done From edd4be3fcb2dd1c73101d8c0a0e1327874d5ab98 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 6 Jun 2020 14:23:47 +0800 Subject: [PATCH 400/881] Update README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 84236bb9..7491748f 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![go report card](https://goreportcard.com/badge/gorm.io/gorm "go report card")](https://goreportcard.com/report/gorm.io/gorm) [![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b) -[![codecov](https://codecov.io/gh/jinzhu/gorm/branch/master/graph/badge.svg)](https://codecov.io/gh/jinzhu/gorm) +[![codecov](https://codecov.io/gh/go-gorm/gorm/branch/master/graph/badge.svg)](https://codecov.io/gh/go-gorm/gorm) [![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm) [![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm) @@ -38,5 +38,5 @@ The fantastic ORM library for Golang, aims to be developer friendly. © Jinzhu, 2013~time.Now -Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License) +Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License) From ebb8511d59c5d95cf0d39d3ae17351bd282865fc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 6 Jun 2020 14:28:59 +0800 Subject: [PATCH 401/881] Add go.sum --- go.sum | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 go.sum diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..148bd6f5 --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.1 h1:g39TucaRWyV3dwDO++eEc6qf8TVIQ/Da48WmqjZ3i7E= +github.com/jinzhu/now v1.1.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= From 1acbb34406b2f2396bd843fe57595f031494610c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 6 Jun 2020 15:05:24 +0800 Subject: [PATCH 402/881] Update wercker.yml --- README.md | 2 +- wercker.yml | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 7491748f..f5df27f5 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![go report card](https://goreportcard.com/badge/gorm.io/gorm "go report card")](https://goreportcard.com/report/gorm.io/gorm) -[![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b) +[![wercker status](https://app.wercker.com/status/55136410c77335a6289ebd58b2f70125/s/master "wercker status")](https://app.wercker.com/project/byKey/55136410c77335a6289ebd58b2f70125) [![codecov](https://codecov.io/gh/go-gorm/gorm/branch/master/graph/badge.svg)](https://codecov.io/gh/go-gorm/gorm) [![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm) diff --git a/wercker.yml b/wercker.yml index 54d80be0..baece1bc 100644 --- a/wercker.yml +++ b/wercker.yml @@ -83,47 +83,47 @@ build: - script: name: test sqlite code: | - GORM_DIALECT=sqlite $GORM_VERBOSE=true ./tests/tests_all.sh + GORM_DIALECT=sqlite GORM_VERBOSE=true ./tests/tests_all.sh - script: name: test mariadb code: | - GORM_DIALECT=mysql $GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh - script: name: test mysql code: | - GORM_DIALECT=mysql $GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh - script: name: test mysql5.7 code: | - GORM_DIALECT=mysql $GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh - script: name: test mysql5.6 code: | - GORM_DIALECT=mysql $GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh - script: name: test postgres code: | - GORM_DIALECT=postgres $GORM_VERBOSE=true GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" ./tests/tests_all.sh + GORM_DIALECT=postgres GORM_VERBOSE=true GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" ./tests/tests_all.sh - script: name: test postgres11 code: | - GORM_DIALECT=postgres $GORM_VERBOSE=true GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" ./tests/tests_all.sh + GORM_DIALECT=postgres GORM_VERBOSE=true GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" ./tests/tests_all.sh - script: name: test postgres10 code: | - GORM_DIALECT=postgres $GORM_VERBOSE=true GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" ./tests/tests_all.sh + GORM_DIALECT=postgres GORM_VERBOSE=true GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" ./tests/tests_all.sh - script: name: test mssql code: | - GORM_DIALECT=mssql $GORM_VERBOSE=true GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" ./tests/tests_all.sh + GORM_DIALECT=mssql GORM_VERBOSE=true GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" ./tests/tests_all.sh - script: name: codecov From 52b763aab33967ab0221d9e7bb6b45a3ac7c5ab2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 6 Jun 2020 17:47:30 +0800 Subject: [PATCH 403/881] Add convert map Assignments helper --- clause/set.go | 21 +++++++++++++++++++++ clause/set_test.go | 19 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/clause/set.go b/clause/set.go index 590e27d5..4adfe68f 100644 --- a/clause/set.go +++ b/clause/set.go @@ -1,5 +1,7 @@ package clause +import "sort" + type Set []Assignment type Assignment struct { @@ -32,3 +34,22 @@ func (set Set) Build(builder Builder) { func (set Set) MergeClause(clause *Clause) { clause.Expression = set } + +func Assignments(values map[string]interface{}) Set { + var keys []string + var assignments []Assignment + + for key := range values { + keys = append(keys, key) + } + + sort.Strings(keys) + + for _, key := range keys { + assignments = append(assignments, Assignment{ + Column: Column{Table: CurrentTable, Name: key}, + Value: values[key], + }) + } + return assignments +} diff --git a/clause/set_test.go b/clause/set_test.go index dbc1e970..56fac706 100644 --- a/clause/set_test.go +++ b/clause/set_test.go @@ -2,6 +2,8 @@ package clause_test import ( "fmt" + "sort" + "strings" "testing" "gorm.io/gorm/clause" @@ -36,3 +38,20 @@ func TestSet(t *testing.T) { }) } } + +func TestAssignments(t *testing.T) { + set := clause.Assignments(map[string]interface{}{ + "name": "jinzhu", + "age": 18, + }) + + assignments := []clause.Assignment(set) + + sort.Slice(assignments, func(i, j int) bool { + return strings.Compare(assignments[i].Column.Name, assignments[j].Column.Name) > 0 + }) + + if len(assignments) != 2 || assignments[0].Column.Name != "name" || assignments[0].Value.(string) != "jinzhu" || assignments[1].Column.Name != "age" || assignments[1].Value.(int) != 18 { + t.Errorf("invalid assignments, got %v", assignments) + } +} From 38d1cd2bf182f55f28a8c909395ceaa2019d8b99 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 6 Jun 2020 21:35:28 +0800 Subject: [PATCH 404/881] Replace For with Locking --- clause/locking.go | 41 ++++++++++++----------------------------- clause/locking_test.go | 18 +++++------------- 2 files changed, 17 insertions(+), 42 deletions(-) diff --git a/clause/locking.go b/clause/locking.go index 3be1063b..290aac92 100644 --- a/clause/locking.go +++ b/clause/locking.go @@ -1,9 +1,5 @@ package clause -type For struct { - Lockings []Locking -} - type Locking struct { Strength string Table Table @@ -11,38 +7,25 @@ type Locking struct { } // Name where clause name -func (f For) Name() string { +func (locking Locking) Name() string { return "FOR" } // Build build where clause -func (f For) Build(builder Builder) { - for idx, locking := range f.Lockings { - if idx > 0 { - builder.WriteByte(' ') - } +func (locking Locking) Build(builder Builder) { + builder.WriteString(locking.Strength) + if locking.Table.Name != "" { + builder.WriteString(" OF ") + builder.WriteQuoted(locking.Table) + } - builder.WriteString("FOR ") - builder.WriteString(locking.Strength) - if locking.Table.Name != "" { - builder.WriteString(" OF ") - builder.WriteQuoted(locking.Table) - } - - if locking.Options != "" { - builder.WriteByte(' ') - builder.WriteString(locking.Options) - } + if locking.Options != "" { + builder.WriteByte(' ') + builder.WriteString(locking.Options) } } // MergeClause merge order by clauses -func (f For) MergeClause(clause *Clause) { - clause.Name = "" - - if v, ok := clause.Expression.(For); ok { - f.Lockings = append(v.Lockings, f.Lockings...) - } - - clause.Expression = f +func (locking Locking) MergeClause(clause *Clause) { + clause.Expression = locking } diff --git a/clause/locking_test.go b/clause/locking_test.go index 6f507692..5ca30ef0 100644 --- a/clause/locking_test.go +++ b/clause/locking_test.go @@ -14,24 +14,16 @@ func TestFor(t *testing.T) { Vars []interface{} }{ { - []clause.Interface{clause.Select{}, clause.From{}, clause.For{ - Lockings: []clause.Locking{{Strength: "UPDATE"}}, - }}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}}, "SELECT * FROM `users` FOR UPDATE", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.For{ - Lockings: []clause.Locking{{Strength: "UPDATE"}, {Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}}, - }}, - "SELECT * FROM `users` FOR UPDATE FOR SHARE OF `users`", nil, + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}}, + "SELECT * FROM `users` FOR SHARE OF `users`", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.For{ - Lockings: []clause.Locking{{Strength: "UPDATE"}, {Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}}, - }, clause.For{ - Lockings: []clause.Locking{{Strength: "UPDATE", Options: "NOWAIT"}}, - }}, - "SELECT * FROM `users` FOR UPDATE FOR SHARE OF `users` FOR UPDATE NOWAIT", nil, + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}, clause.Locking{Strength: "UPDATE", Options: "NOWAIT"}}, + "SELECT * FROM `users` FOR UPDATE NOWAIT", nil, }, } From 6937d713c31e23eef0c0377e73d494a631f4e9f3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 6 Jun 2020 22:52:08 +0800 Subject: [PATCH 405/881] Refactor clauses --- clause/clause.go | 44 +++++++++++++++++++++++------------------- clause/locking_test.go | 2 +- clause/where.go | 18 ++++++++--------- clause/where_test.go | 2 +- finisher_api.go | 7 ++++--- statement.go | 16 +++++++-------- 6 files changed, 46 insertions(+), 43 deletions(-) diff --git a/clause/clause.go b/clause/clause.go index 9a5d1273..b3e96332 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -24,42 +24,46 @@ type Builder interface { // Clause type Clause struct { - Name string // WHERE - Priority float64 - BeforeExpressions []Expression - AfterNameExpressions []Expression - AfterExpressions []Expression - Expression Expression - Builder ClauseBuilder + Name string // WHERE + BeforeExpression Expression + AfterNameExpression Expression + AfterExpression Expression + Expression Expression + Builder ClauseBuilder } // Build build clause func (c Clause) Build(builder Builder) { if c.Builder != nil { c.Builder(c, builder) - } else { - builders := c.BeforeExpressions + } else if c.Expression != nil { + if c.BeforeExpression != nil { + c.BeforeExpression.Build(builder) + builder.WriteByte(' ') + } + if c.Name != "" { - builders = append(builders, Expr{SQL: c.Name}) + builder.WriteString(c.Name) + builder.WriteByte(' ') } - builders = append(builders, c.AfterNameExpressions...) - if c.Expression != nil { - builders = append(builders, c.Expression) + if c.AfterNameExpression != nil { + c.BeforeExpression.Build(builder) + builder.WriteByte(' ') } - for idx, expr := range append(builders, c.AfterExpressions...) { - if idx != 0 { - builder.WriteByte(' ') - } - expr.Build(builder) + c.Expression.Build(builder) + + if c.AfterExpression != nil { + builder.WriteByte(' ') + c.AfterExpression.Build(builder) } } } const ( - PrimaryKey string = "@@@priamry_key@@@" - CurrentTable string = "@@@table@@@" + PrimaryKey string = "@@@py@@@" // primary key + CurrentTable string = "@@@ct@@@" // current table ) var ( diff --git a/clause/locking_test.go b/clause/locking_test.go index 5ca30ef0..0e607312 100644 --- a/clause/locking_test.go +++ b/clause/locking_test.go @@ -7,7 +7,7 @@ import ( "gorm.io/gorm/clause" ) -func TestFor(t *testing.T) { +func TestLocking(t *testing.T) { results := []struct { Clauses []clause.Interface Result string diff --git a/clause/where.go b/clause/where.go index 08c78b22..015addf8 100644 --- a/clause/where.go +++ b/clause/where.go @@ -14,7 +14,7 @@ func (where Where) Name() string { func (where Where) Build(builder Builder) { // Switch position if the first query expression is a single Or condition for idx, expr := range where.Exprs { - if v, ok := expr.(OrConditions); (!ok && expr != nil) || len(v.Exprs) > 1 { + if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 { if idx != 0 { where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0] } @@ -23,17 +23,15 @@ func (where Where) Build(builder Builder) { } for idx, expr := range where.Exprs { - if expr != nil { - if idx > 0 { - if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { - builder.WriteString(" OR ") - } else { - builder.WriteString(" AND ") - } + if idx > 0 { + if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { + builder.WriteString(" OR ") + } else { + builder.WriteString(" AND ") } - - expr.Build(builder) } + + expr.Build(builder) } return diff --git a/clause/where_test.go b/clause/where_test.go index 894e11f4..95bba820 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -27,7 +27,7 @@ func TestWhere(t *testing.T) { }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ - Exprs: []clause.Expression{clause.Or(), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, + Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, }}, "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18}, }, diff --git a/finisher_api.go b/finisher_api.go index e94fd095..434f0e22 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -32,13 +32,14 @@ func (db *DB) Save(value interface{}) (tx *DB) { for idx, pf := range tx.Statement.Schema.PrimaryFields { if pv, isZero := pf.ValueOf(reflectValue); isZero { tx.callbacks.Create().Execute(tx) - where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} return + } else { + where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} } } - } - tx.Statement.AddClause(where) + tx.Statement.AddClause(where) + } } if len(tx.Statement.Selects) == 0 { diff --git a/statement.go b/statement.go index 2c814547..ec9e021f 100644 --- a/statement.go +++ b/statement.go @@ -201,19 +201,19 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { func (stmt *Statement) AddClause(v clause.Interface) { if optimizer, ok := v.(StatementModifier); ok { optimizer.ModifyStatement(stmt) + } else { + c, ok := stmt.Clauses[v.Name()] + if !ok { + c.Name = v.Name() + } + v.MergeClause(&c) + stmt.Clauses[v.Name()] = c } - - c, ok := stmt.Clauses[v.Name()] - if !ok { - c.Name = v.Name() - } - v.MergeClause(&c) - stmt.Clauses[v.Name()] = c } // AddClauseIfNotExists add clause if not exists func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { - if c, ok := stmt.Clauses[v.Name()]; !ok && c.Expression == nil { + if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil { stmt.AddClause(v) } } From 93043334c3a64bde3933c3dbd32bca9125b90816 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 7 Jun 2020 12:47:26 +0800 Subject: [PATCH 406/881] Create FUNDING.yml --- .github/FUNDING.yml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .github/FUNDING.yml diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000..2e7a32d9 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,5 @@ +# These are supported funding model platforms + +github: [jinzhu] +patreon: jinzhu +open_collective: gorm From 82d55b105440609d52577c7414ed9e68a503687f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 7 Jun 2020 09:39:39 +0800 Subject: [PATCH 407/881] Add OnConflict DoUpdates test --- clause/on_conflict.go | 10 ++++++++-- clause/set.go | 2 +- tests/upsert_test.go | 24 ++++++++++++++++++++++-- 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/clause/on_conflict.go b/clause/on_conflict.go index 6001399f..47f69fc9 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -14,8 +14,14 @@ func (OnConflict) Name() string { // Build build onConflict clause func (onConflict OnConflict) Build(builder Builder) { if len(onConflict.Columns) > 0 { - builder.WriteQuoted(onConflict.Columns) // FIXME columns - builder.WriteByte(' ') + builder.WriteByte('(') + for idx, column := range onConflict.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + builder.WriteString(`) `) } if len(onConflict.Where.Exprs) > 0 { diff --git a/clause/set.go b/clause/set.go index 4adfe68f..7704ca36 100644 --- a/clause/set.go +++ b/clause/set.go @@ -47,7 +47,7 @@ func Assignments(values map[string]interface{}) Set { for _, key := range keys { assignments = append(assignments, Assignment{ - Column: Column{Table: CurrentTable, Name: key}, + Column: Column{Name: key}, Value: values[key], }) } diff --git a/tests/upsert_test.go b/tests/upsert_test.go index f132a7da..311b7136 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -10,10 +10,14 @@ import ( func TestUpsert(t *testing.T) { lang := Language{Code: "upsert", Name: "Upsert"} - DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang) + if err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } lang2 := Language{Code: "upsert", Name: "Upsert"} - DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang2) + if err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang2).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } var langs []Language if err := DB.Find(&langs, "code = ?", lang.Code).Error; err != nil { @@ -21,6 +25,22 @@ func TestUpsert(t *testing.T) { } else if len(langs) != 1 { t.Errorf("should only find only 1 languages, but got %+v", langs) } + + lang3 := Language{Code: "upsert", Name: "Upsert"} + if err := DB.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "code"}}, + DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}), + }).Create(&lang3).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + if err := DB.Find(&langs, "code = ?", lang.Code).Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs) != 1 { + t.Errorf("should only find only 1 languages, but got %+v", langs) + } else if langs[0].Name != "upsert-new" { + t.Errorf("should update name on conflict, but got name %+v", langs[0].Name) + } } func TestUpsertSlice(t *testing.T) { From 4a4b8234de826dc08d15bc5e8edb7cec42eff56b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 7 Jun 2020 13:16:09 +0800 Subject: [PATCH 408/881] Update issues template --- .github/ISSUE_TEMPLATE.md | 35 ++++++++++++++--------------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index a0b64bfa..74824a19 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -1,4 +1,4 @@ -Your issue may already be reported! Please search on the [issue track](https://github.com/jinzhu/gorm/issues) before creating one. +Your issue may already be reported! Please search on the [issue track](https://github.com/go-gorm/gorm/issues) before creating one. ### What version of Go are you using (`go version`)? @@ -8,34 +8,27 @@ Your issue may already be reported! Please search on the [issue track](https://g ### Please provide a complete runnable program to reproduce your issue. **IMPORTANT** -Need to runnable with [GORM's docker compose config](https://github.com/jinzhu/gorm/blob/master/docker-compose.yml) or please provides your config. +Need to runnable with [GORM's docker compose config](https://github.com/go-gorm/gorm/blob/master/tests/docker-compose.yml) or please provides your config. ```go package main import ( - "github.com/jinzhu/gorm" - _ "github.com/jinzhu/gorm/dialects/mssql" - _ "github.com/jinzhu/gorm/dialects/mysql" - _ "github.com/jinzhu/gorm/dialects/postgres" - _ "github.com/jinzhu/gorm/dialects/sqlite" + "gorm.io/gorm" + "gorm.io/driver/sqlite" +// "gorm.io/driver/mysql" +// "gorm.io/driver/postgres" +// "gorm.io/driver/sqlserver" ) -var db *gorm.DB - -func init() { - var err error - db, err = gorm.Open("sqlite3", "test.db") - // db, err = gorm.Open("postgres", "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable") - // db, err = gorm.Open("mysql", "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True") - // db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm") - if err != nil { - panic(err) - } - db.LogMode(true) -} - func main() { + db, err := gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) + // db, err := gorm.Open(postgres.Open("user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai"), &gorm.Config{}) + // db, err := gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{}) + // db, err := gorm.Open(sqlserver.Open("sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"), &gorm.Config{}) + + /* your code */ + if /* failure condition */ { fmt.Println("failed") } else { From d11c424334b8964e48c4226f0c91ea9e4c062910 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 7 Jun 2020 15:24:34 +0800 Subject: [PATCH 409/881] Fix typo --- callbacks.go | 12 ++++++------ callbacks/helper.go | 2 +- callbacks/update.go | 2 +- tests/associations_belongs_to_test.go | 2 +- tests/associations_has_many_test.go | 12 ++++++------ tests/associations_has_one_test.go | 4 ++-- tests/associations_many2many_test.go | 8 ++++---- tests/migrate_test.go | 2 +- tests/scanner_valuer_test.go | 2 +- 9 files changed, 23 insertions(+), 23 deletions(-) diff --git a/callbacks.go b/callbacks.go index 4f700081..e6cf29af 100644 --- a/callbacks.go +++ b/callbacks.go @@ -15,12 +15,12 @@ import ( func initializeCallbacks(db *DB) *callbacks { return &callbacks{ processors: map[string]*processor{ - "create": &processor{db: db}, - "query": &processor{db: db}, - "update": &processor{db: db}, - "delete": &processor{db: db}, - "row": &processor{db: db}, - "raw": &processor{db: db}, + "create": {db: db}, + "query": {db: db}, + "update": {db: db}, + "delete": {db: db}, + "row": {db: db}, + "raw": {db: db}, }, } } diff --git a/callbacks/helper.go b/callbacks/helper.go index 828e025a..97c8ad35 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -62,7 +62,7 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter selectColumns, restricted := SelectAndOmitColumns(stmt, true, false) var keys []string - for k, _ := range mapValue { + for k := range mapValue { keys = append(keys, k) } sort.Strings(keys) diff --git a/callbacks/update.go b/callbacks/update.go index a41a3c59..f5287dc6 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -141,7 +141,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { set = make([]clause.Assignment, 0, len(value)) var keys []string - for k, _ := range value { + for k := range value { keys = append(keys, k) } sort.Strings(keys) diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index 35419666..1800be91 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -180,7 +180,7 @@ func TestBelongsToAssociationForSlice(t *testing.T) { // Delete if err := DB.Model(&users).Association("Company").Delete(&users[0].Company); err != nil { - t.Errorf("no error should happend when deleting company, but got %v", err) + t.Errorf("no error should happened when deleting company, but got %v", err) } if users[0].CompanyID != nil || users[0].Company.ID != 0 { diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index 7ef0c218..d8befd8a 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -234,13 +234,13 @@ func TestHasManyAssociationForSlice(t *testing.T) { // Delete if err := DB.Model(&users).Association("Pets").Delete(&users[2].Pets); err != nil { - t.Errorf("no error should happend when deleting pet, but got %v", err) + t.Errorf("no error should happened when deleting pet, but got %v", err) } AssertAssociationCount(t, users, "Pets", 4, "after delete") if err := DB.Model(&users).Association("Pets").Delete(users[0].Pets[0], users[1].Pets[1]); err != nil { - t.Errorf("no error should happend when deleting pet, but got %v", err) + t.Errorf("no error should happened when deleting pet, but got %v", err) } AssertAssociationCount(t, users, "Pets", 2, "after delete") @@ -290,13 +290,13 @@ func TestSingleTableHasManyAssociationForSlice(t *testing.T) { // Delete if err := DB.Model(&users).Association("Team").Delete(&users[2].Team); err != nil { - t.Errorf("no error should happend when deleting pet, but got %v", err) + t.Errorf("no error should happened when deleting pet, but got %v", err) } AssertAssociationCount(t, users, "Team", 4, "after delete") if err := DB.Model(&users).Association("Team").Delete(users[0].Team[0], users[1].Team[1]); err != nil { - t.Errorf("no error should happend when deleting pet, but got %v", err) + t.Errorf("no error should happened when deleting pet, but got %v", err) } AssertAssociationCount(t, users, "Team", 2, "after delete") @@ -439,13 +439,13 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { // Delete if err := DB.Model(&users).Association("Toys").Delete(&users[2].Toys); err != nil { - t.Errorf("no error should happend when deleting toy, but got %v", err) + t.Errorf("no error should happened when deleting toy, but got %v", err) } AssertAssociationCount(t, users, "Toys", 4, "after delete") if err := DB.Model(&users).Association("Toys").Delete(users[0].Toys[0], users[1].Toys[1]); err != nil { - t.Errorf("no error should happend when deleting toy, but got %v", err) + t.Errorf("no error should happened when deleting toy, but got %v", err) } AssertAssociationCount(t, users, "Toys", 2, "after delete") diff --git a/tests/associations_has_one_test.go b/tests/associations_has_one_test.go index f32a692d..a6dcc6c5 100644 --- a/tests/associations_has_one_test.go +++ b/tests/associations_has_one_test.go @@ -113,7 +113,7 @@ func TestHasOneAssociationForSlice(t *testing.T) { // Delete if err := DB.Model(&users).Association("Account").Delete(&users[0].Account); err != nil { - t.Errorf("no error should happend when deleting account, but got %v", err) + t.Errorf("no error should happened when deleting account, but got %v", err) } AssertAssociationCount(t, users, "Account", 2, "after delete") @@ -230,7 +230,7 @@ func TestPolymorphicHasOneAssociationForSlice(t *testing.T) { // Delete if err := DB.Model(&pets).Association("Toy").Delete(&pets[0].Toy); err != nil { - t.Errorf("no error should happend when deleting toy, but got %v", err) + t.Errorf("no error should happened when deleting toy, but got %v", err) } AssertAssociationCount(t, pets, "Toy", 2, "after delete") diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index ba9695b7..2ecf7b66 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -147,13 +147,13 @@ func TestMany2ManyAssociationForSlice(t *testing.T) { // Delete if err := DB.Model(&users).Association("Languages").Delete(&users[2].Languages); err != nil { - t.Errorf("no error should happend when deleting language, but got %v", err) + t.Errorf("no error should happened when deleting language, but got %v", err) } AssertAssociationCount(t, users, "Languages", 4, "after delete") if err := DB.Model(&users).Association("Languages").Delete(users[0].Languages[0], users[1].Languages[1]); err != nil { - t.Errorf("no error should happend when deleting language, but got %v", err) + t.Errorf("no error should happened when deleting language, but got %v", err) } AssertAssociationCount(t, users, "Languages", 2, "after delete") @@ -282,13 +282,13 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) { // Delete if err := DB.Model(&users).Association("Team").Delete(&users[2].Team); err != nil { - t.Errorf("no error should happend when deleting team, but got %v", err) + t.Errorf("no error should happened when deleting team, but got %v", err) } AssertAssociationCount(t, users, "Team", 4, "after delete") if err := DB.Model(&users).Association("Team").Delete(users[0].Team[0], users[1].Team[1]); err != nil { - t.Errorf("no error should happend when deleting team, but got %v", err) + t.Errorf("no error should happened when deleting team, but got %v", err) } AssertAssociationCount(t, users, "Team", 2, "after delete") diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 5293898f..194b5cbf 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -131,7 +131,7 @@ func TestColumns(t *testing.T) { } if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct2{}, "Name"); err != nil { - t.Fatalf("no error should happend when alter column, but got %v", err) + t.Fatalf("no error should happened when alter column, but got %v", err) } if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 7d72db15..ec228f00 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -38,7 +38,7 @@ func TestScannerValuer(t *testing.T) { } if err := DB.Create(&data).Error; err != nil { - t.Errorf("No error should happend when create scanner valuer struct, but got %v", err) + t.Errorf("No error should happened when create scanner valuer struct, but got %v", err) } var result ScannerValuerStruct From 31a0553b8211c3b6d36ff160ea6df08377c2058b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 7 Jun 2020 18:27:12 +0800 Subject: [PATCH 410/881] Fix FileWithLineNum on windows --- utils/utils.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/utils.go b/utils/utils.go index ce42b218..81d2dc34 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -3,8 +3,8 @@ package utils import ( "database/sql/driver" "fmt" - "path/filepath" "reflect" + "regexp" "runtime" "strconv" "strings" @@ -15,7 +15,7 @@ var gormSourceDir string func init() { _, file, _, _ := runtime.Caller(0) - gormSourceDir = filepath.Dir(filepath.Dir(file)) + gormSourceDir = regexp.MustCompile("utils.utils\\.go").ReplaceAllString(file, "") } func FileWithLineNum() string { @@ -23,7 +23,7 @@ func FileWithLineNum() string { _, file, line, ok := runtime.Caller(i) if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) { - return fmt.Sprintf("%v:%v", file, line) + return file + ":" + strconv.FormatInt(int64(line), 10) } } return "" From e7b2e92ce3d3c60fc73509fd53746ec70aaae7c3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 7 Jun 2020 22:03:45 +0800 Subject: [PATCH 411/881] Remove RecordNotFound method --- finisher_api.go | 4 ---- tests/delete_test.go | 4 ++-- tests/soft_delete_test.go | 4 +++- tests/update_test.go | 2 +- tests/upsert_test.go | 4 ++-- 5 files changed, 8 insertions(+), 10 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 434f0e22..72453b1d 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -389,7 +389,3 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx.callbacks.Raw().Execute(tx) return } - -func (db *DB) RecordNotFound() bool { - return errors.Is(db.Error, ErrRecordNotFound) -} diff --git a/tests/delete_test.go b/tests/delete_test.go index 66c396d1..b853a9d3 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -52,13 +52,13 @@ func TestInlineCondDelete(t *testing.T) { if DB.Delete(&User{}, user1.ID).Error != nil { t.Errorf("No error should happen when delete a record") - } else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { + } else if err := DB.Where("name = ?", user1.Name).First(&User{}).Error; !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("User can't be found after delete") } if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil { t.Errorf("No error should happen when delete a record, err=%s", err) - } else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { + } else if err := DB.Where("name = ?", user2.Name).First(&User{}).Error; !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("User can't be found after delete") } } diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go index c632c753..b6dabe06 100644 --- a/tests/soft_delete_test.go +++ b/tests/soft_delete_test.go @@ -1,8 +1,10 @@ package tests_test import ( + "errors" "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -22,7 +24,7 @@ func TestSoftDelete(t *testing.T) { } DB.Unscoped().Delete(&user) - if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() { + if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("Can't find permanently deleted record") } } diff --git a/tests/update_test.go b/tests/update_test.go index 220d3e76..d56e3f76 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -149,7 +149,7 @@ func TestUpdates(t *testing.T) { DB.Table("users").Where("name in ?", []string{users[1].Name}).Updates(User{Name: "updates_02_newname"}) var user3 User - if DB.First(&user3, "name = ?", "updates_02_newname").RecordNotFound() { + if err := DB.First(&user3, "name = ?", "updates_02_newname").Error; err != nil { t.Errorf("User2's name should be updated") } AssertEqual(t, user2.UpdatedAt, user3.UpdatedAt) diff --git a/tests/upsert_test.go b/tests/upsert_test.go index 311b7136..e9ba54e3 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -171,11 +171,11 @@ func TestFindOrCreate(t *testing.T) { } DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, Account: Account{Number: "1231231231"}, Pets: []*Pet{{Name: "first_or_create_pet1"}, {Name: "first_or_create_pet2"}}}).FirstOrCreate(&user8) - if DB.Where("name = ?", "first_or_create_pet1").First(&Pet{}).RecordNotFound() { + if err := DB.Where("name = ?", "first_or_create_pet1").First(&Pet{}).Error; err != nil { t.Errorf("has many association should be saved") } - if DB.Where("number = ?", "1231231231").First(&Account{}).RecordNotFound() { + if err := DB.Where("number = ?", "1231231231").First(&Account{}).Error; err != nil { t.Errorf("belongs to association should be saved") } } From 72d0fa61960c5c2472b561e3945654b3f020a233 Mon Sep 17 00:00:00 2001 From: Douglas Danger Manley Date: Sun, 7 Jun 2020 16:41:54 -0400 Subject: [PATCH 412/881] Fix Statement Where clone array corruption in v2 Method-chaining in gorm is predicated on a `Clause`'s `MergeClause` method ensuring that the two clauses are disconnected in terms of pointers (at least in the Wherec case). However, the original Where implementation used `append`, which only returns a new instance if the backing array needs to be resized. In some cases, this is true. Practically, go doubles the size of the slice once it gets full, so the following slice `append` calls would result in a new slice: * 0 -> 1 * 1 -> 2 * 2 -> 4 * 4 -> 8 * and so on. So, when the number of "where" conditions was 0, 1, 2, or 4, method-chaining would work as expected. However, when it was 3, 5, 6, or 7, modifying the copy would modify the original. This also updates the "order by", "group by" and "set" clauses. --- clause/group_by.go | 9 +++++++-- clause/order_by.go | 4 +++- clause/set.go | 4 +++- clause/where.go | 4 +++- statement_test.go | 37 +++++++++++++++++++++++++++++++++++++ 5 files changed, 53 insertions(+), 5 deletions(-) create mode 100644 statement_test.go diff --git a/clause/group_by.go b/clause/group_by.go index c1383c36..88231916 100644 --- a/clause/group_by.go +++ b/clause/group_by.go @@ -30,8 +30,13 @@ func (groupBy GroupBy) Build(builder Builder) { // MergeClause merge group by clause func (groupBy GroupBy) MergeClause(clause *Clause) { if v, ok := clause.Expression.(GroupBy); ok { - groupBy.Columns = append(v.Columns, groupBy.Columns...) - groupBy.Having = append(v.Having, groupBy.Having...) + copiedColumns := make([]Column, len(v.Columns)) + copy(copiedColumns, v.Columns) + groupBy.Columns = append(copiedColumns, groupBy.Columns...) + + copiedHaving := make([]Expression, len(v.Having)) + copy(copiedHaving, v.Having) + groupBy.Having = append(copiedHaving, groupBy.Having...) } clause.Expression = groupBy } diff --git a/clause/order_by.go b/clause/order_by.go index 307bf930..a8a9539a 100644 --- a/clause/order_by.go +++ b/clause/order_by.go @@ -40,7 +40,9 @@ func (orderBy OrderBy) MergeClause(clause *Clause) { } } - orderBy.Columns = append(v.Columns, orderBy.Columns...) + copiedColumns := make([]OrderByColumn, len(v.Columns)) + copy(copiedColumns, v.Columns) + orderBy.Columns = append(copiedColumns, orderBy.Columns...) } clause.Expression = orderBy diff --git a/clause/set.go b/clause/set.go index 7704ca36..2d3965d3 100644 --- a/clause/set.go +++ b/clause/set.go @@ -32,7 +32,9 @@ func (set Set) Build(builder Builder) { // MergeClause merge assignments clauses func (set Set) MergeClause(clause *Clause) { - clause.Expression = set + copiedAssignments := make([]Assignment, len(set)) + copy(copiedAssignments, set) + clause.Expression = Set(copiedAssignments) } func Assignments(values map[string]interface{}) Set { diff --git a/clause/where.go b/clause/where.go index 015addf8..806565d1 100644 --- a/clause/where.go +++ b/clause/where.go @@ -40,7 +40,9 @@ func (where Where) Build(builder Builder) { // MergeClause merge where clauses func (where Where) MergeClause(clause *Clause) { if w, ok := clause.Expression.(Where); ok { - where.Exprs = append(w.Exprs, where.Exprs...) + copiedExpressions := make([]Expression, len(w.Exprs)) + copy(copiedExpressions, w.Exprs) + where.Exprs = append(copiedExpressions, where.Exprs...) } clause.Expression = where diff --git a/statement_test.go b/statement_test.go new file mode 100644 index 00000000..7d730875 --- /dev/null +++ b/statement_test.go @@ -0,0 +1,37 @@ +package gorm + +import ( + "fmt" + "reflect" + "testing" + + "gorm.io/gorm/clause" +) + +func TestWhereCloneCorruption(t *testing.T) { + for whereCount := 1; whereCount <= 8; whereCount++ { + t.Run(fmt.Sprintf("w=%d", whereCount), func(t *testing.T) { + s := new(Statement) + for w := 0; w < whereCount; w++ { + s = s.clone() + s.AddClause(clause.Where{ + Exprs: s.BuildCondtion(fmt.Sprintf("where%d", w)), + }) + } + + s1 := s.clone() + s1.AddClause(clause.Where{ + Exprs: s.BuildCondtion("FINAL1"), + }) + s2 := s.clone() + s2.AddClause(clause.Where{ + Exprs: s.BuildCondtion("FINAL2"), + }) + + if reflect.DeepEqual(s1.Clauses["WHERE"], s2.Clauses["WHERE"]) { + t.Errorf("Where conditions should be different") + } + }) + } +} + From 8f8d549ca36d34a1f1dbbbd422071990e9b8a78d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 8 Jun 2020 09:10:27 +0800 Subject: [PATCH 413/881] Refactor merge where exprs --- clause/where.go | 7 ++++--- statement_test.go | 3 +-- tests/named_polymorphic_test.go | 1 + 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/clause/where.go b/clause/where.go index 806565d1..6399a2d5 100644 --- a/clause/where.go +++ b/clause/where.go @@ -40,9 +40,10 @@ func (where Where) Build(builder Builder) { // MergeClause merge where clauses func (where Where) MergeClause(clause *Clause) { if w, ok := clause.Expression.(Where); ok { - copiedExpressions := make([]Expression, len(w.Exprs)) - copy(copiedExpressions, w.Exprs) - where.Exprs = append(copiedExpressions, where.Exprs...) + exprs := make([]Expression, len(w.Exprs)+len(where.Exprs)) + copy(exprs, w.Exprs) + copy(exprs[len(w.Exprs):], where.Exprs) + where.Exprs = exprs } clause.Expression = where diff --git a/statement_test.go b/statement_test.go index 7d730875..16956e85 100644 --- a/statement_test.go +++ b/statement_test.go @@ -4,7 +4,7 @@ import ( "fmt" "reflect" "testing" - + "gorm.io/gorm/clause" ) @@ -34,4 +34,3 @@ func TestWhereCloneCorruption(t *testing.T) { }) } } - diff --git a/tests/named_polymorphic_test.go b/tests/named_polymorphic_test.go index 61655784..cbe236b5 100644 --- a/tests/named_polymorphic_test.go +++ b/tests/named_polymorphic_test.go @@ -14,6 +14,7 @@ type Hamster struct { } func TestNamedPolymorphic(t *testing.T) { + DB.Migrator().DropTable(&Hamster{}) DB.AutoMigrate(&Hamster{}) hamster := Hamster{Name: "Mr. Hammond", PreferredToy: Toy{Name: "bike"}, OtherToy: Toy{Name: "treadmill"}} From 13f96f7a158193f22d03419a5b1c0fd4c6c59f55 Mon Sep 17 00:00:00 2001 From: Douglas Danger Manley Date: Sun, 7 Jun 2020 23:38:51 -0400 Subject: [PATCH 414/881] Spelling fix for "condtion" -> "condition" (#3042) This fixes a spelling error in the word "condition"; in particular, the `BuildCondtion` function should be named `BuildCondition`. --- chainable_api.go | 10 +++++----- finisher_api.go | 20 ++++++++++---------- statement.go | 4 ++-- statement_test.go | 6 +++--- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 6c5a6f77..0be86e03 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -33,7 +33,7 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { } if len(whereConds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...)}) + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)}) } return } @@ -121,7 +121,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) { // Where add conditions func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - if conds := tx.Statement.BuildCondtion(query, args...); len(conds) > 0 { + if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: conds}) } return @@ -130,7 +130,7 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { // Not add NOT conditions func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - if conds := tx.Statement.BuildCondtion(query, args...); len(conds) > 0 { + if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}}) } return @@ -139,7 +139,7 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { // Or add OR conditions func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - if conds := tx.Statement.BuildCondtion(query, args...); len(conds) > 0 { + if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(conds...)}}) } return @@ -170,7 +170,7 @@ func (db *DB) Group(name string) (tx *DB) { func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.GroupBy{ - Having: tx.Statement.BuildCondtion(query, args...), + Having: tx.Statement.BuildCondition(query, args...), }) return } diff --git a/finisher_api.go b/finisher_api.go index 72453b1d..84890b51 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -55,7 +55,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) if len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = dest @@ -67,7 +67,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1) if len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = dest @@ -82,7 +82,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { Desc: true, }) if len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = dest @@ -94,7 +94,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) } tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) @@ -130,7 +130,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { // initialize with attrs, conds if len(tx.Statement.attrs) > 0 { - exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) + exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) tx.assignExprsToValue(exprs) } tx.Error = nil @@ -138,7 +138,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { // initialize with attrs, conds if len(tx.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) + exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) tx.assignExprsToValue(exprs) } return @@ -157,19 +157,19 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { // initialize with attrs, conds if len(tx.Statement.attrs) > 0 { - exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) + exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) tx.assignExprsToValue(exprs) } // initialize with attrs, conds if len(tx.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) + exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) tx.assignExprsToValue(exprs) } return tx.Create(dest) } else if len(tx.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) + exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) assigns := map[string]interface{}{} for _, expr := range exprs { if eq, ok := expr.(clause.Eq); ok { @@ -225,7 +225,7 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) { func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) } tx.Statement.Dest = value tx.callbacks.Delete().Execute(tx) diff --git a/statement.go b/statement.go index ec9e021f..614a3ad3 100644 --- a/statement.go +++ b/statement.go @@ -218,8 +218,8 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { } } -// BuildCondtion build condition -func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conds []clause.Expression) { +// BuildCondition build condition +func (stmt Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) { if sql, ok := query.(string); ok { // if it is a number, then treats it as primary key if _, err := strconv.Atoi(sql); err != nil { diff --git a/statement_test.go b/statement_test.go index 16956e85..03ad81dc 100644 --- a/statement_test.go +++ b/statement_test.go @@ -15,17 +15,17 @@ func TestWhereCloneCorruption(t *testing.T) { for w := 0; w < whereCount; w++ { s = s.clone() s.AddClause(clause.Where{ - Exprs: s.BuildCondtion(fmt.Sprintf("where%d", w)), + Exprs: s.BuildCondition(fmt.Sprintf("where%d", w)), }) } s1 := s.clone() s1.AddClause(clause.Where{ - Exprs: s.BuildCondtion("FINAL1"), + Exprs: s.BuildCondition("FINAL1"), }) s2 := s.clone() s2.AddClause(clause.Where{ - Exprs: s.BuildCondtion("FINAL2"), + Exprs: s.BuildCondition("FINAL2"), }) if reflect.DeepEqual(s1.Clauses["WHERE"], s2.Clauses["WHERE"]) { From aaf07257719d4b7e85574ffc6fd6546f364b492e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 8 Jun 2020 13:45:41 +0800 Subject: [PATCH 415/881] Refactor for performance --- callbacks/create.go | 7 ++- callbacks/query.go | 108 ++++++++++++++++++++------------------------ callbacks/update.go | 2 +- clause/set.go | 13 ++---- gorm.go | 79 +++++++++++++++----------------- migrator.go | 5 ++ scan.go | 31 ++++++++----- statement.go | 8 ++-- 8 files changed, 123 insertions(+), 130 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index ec4ee1d1..6dc3f10a 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -192,19 +192,22 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { return ConvertSliceOfMapToValuesForCreate(stmt, value) default: var ( - values = clause.Values{} + values = clause.Values{Columns: make([]clause.Column, len(stmt.Schema.DBNames))} selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) curTime = stmt.DB.NowFunc() isZero = false ) + var columns int for _, db := range stmt.Schema.DBNames { if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { - values.Columns = append(values.Columns, clause.Column{Name: db}) + values.Columns[columns] = clause.Column{Name: db} + columns++ } } } + values.Columns = values.Columns[:columns] switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: diff --git a/callbacks/query.go b/callbacks/query.go index 41f09375..571c7245 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -53,38 +53,28 @@ func BuildQuerySQL(db *gorm.DB) { } if len(db.Statement.Selects) > 0 { - for _, name := range db.Statement.Selects { + clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects)) + for idx, name := range db.Statement.Selects { if db.Statement.Schema == nil { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Name: name, - Raw: true, - }) + clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} } else if f := db.Statement.Schema.LookUpField(name); f != nil { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Name: f.DBName, - }) + clauseSelect.Columns[idx] = clause.Column{Name: f.DBName} } else { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Name: name, - Raw: true, - }) + clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} } } } // inline joins if len(db.Statement.Joins) != 0 { - joins := []clause.Join{} - if len(db.Statement.Selects) == 0 { - for _, dbName := range db.Statement.Schema.DBNames { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: db.Statement.Table, - Name: dbName, - }) + clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) + for idx, dbName := range db.Statement.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} } } + joins := []clause.Join{} for name, conds := range db.Statement.Joins { if db.Statement.Schema == nil { joins = append(joins, clause.Join{ @@ -101,24 +91,24 @@ func BuildQuerySQL(db *gorm.DB) { }) } - var exprs []clause.Expression - for _, ref := range relation.References { + exprs := make([]clause.Expression, len(relation.References)) + for idx, ref := range relation.References { if ref.OwnPrimaryKey { - exprs = append(exprs, clause.Eq{ + exprs[idx] = clause.Eq{ Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.PrimaryKey.DBName}, Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - }) + } } else { if ref.PrimaryValue == "" { - exprs = append(exprs, clause.Eq{ + exprs[idx] = clause.Eq{ Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.ForeignKey.DBName}, Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, - }) + } } else { - exprs = append(exprs, clause.Eq{ + exprs[idx] = clause.Eq{ Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, Value: ref.PrimaryValue, - }) + } } } } @@ -146,42 +136,40 @@ func BuildQuerySQL(db *gorm.DB) { } func Preload(db *gorm.DB) { - if db.Error == nil { - if len(db.Statement.Preloads) > 0 { - preloadMap := map[string][]string{} - for name := range db.Statement.Preloads { - preloadFields := strings.Split(name, ".") - for idx := range preloadFields { - preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] + if db.Error == nil && len(db.Statement.Preloads) > 0 { + preloadMap := map[string][]string{} + for name := range db.Statement.Preloads { + preloadFields := strings.Split(name, ".") + for idx := range preloadFields { + preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] + } + } + + preloadNames := make([]string, len(preloadMap)) + idx := 0 + for key := range preloadMap { + preloadNames[idx] = key + idx++ + } + sort.Strings(preloadNames) + + for _, name := range preloadNames { + var ( + curSchema = db.Statement.Schema + preloadFields = preloadMap[name] + rels = make([]*schema.Relationship, len(preloadFields)) + ) + + for idx, preloadField := range preloadFields { + if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { + rels[idx] = rel + curSchema = rel.FieldSchema + } else { + db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) } } - preloadNames := make([]string, len(preloadMap)) - idx := 0 - for key := range preloadMap { - preloadNames[idx] = key - idx++ - } - sort.Strings(preloadNames) - - for _, name := range preloadNames { - var ( - curSchema = db.Statement.Schema - preloadFields = preloadMap[name] - rels = make([]*schema.Relationship, len(preloadFields)) - ) - - for idx, preloadField := range preloadFields { - if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { - rels[idx] = rel - curSchema = rel.FieldSchema - } else { - db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) - } - } - - preload(db, rels, db.Statement.Preloads[name]) - } + preload(db, rels, db.Statement.Preloads[name]) } } } diff --git a/callbacks/update.go b/callbacks/update.go index f5287dc6..4ef33598 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -140,7 +140,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { case map[string]interface{}: set = make([]clause.Assignment, 0, len(value)) - var keys []string + keys := make([]string, 0, len(value)) for k := range value { keys = append(keys, k) } diff --git a/clause/set.go b/clause/set.go index 2d3965d3..1c2a9ef2 100644 --- a/clause/set.go +++ b/clause/set.go @@ -38,20 +38,15 @@ func (set Set) MergeClause(clause *Clause) { } func Assignments(values map[string]interface{}) Set { - var keys []string - var assignments []Assignment - + keys := make([]string, 0, len(values)) for key := range values { keys = append(keys, key) } - sort.Strings(keys) - for _, key := range keys { - assignments = append(assignments, Assignment{ - Column: Column{Name: key}, - Value: values[key], - }) + assignments := make([]Assignment, len(keys)) + for idx, key := range keys { + assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]} } return assignments } diff --git a/gorm.go b/gorm.go index cea744f7..0de6860b 100644 --- a/gorm.go +++ b/gorm.go @@ -205,53 +205,11 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) { return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) } -func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { - var ( - tx = db.getInstance() - stmt = tx.Statement - modelSchema, joinSchema *schema.Schema - ) - - if err := stmt.Parse(model); err == nil { - modelSchema = stmt.Schema - } else { - return err - } - - if err := stmt.Parse(joinTable); err == nil { - joinSchema = stmt.Schema - } else { - return err - } - - if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil { - for _, ref := range relation.References { - if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { - f.DataType = ref.ForeignKey.DataType - ref.ForeignKey = f - } else { - return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) - } - } - - relation.JoinTable = joinSchema - } else { - return fmt.Errorf("failed to found relation: %v", field) - } - - return nil -} - // Callback returns callback manager func (db *DB) Callback() *callbacks { return db.callbacks } -// AutoMigrate run auto migration for given models -func (db *DB) AutoMigrate(dst ...interface{}) error { - return db.Migrator().AutoMigrate(dst...) -} - // AddError add error to db func (db *DB) AddError(err error) error { if db.Error == nil { @@ -289,3 +247,40 @@ func (db *DB) getInstance() *DB { func Expr(expr string, args ...interface{}) clause.Expr { return clause.Expr{SQL: expr, Vars: args} } + +func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { + var ( + tx = db.getInstance() + stmt = tx.Statement + modelSchema, joinSchema *schema.Schema + ) + + if err := stmt.Parse(model); err == nil { + modelSchema = stmt.Schema + } else { + return err + } + + if err := stmt.Parse(joinTable); err == nil { + joinSchema = stmt.Schema + } else { + return err + } + + if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil { + for _, ref := range relation.References { + if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { + f.DataType = ref.ForeignKey.DataType + ref.ForeignKey = f + } else { + return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) + } + } + + relation.JoinTable = joinSchema + } else { + return fmt.Errorf("failed to found relation: %v", field) + } + + return nil +} diff --git a/migrator.go b/migrator.go index 865a08ef..d45e3ac2 100644 --- a/migrator.go +++ b/migrator.go @@ -9,6 +9,11 @@ func (db *DB) Migrator() Migrator { return db.Dialector.Migrator(db) } +// AutoMigrate run auto migration for given models +func (db *DB) AutoMigrate(dst ...interface{}) error { + return db.Migrator().AutoMigrate(dst...) +} + // ViewOption view option type ViewOption struct { Replace bool diff --git a/scan.go b/scan.go index acba4e9f..f1cdb2e5 100644 --- a/scan.go +++ b/scan.go @@ -71,20 +71,27 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { default: switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - reflectValueType := db.Statement.ReflectValue.Type().Elem() - isPtr := reflectValueType.Kind() == reflect.Ptr + var ( + reflectValueType = db.Statement.ReflectValue.Type().Elem() + isPtr = reflectValueType.Kind() == reflect.Ptr + fields = make([]*schema.Field, len(columns)) + joinFields [][2]*schema.Field + ) + if isPtr { reflectValueType = reflectValueType.Elem() } db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) - fields := make([]*schema.Field, len(columns)) - joinFields := make([][2]*schema.Field, len(columns)) for idx, column := range columns { if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { fields[idx] = field } else if names := strings.Split(column, "__"); len(names) > 1 { + if len(joinFields) == 0 { + joinFields = make([][2]*schema.Field, len(columns)) + } + if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { fields[idx] = field @@ -98,26 +105,26 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } } + // pluck values into slice of data + isPluck := len(fields) == 1 && reflectValueType.Kind() != reflect.Struct for initialized || rows.Next() { initialized = false db.RowsAffected++ elem := reflect.New(reflectValueType).Elem() - if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 { - // pluck - values[0] = elem.Addr().Interface() - db.AddError(rows.Scan(values...)) + if isPluck { + db.AddError(rows.Scan(elem.Addr().Interface())) } else { for idx, field := range fields { if field != nil { - values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() } } db.AddError(rows.Scan(values...)) for idx, field := range fields { - if joinFields[idx][0] != nil { + if len(joinFields) != 0 && joinFields[idx][0] != nil { value := reflect.ValueOf(values[idx]).Elem() relValue := joinFields[idx][0].ReflectValueOf(elem) @@ -145,11 +152,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { if initialized || rows.Next() { for idx, column := range columns { if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() continue } } diff --git a/statement.go b/statement.go index 614a3ad3..e0e86019 100644 --- a/statement.go +++ b/statement.go @@ -63,7 +63,7 @@ func (stmt *Statement) WriteQuoted(value interface{}) error { } // QuoteTo write quoted value to writer -func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { +func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { switch v := field.(type) { case clause.Table: if v.Name == clause.CurrentTable { @@ -109,7 +109,7 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { case []string: writer.WriteByte('(') for idx, d := range v { - if idx != 0 { + if idx > 0 { writer.WriteString(",") } stmt.DB.Dialector.QuoteTo(writer, d) @@ -121,7 +121,7 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { } // Quote returns quoted value -func (stmt Statement) Quote(field interface{}) string { +func (stmt *Statement) Quote(field interface{}) string { var builder strings.Builder stmt.QuoteTo(&builder, field) return builder.String() @@ -219,7 +219,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { } // BuildCondition build condition -func (stmt Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) { +func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) { if sql, ok := query.(string); ok { // if it is a number, then treats it as primary key if _, err := strconv.Atoi(sql); err != nil { From 9f193783049d88aaa3ff9153c040dcac27fa6559 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 8 Jun 2020 20:23:47 +0800 Subject: [PATCH 416/881] Grow SQL capacity to reduce allocation --- callbacks/create.go | 2 ++ callbacks/delete.go | 1 + callbacks/query.go | 1 + callbacks/update.go | 1 + 4 files changed, 5 insertions(+) diff --git a/callbacks/create.go b/callbacks/create.go index 6dc3f10a..cb161061 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -42,6 +42,7 @@ func Create(config *Config) func(db *gorm.DB) { } if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Insert{ Table: clause.Table{Name: db.Statement.Table}, }) @@ -211,6 +212,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: + stmt.SQL.Grow(stmt.ReflectValue.Len() * 15) values.Values = make([][]interface{}, stmt.ReflectValue.Len()) defaultValueFieldsHavingValue := map[string][]interface{}{} for i := 0; i < stmt.ReflectValue.Len(); i++ { diff --git a/callbacks/delete.go b/callbacks/delete.go index b246e69f..dea8bb5e 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -30,6 +30,7 @@ func Delete(db *gorm.DB) { } if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(100) db.Statement.AddClauseIfNotExists(clause.Delete{}) if db.Statement.Schema != nil { diff --git a/callbacks/query.go b/callbacks/query.go index 571c7245..e5557d4a 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -37,6 +37,7 @@ func Query(db *gorm.DB) { } func BuildQuerySQL(db *gorm.DB) { + db.Statement.SQL.Grow(100) clauseSelect := clause.Select{Distinct: db.Statement.Distinct} if db.Statement.ReflectValue.Kind() == reflect.Struct { diff --git a/callbacks/update.go b/callbacks/update.go index 4ef33598..03d5c1e9 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -59,6 +59,7 @@ func Update(db *gorm.DB) { } if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Update{}) if set := ConvertToAssignments(db.Statement); len(set) != 0 { db.Statement.AddClause(set) From 4555796b62fa679f3397d5201759e387f7d88a0c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 8 Jun 2020 22:32:35 +0800 Subject: [PATCH 417/881] Refactor Execute callbacks --- callbacks.go | 50 ++++++++++++++++++++++++------------------------- finisher_api.go | 16 +++++++--------- 2 files changed, 31 insertions(+), 35 deletions(-) diff --git a/callbacks.go b/callbacks.go index e6cf29af..5e7933af 100644 --- a/callbacks.go +++ b/callbacks.go @@ -73,26 +73,26 @@ func (cs *callbacks) Raw() *processor { func (p *processor) Execute(db *DB) { curTime := time.Now() + stmt := db.Statement db.RowsAffected = 0 - if stmt := db.Statement; stmt != nil { - if stmt.Model == nil { - stmt.Model = stmt.Dest - } - if stmt.Model != nil { - if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { - db.AddError(err) - } - } + if stmt.Model == nil { + stmt.Model = stmt.Dest + } - if stmt.Dest != nil { - stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest)) - for stmt.ReflectValue.Kind() == reflect.Ptr { - stmt.ReflectValue = stmt.ReflectValue.Elem() - } - if !stmt.ReflectValue.IsValid() { - db.AddError(fmt.Errorf("invalid value")) - } + if stmt.Model != nil { + if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { + db.AddError(err) + } + } + + if stmt.Dest != nil { + stmt.ReflectValue = reflect.ValueOf(stmt.Dest) + for stmt.ReflectValue.Kind() == reflect.Ptr { + stmt.ReflectValue = stmt.ReflectValue.Elem() + } + if !stmt.ReflectValue.IsValid() { + db.AddError(fmt.Errorf("invalid value")) } } @@ -100,16 +100,14 @@ func (p *processor) Execute(db *DB) { f(db) } - if stmt := db.Statement; stmt != nil { - db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { - return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected - }, db.Error) + db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { + return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected + }, db.Error) - if !stmt.DB.DryRun { - stmt.SQL.Reset() - stmt.Vars = nil - stmt.NamedVars = nil - } + if !stmt.DB.DryRun { + stmt.SQL.Reset() + stmt.Vars = nil + stmt.NamedVars = nil } } diff --git a/finisher_api.go b/finisher_api.go index 84890b51..fc21e490 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -51,7 +51,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { // First find first record that match given conditions, order by primary key func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { - tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ + tx = db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) if len(conds) > 0 { @@ -65,7 +65,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { // Take return a record that match given conditions, the order will depend on the database implementation func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { - tx = db.getInstance().Limit(1) + tx = db.Limit(1) if len(conds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) } @@ -77,7 +77,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { // Last find last record that match given conditions, order by primary key func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { - tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ + tx = db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Desc: true, }) @@ -120,8 +120,7 @@ func (tx *DB) assignExprsToValue(exprs []clause.Expression) { } func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { - tx = db.getInstance() - if tx = tx.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { + if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { tx.assignExprsToValue(where.Exprs) @@ -145,8 +144,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { } func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { - tx = db.getInstance() - if err := tx.First(dest, conds...).Error; errors.Is(err, ErrRecordNotFound) { + if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { tx.Error = nil if c, ok := tx.Statement.Clauses["WHERE"]; ok { @@ -168,7 +166,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { } return tx.Create(dest) - } else if len(tx.Statement.assigns) > 0 { + } else if len(db.Statement.assigns) > 0 { exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) assigns := map[string]interface{}{} for _, expr := range exprs { @@ -186,7 +184,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { return tx.Model(dest).Updates(assigns) } - return + return db } // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update From f0b6bd9ee04691c7f6285c8d597dd630289a33b8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 8 Jun 2020 23:25:16 +0800 Subject: [PATCH 418/881] Fix typo --- tests/transaction_test.go | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 0c04e2ed..592f1321 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -10,13 +10,13 @@ import ( func TestTransaction(t *testing.T) { tx := DB.Begin() - user := *GetUser("transcation", Config{}) + user := *GetUser("transaction", Config{}) if err := tx.Save(&user).Error; err != nil { t.Fatalf("No error should raise, but got %v", err) } - if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { + if err := tx.First(&User{}, "name = ?", "transaction").Error; err != nil { t.Fatalf("Should find saved record, but got %v", err) } @@ -26,23 +26,23 @@ func TestTransaction(t *testing.T) { tx.Rollback() - if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil { + if err := DB.First(&User{}, "name = ?", "transaction").Error; err == nil { t.Fatalf("Should not find record after rollback, but got %v", err) } tx2 := DB.Begin() - user2 := *GetUser("transcation-2", Config{}) + user2 := *GetUser("transaction-2", Config{}) if err := tx2.Save(&user2).Error; err != nil { t.Fatalf("No error should raise, but got %v", err) } - if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + if err := tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil { t.Fatalf("Should find saved record, but got %v", err) } tx2.Commit() - if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + if err := DB.First(&User{}, "name = ?", "transaction-2").Error; err != nil { t.Fatalf("Should be able to find committed record, but got %v", err) } } @@ -59,7 +59,7 @@ func TestTransactionWithBlock(t *testing.T) { // rollback err := DB.Transaction(func(tx *gorm.DB) error { - user := *GetUser("transcation-block", Config{}) + user := *GetUser("transaction-block", Config{}) if err := tx.Save(&user).Error; err != nil { t.Fatalf("No error should raise") } @@ -75,13 +75,13 @@ func TestTransactionWithBlock(t *testing.T) { t.Fatalf("Transaction return error will equal the block returns error") } - if err := DB.First(&User{}, "name = ?", "transcation-block").Error; err == nil { + if err := DB.First(&User{}, "name = ?", "transaction-block").Error; err == nil { t.Fatalf("Should not find record after rollback") } // commit DB.Transaction(func(tx *gorm.DB) error { - user := *GetUser("transcation-block-2", Config{}) + user := *GetUser("transaction-block-2", Config{}) if err := tx.Save(&user).Error; err != nil { t.Fatalf("No error should raise") } @@ -92,14 +92,14 @@ func TestTransactionWithBlock(t *testing.T) { return nil }) - if err := DB.First(&User{}, "name = ?", "transcation-block-2").Error; err != nil { + if err := DB.First(&User{}, "name = ?", "transaction-block-2").Error; err != nil { t.Fatalf("Should be able to find committed record") } // panic will rollback assertPanic(func() { DB.Transaction(func(tx *gorm.DB) error { - user := *GetUser("transcation-block-3", Config{}) + user := *GetUser("transaction-block-3", Config{}) if err := tx.Save(&user).Error; err != nil { t.Fatalf("No error should raise") } @@ -112,14 +112,14 @@ func TestTransactionWithBlock(t *testing.T) { }) }) - if err := DB.First(&User{}, "name = ?", "transcation-block-3").Error; err == nil { + if err := DB.First(&User{}, "name = ?", "transaction-block-3").Error; err == nil { t.Fatalf("Should not find record after panic rollback") } } func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) { tx := DB.Begin() - user := User{Name: "transcation"} + user := User{Name: "transaction"} if err := tx.Save(&user).Error; err != nil { t.Fatalf("No error should raise") } From 649d02fddd31fe82cd8ecbe6ab63e4ab61a5be4d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 9 Jun 2020 09:04:25 +0800 Subject: [PATCH 419/881] Add batch upsert tests --- clause/set.go | 8 ++++++++ tests/go.mod | 4 ++-- tests/upsert_test.go | 23 +++++++++++++++++++++++ 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/clause/set.go b/clause/set.go index 1c2a9ef2..6a885711 100644 --- a/clause/set.go +++ b/clause/set.go @@ -50,3 +50,11 @@ func Assignments(values map[string]interface{}) Set { } return assignments } + +func AssignmentColumns(values []string) Set { + assignments := make([]Assignment, len(values)) + for idx, value := range values { + assignments[idx] = Assignment{Column: Column{Name: value}, Value: Column{Table: "excluded", Name: value}} + } + return assignments +} diff --git a/tests/go.mod b/tests/go.mod index 3c2dfc6c..c184732c 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -4,10 +4,10 @@ go 1.14 require ( github.com/jinzhu/now v1.1.1 - gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0 + gorm.io/driver/mysql v0.0.0-20200609004954-b8310c61c3f2 gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 gorm.io/driver/sqlite v1.0.0 - gorm.io/driver/sqlserver v0.0.0-20200605135528-04ae0f7a15bf + gorm.io/driver/sqlserver v0.0.0-20200609005334-d550a0be1cfb gorm.io/gorm v0.0.0-00010101000000-000000000000 ) diff --git a/tests/upsert_test.go b/tests/upsert_test.go index e9ba54e3..a1307e32 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -65,6 +65,29 @@ func TestUpsertSlice(t *testing.T) { } else if len(langs3) != 3 { t.Errorf("should only find only 3 languages, but got %+v", langs3) } + + for idx, lang := range langs { + lang.Name = lang.Name + "_new" + langs[idx] = lang + } + + if err := DB.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "code"}}, + DoUpdates: clause.AssignmentColumns([]string{"name"}), + }).Create(&langs).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + for _, lang := range langs { + var results []Language + if err := DB.Find(&results, "code = ?", lang.Code).Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(results) != 1 { + t.Errorf("should only find only 1 languages, but got %+v", langs) + } else if results[0].Name != lang.Name { + t.Errorf("should update name on conflict, but got name %+v", results[0].Name) + } + } } func TestFindOrInitialize(t *testing.T) { From c4872cddfda178ba51c64191f8981e5f9c5a564c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 9 Jun 2020 10:17:24 +0800 Subject: [PATCH 420/881] Refactor callbacks --- callbacks/create.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index cb161061..fca9d374 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -193,22 +193,19 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { return ConvertSliceOfMapToValuesForCreate(stmt, value) default: var ( - values = clause.Values{Columns: make([]clause.Column, len(stmt.Schema.DBNames))} + values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) curTime = stmt.DB.NowFunc() isZero = false ) - var columns int for _, db := range stmt.Schema.DBNames { if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { - values.Columns[columns] = clause.Column{Name: db} - columns++ + values.Columns = append(values.Columns, clause.Column{Name: db}) } } } - values.Columns = values.Columns[:columns] switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: From a42f9bf4391030acae05c5ce3286f4b237483161 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 9 Jun 2020 11:00:50 +0800 Subject: [PATCH 421/881] Remove codecov as doesn't support detect code-coverage of separated folders --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index f5df27f5..1260618a 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,6 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![go report card](https://goreportcard.com/badge/gorm.io/gorm "go report card")](https://goreportcard.com/report/gorm.io/gorm) [![wercker status](https://app.wercker.com/status/55136410c77335a6289ebd58b2f70125/s/master "wercker status")](https://app.wercker.com/project/byKey/55136410c77335a6289ebd58b2f70125) -[![codecov](https://codecov.io/gh/go-gorm/gorm/branch/master/graph/badge.svg)](https://codecov.io/gh/go-gorm/gorm) [![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm) [![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm) From 05e6a65ee13795e1ebe0a02e699ee75b41e5673c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 9 Jun 2020 12:00:43 +0800 Subject: [PATCH 422/881] Fix typo --- README.md | 2 +- callbacks/create.go | 2 +- model.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 1260618a..349bb860 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. -[![go report card](https://goreportcard.com/badge/gorm.io/gorm "go report card")](https://goreportcard.com/report/gorm.io/gorm) +[![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) [![wercker status](https://app.wercker.com/status/55136410c77335a6289ebd58b2f70125/s/master "wercker status")](https://app.wercker.com/project/byKey/55136410c77335a6289ebd58b2f70125) [![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm) diff --git a/callbacks/create.go b/callbacks/create.go index fca9d374..091f1774 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -196,7 +196,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) curTime = stmt.DB.NowFunc() - isZero = false + isZero bool ) for _, db := range stmt.Schema.DBNames { diff --git a/model.go b/model.go index dcc3cdc2..3334d17c 100644 --- a/model.go +++ b/model.go @@ -3,7 +3,7 @@ package gorm import "time" // Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt -// It may be embeded into your model or you may build your own model without it +// It may be embedded into your model or you may build your own model without it // type User struct { // gorm.Model // } From 22ff8377dfaf208c1db8cb4923a481990f7e76a5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 9 Jun 2020 15:34:55 +0800 Subject: [PATCH 423/881] Fix Pluck with Table only --- finisher_api.go | 16 ++++++++-------- scan.go | 34 ++++++++++++++++++---------------- tests/distinct_test.go | 2 +- 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index fc21e490..d45c6c4f 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -289,16 +289,16 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { column = f.DBName } } - - tx.Statement.AddClauseIfNotExists(clause.Select{ - Distinct: tx.Statement.Distinct, - Columns: []clause.Column{{Name: column}}, - }) - tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) - } else { + } else if tx.Statement.Table == "" { tx.AddError(ErrorModelValueRequired) } + + tx.Statement.AddClauseIfNotExists(clause.Select{ + Distinct: tx.Statement.Distinct, + Columns: []clause.Column{{Name: column}}, + }) + tx.Statement.Dest = dest + tx.callbacks.Query().Execute(tx) return } diff --git a/scan.go b/scan.go index f1cdb2e5..1f0aacd0 100644 --- a/scan.go +++ b/scan.go @@ -84,24 +84,26 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) - for idx, column := range columns { - if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { - fields[idx] = field - } else if names := strings.Split(column, "__"); len(names) > 1 { - if len(joinFields) == 0 { - joinFields = make([][2]*schema.Field, len(columns)) - } - - if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - fields[idx] = field - joinFields[idx] = [2]*schema.Field{rel.Field, field} - continue + if db.Statement.Schema != nil { + for idx, column := range columns { + if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { + fields[idx] = field + } else if names := strings.Split(column, "__"); len(names) > 1 { + if len(joinFields) == 0 { + joinFields = make([][2]*schema.Field, len(columns)) } + + if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + fields[idx] = field + joinFields[idx] = [2]*schema.Field{rel.Field, field} + continue + } + } + values[idx] = &sql.RawBytes{} + } else { + values[idx] = &sql.RawBytes{} } - values[idx] = &sql.RawBytes{} - } else { - values[idx] = &sql.RawBytes{} } } diff --git a/tests/distinct_test.go b/tests/distinct_test.go index f5a969a8..248602d3 100644 --- a/tests/distinct_test.go +++ b/tests/distinct_test.go @@ -21,7 +21,7 @@ func TestDistinct(t *testing.T) { } var names []string - DB.Model(&User{}).Where("name like ?", "distinct%").Order("name").Pluck("Name", &names) + DB.Table("users").Where("name like ?", "distinct%").Order("name").Pluck("name", &names) AssertEqual(t, names, []string{"distinct", "distinct", "distinct", "distinct-2", "distinct-3"}) var names1 []string From f3424c68645e327243c15bcbc577ea78967449d8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 10 Jun 2020 00:02:14 +0800 Subject: [PATCH 424/881] Support save slice of data --- callbacks/create.go | 35 +++++++++++++++++++++++++++++------ finisher_api.go | 27 ++++++++++++++++----------- tests/upsert_test.go | 17 +++++++++++++++++ 3 files changed, 62 insertions(+), 17 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 091f1774..22adca24 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -185,19 +185,19 @@ func AfterCreate(db *gorm.DB) { } // ConvertToCreateValues convert to create values -func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { +func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { switch value := stmt.Dest.(type) { case map[string]interface{}: - return ConvertMapToValuesForCreate(stmt, value) + values = ConvertMapToValuesForCreate(stmt, value) case []map[string]interface{}: - return ConvertSliceOfMapToValuesForCreate(stmt, value) + values = ConvertSliceOfMapToValuesForCreate(stmt, value) default: var ( - values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) curTime = stmt.DB.NowFunc() isZero bool ) + values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} for _, db := range stmt.Schema.DBNames { if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { @@ -274,7 +274,30 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { } } } - - return values } + + if stmt.UpdatingColumn { + if stmt.Schema != nil { + columns := make([]string, 0, len(stmt.Schema.DBNames)-1) + for _, name := range stmt.Schema.DBNames { + if field := stmt.Schema.LookUpField(name); field != nil { + if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 { + columns = append(columns, name) + } + } + } + + onConflict := clause.OnConflict{ + Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)), + DoUpdates: clause.AssignmentColumns(columns), + } + + for idx, field := range stmt.Schema.PrimaryFields { + onConflict.Columns[idx] = clause.Column{Name: field.DBName} + } + stmt.AddClause(onConflict) + } + } + + return values } diff --git a/finisher_api.go b/finisher_api.go index d45c6c4f..afefd9fd 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -22,13 +22,14 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = value - if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { - where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} - reflectValue := reflect.Indirect(reflect.ValueOf(value)) - switch reflectValue.Kind() { - case reflect.Slice, reflect.Array: - tx.AddError(ErrPtrStructSupported) - case reflect.Struct: + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + tx.Statement.UpdatingColumn = true + tx.callbacks.Create().Execute(tx) + case reflect.Struct: + if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { + where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} for idx, pf := range tx.Statement.Schema.PrimaryFields { if pv, isZero := pf.ValueOf(reflectValue); isZero { tx.callbacks.Create().Execute(tx) @@ -40,12 +41,16 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.Statement.AddClause(where) } + + fallthrough + default: + if len(tx.Statement.Selects) == 0 { + tx.Statement.Selects = append(tx.Statement.Selects, "*") + } + + tx.callbacks.Update().Execute(tx) } - if len(tx.Statement.Selects) == 0 { - tx.Statement.Selects = append(tx.Statement.Selects, "*") - } - tx.callbacks.Update().Execute(tx) return } diff --git a/tests/upsert_test.go b/tests/upsert_test.go index a1307e32..5826b4fc 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -90,6 +90,23 @@ func TestUpsertSlice(t *testing.T) { } } +func TestUpsertWithSave(t *testing.T) { + langs := []Language{ + {Code: "upsert-save-1", Name: "Upsert-save-1"}, + {Code: "upsert-save-2", Name: "Upsert-save-2"}, + } + if err := DB.Save(&langs).Error; err != nil { + t.Errorf("Failed to create, got error %v", err) + } + + for _, lang := range langs { + var result Language + if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { + t.Errorf("Failed to query lang, got error %v", err) + } + } +} + func TestFindOrInitialize(t *testing.T) { var user1, user2, user3, user4, user5, user6 User if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil { From 0d58d5a3a7b7b73cf6b3533ef5da6b74ed602051 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 10 Jun 2020 10:48:48 +0800 Subject: [PATCH 425/881] Upsert selected columns --- callbacks/create.go | 8 ++++---- tests/upsert_test.go | 45 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 22adca24..684d5530 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -278,11 +278,11 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { if stmt.UpdatingColumn { if stmt.Schema != nil { - columns := make([]string, 0, len(stmt.Schema.DBNames)-1) - for _, name := range stmt.Schema.DBNames { - if field := stmt.Schema.LookUpField(name); field != nil { + columns := make([]string, 0, len(values.Columns)-1) + for _, column := range values.Columns { + if field := stmt.Schema.LookUpField(column.Name); field != nil { if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 { - columns = append(columns, name) + columns = append(columns, column.Name) } } } diff --git a/tests/upsert_test.go b/tests/upsert_test.go index 5826b4fc..ba7c1a9d 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -95,6 +95,7 @@ func TestUpsertWithSave(t *testing.T) { {Code: "upsert-save-1", Name: "Upsert-save-1"}, {Code: "upsert-save-2", Name: "Upsert-save-2"}, } + if err := DB.Save(&langs).Error; err != nil { t.Errorf("Failed to create, got error %v", err) } @@ -103,8 +104,52 @@ func TestUpsertWithSave(t *testing.T) { var result Language if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { t.Errorf("Failed to query lang, got error %v", err) + } else { + AssertEqual(t, result, lang) } } + + for idx, lang := range langs { + lang.Name += "_new" + langs[idx] = lang + } + + if err := DB.Save(&langs).Error; err != nil { + t.Errorf("Failed to upsert, got error %v", err) + } + + for _, lang := range langs { + var result Language + if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { + t.Errorf("Failed to query lang, got error %v", err) + } else { + AssertEqual(t, result, lang) + } + } + + // lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"} + // if err := DB.Save(&lang).Error; err != nil { + // t.Errorf("Failed to create, got error %v", err) + // } + + // var result Language + // if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { + // t.Errorf("Failed to query lang, got error %v", err) + // } else { + // AssertEqual(t, result, lang) + // } + + // lang.Name += "_new" + // if err := DB.Save(&lang).Error; err != nil { + // t.Errorf("Failed to create, got error %v", err) + // } + + // var result2 Language + // if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil { + // t.Errorf("Failed to query lang, got error %v", err) + // } else { + // AssertEqual(t, result2, lang) + // } } func TestFindOrInitialize(t *testing.T) { From dbc3f8feb0f57d7a277aac51acfaf0df793df683 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 10 Jun 2020 13:42:39 +0800 Subject: [PATCH 426/881] Add count soft deleted record test --- tests/go.mod | 2 +- tests/soft_delete_test.go | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/go.mod b/tests/go.mod index c184732c..3401bdfe 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,7 +7,7 @@ require ( gorm.io/driver/mysql v0.0.0-20200609004954-b8310c61c3f2 gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 gorm.io/driver/sqlite v1.0.0 - gorm.io/driver/sqlserver v0.0.0-20200609005334-d550a0be1cfb + gorm.io/driver/sqlserver v0.0.0-20200610030356-9c9aea39e1c1 gorm.io/gorm v0.0.0-00010101000000-000000000000 ) diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go index b6dabe06..40d46fd8 100644 --- a/tests/soft_delete_test.go +++ b/tests/soft_delete_test.go @@ -11,6 +11,12 @@ import ( func TestSoftDelete(t *testing.T) { user := *GetUser("SoftDelete", Config{}) DB.Save(&user) + + var count int64 + if DB.Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 1 { + t.Errorf("Count soft deleted record, expects: %v, got: %v", 1, count) + } + if err := DB.Delete(&user).Error; err != nil { t.Fatalf("No error should happen when soft delete user, but got %v", err) } @@ -19,10 +25,18 @@ func TestSoftDelete(t *testing.T) { t.Errorf("Can't find a soft deleted record") } + if DB.Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 0 { + t.Errorf("Count soft deleted record, expects: %v, got: %v", 0, count) + } + if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err) } + if DB.Unscoped().Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 1 { + t.Errorf("Count soft deleted record, expects: %v, count: %v", 1, count) + } + DB.Unscoped().Delete(&user) if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("Can't find permanently deleted record") From 45cb6b49bfce8ff837f20d4fecdae882ca1bc0f1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 10 Jun 2020 15:36:29 +0800 Subject: [PATCH 427/881] Add FindInBatches support --- finisher_api.go | 24 ++++++++++++++++++++++++ tests/query_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/finisher_api.go b/finisher_api.go index afefd9fd..032c3059 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -106,6 +106,30 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { return } +// FindInBatches find records in batches +func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) (tx *DB) { + tx = db.Session(&Session{WithConditions: true}) + rowsAffected := int64(0) + batch := 0 + + for { + result := tx.Limit(batchSize).Offset(batch * batchSize).Find(dest) + rowsAffected += result.RowsAffected + batch++ + + if result.Error == nil && result.RowsAffected != 0 { + tx.AddError(fc(result, batch)) + } + + if tx.Error != nil || int(result.RowsAffected) < batchSize { + break + } + } + + tx.RowsAffected = rowsAffected + return +} + func (tx *DB) assignExprsToValue(exprs []clause.Expression) { for _, expr := range exprs { if eq, ok := expr.(clause.Eq); ok { diff --git a/tests/query_test.go b/tests/query_test.go index 66413b3b..de65b63b 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -102,6 +102,44 @@ func TestFind(t *testing.T) { }) } +func TestFindInBatches(t *testing.T) { + var users = []User{ + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + } + + DB.Create(&users) + + var ( + results []User + totalBatch int + ) + + if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + totalBatch += batch + + if tx.RowsAffected != 2 { + t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected) + } + + if len(results) != 2 { + t.Errorf("Incorrect users length, expects: 2, got %v", len(results)) + } + + return nil + }); result.Error != nil || result.RowsAffected != 6 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + + if totalBatch != 6 { + t.Errorf("incorrect total batch, expects: %v, got %v", 6, totalBatch) + } +} + func TestFillSmallerStruct(t *testing.T) { user := User{Name: "SmallerUser", Age: 100} DB.Save(&user) From 1af325ab4fad8e490a089ee1655e45c71ac9fa94 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 10 Jun 2020 16:06:54 +0800 Subject: [PATCH 428/881] Upgrade sqlserver driver --- tests/go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/go.mod b/tests/go.mod index 3401bdfe..e5e181d4 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,7 +7,7 @@ require ( gorm.io/driver/mysql v0.0.0-20200609004954-b8310c61c3f2 gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 gorm.io/driver/sqlite v1.0.0 - gorm.io/driver/sqlserver v0.0.0-20200610030356-9c9aea39e1c1 + gorm.io/driver/sqlserver v0.0.0-20200610080012-25da0c25e81d gorm.io/gorm v0.0.0-00010101000000-000000000000 ) From 537065fbd9076537c4f799fff178783d17c96c22 Mon Sep 17 00:00:00 2001 From: Razon Yang Date: Fri, 12 Jun 2020 20:00:55 +0800 Subject: [PATCH 429/881] Replace godoc badge with pkg.go.dev (#3051) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 349bb860..6c2c7731 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm) [![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) -[![GoDoc](https://godoc.org/gorm.io/gorm?status.svg)](https://godoc.org/gorm.io/gorm) +[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) ## Overview From 1bbaa4395115dad830e1fedfd47d0d7c4ae630e8 Mon Sep 17 00:00:00 2001 From: maiyama18 Date: Sun, 14 Jun 2020 10:24:07 +0900 Subject: [PATCH 430/881] fix typos in test method names (#3052) --- tests/create_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/create_test.go b/tests/create_test.go index c497014e..351f02a3 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -190,7 +190,7 @@ func TestPolymorphicHasOne(t *testing.T) { }) } -func TestCreateEmptyStrut(t *testing.T) { +func TestCreateEmptyStruct(t *testing.T) { type EmptyStruct struct { ID uint } @@ -244,7 +244,7 @@ func TestCreateWithNowFuncOverride(t *testing.T) { AssertEqual(t, newUser.UpdatedAt, curTime) } -func TestCreateWithNoGORMPrimayKey(t *testing.T) { +func TestCreateWithNoGORMPrimaryKey(t *testing.T) { type JoinTable struct { UserID uint FriendID uint From 56bdded0f851ef64b2008fda0dff4ef0854d1713 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 14 Jun 2020 11:46:17 +0800 Subject: [PATCH 431/881] Fix statement modifier support --- chainable_api.go | 2 ++ clause/clause.go | 2 +- statement.go | 9 ++++----- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 0be86e03..dbd783fd 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -27,6 +27,8 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { for _, cond := range conds { if c, ok := cond.(clause.Interface); ok { tx.Statement.AddClause(c) + } else if optimizer, ok := cond.(StatementModifier); ok { + optimizer.ModifyStatement(tx.Statement) } else { whereConds = append(whereConds, cond) } diff --git a/clause/clause.go b/clause/clause.go index b3e96332..64f08d14 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -48,7 +48,7 @@ func (c Clause) Build(builder Builder) { } if c.AfterNameExpression != nil { - c.BeforeExpression.Build(builder) + c.AfterNameExpression.Build(builder) builder.WriteByte(' ') } diff --git a/statement.go b/statement.go index e0e86019..720ef283 100644 --- a/statement.go +++ b/statement.go @@ -202,12 +202,11 @@ func (stmt *Statement) AddClause(v clause.Interface) { if optimizer, ok := v.(StatementModifier); ok { optimizer.ModifyStatement(stmt) } else { - c, ok := stmt.Clauses[v.Name()] - if !ok { - c.Name = v.Name() - } + name := v.Name() + c, _ := stmt.Clauses[name] + c.Name = name v.MergeClause(&c) - stmt.Clauses[v.Name()] = c + stmt.Clauses[name] = c } } From 1fdc66710e71692e188d00a26f2fc84ba40c5c10 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 14 Jun 2020 19:13:16 +0800 Subject: [PATCH 432/881] Add table options --- migrator/migrator.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/migrator/migrator.go b/migrator/migrator.go index a98f7fe3..6baa9dc3 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -203,6 +203,11 @@ func (m Migrator) CreateTable(values ...interface{}) error { createTableSQL = strings.TrimSuffix(createTableSQL, ",") createTableSQL += ")" + + if tableOption, ok := m.DB.Get("gorm:table_options"); ok { + createTableSQL += fmt.Sprint(tableOption) + } + return tx.Exec(createTableSQL, values...).Error }); err != nil { return err From 9039e36cfcff3f766a77e640d287597543006405 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 14 Jun 2020 19:18:42 +0800 Subject: [PATCH 433/881] Allow scan into float close #1373 --- scan.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scan.go b/scan.go index 1f0aacd0..2d227ec2 100644 --- a/scan.go +++ b/scan.go @@ -62,7 +62,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { *dest = append(*dest, mapValue) } - case *int, *int64, *uint, *uint64: + case *int, *int64, *uint, *uint64, *float32, *float64: for initialized || rows.Next() { initialized = false db.RowsAffected++ From d716e456f46bad2aac142d1b4286026e0648df3d Mon Sep 17 00:00:00 2001 From: 2BFL <1@linux.com> Date: Mon, 15 Jun 2020 12:28:35 +0800 Subject: [PATCH 434/881] fix broken url (#3053) --- finisher_api.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 032c3059..73e42508 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -216,7 +216,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { return db } -// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update +// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields func (db *DB) Update(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} @@ -224,7 +224,7 @@ func (db *DB) Update(column string, value interface{}) (tx *DB) { return } -// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update +// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields func (db *DB) Updates(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values From e487f355a0838bbc158c5c7d848b35753d290884 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 17 Jun 2020 19:56:03 +0800 Subject: [PATCH 435/881] Add DB method --- gorm.go | 16 ++++++++++++++++ tests/tests_test.go | 9 +++++++++ 2 files changed, 25 insertions(+) diff --git a/gorm.go b/gorm.go index 0de6860b..a5f8bbfd 100644 --- a/gorm.go +++ b/gorm.go @@ -3,6 +3,7 @@ package gorm import ( "context" "database/sql" + "errors" "fmt" "sync" "time" @@ -220,6 +221,21 @@ func (db *DB) AddError(err error) error { return db.Error } +// DB returns `*sql.DB` +func (db *DB) DB() (*sql.DB, error) { + connPool := db.ConnPool + + if stmtDB, ok := connPool.(*PreparedStmtDB); ok { + connPool = stmtDB.ConnPool + } + + if sqldb, ok := connPool.(*sql.DB); ok { + return sqldb, nil + } + + return nil, errors.New("invalid db") +} + func (db *DB) getInstance() *DB { if db.clone > 0 { tx := &DB{Config: db.Config} diff --git a/tests/tests_test.go b/tests/tests_test.go index 09850003..c80fb849 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -24,6 +24,15 @@ func init() { log.Printf("failed to connect database, got error %v\n", err) os.Exit(1) } else { + sqlDB, err := DB.DB() + if err == nil { + err = sqlDB.Ping() + } + + if err != nil { + log.Printf("failed to connect database, got error %v\n", err) + } + RunMigrations() } } From ca2c80c8e385a5959f483cd73d1df58beffd806f Mon Sep 17 00:00:00 2001 From: mojotv <34467684+mojocn@users.noreply.github.com> Date: Wed, 17 Jun 2020 20:29:37 +0800 Subject: [PATCH 436/881] add githubAction CI for tests (#3057) --- .github/workflows/go.yml | 73 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 .github/workflows/go.yml diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 00000000..a5dc41a3 --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,73 @@ +name: Go + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + # Label of the container job + containerTest: + # Containers must run in Linux based operating systems + runs-on: ubuntu-latest + # Docker Hub image that `container-job` executes in + #container: node:10.18-jessie + + # Service containers to run with `container-job` + services: + # start postgres + postgres: + image: postgres:latest + env: + POSTGRES_PASSWORD: gorm + POSTGRES_USER: gorm + POSTGRES_DB: gorm + TZ: Asia/Shanghai + + ports: + - 9920:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + # start mysql + mysql: + image: mysql:latest + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + + ports: + - 9910:3306 + # start mssql + mssql: + image: mcmoe/mssqldocker:latest + env: + ACCEPT_EULA: Y + SA_PASSWORD: LoremIpsum86 + MSSQL_DB: gorm + MSSQL_USER: gorm + MSSQL_PASSWORD: LoremIpsum86 + ports: + - 9930:1433 + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ^1.13 + id: go + + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + + - name: show ports + run: netstat -lntp + + - name: run tests + run: cd tests && ./tests_all.sh From 6b2f37189ee1cc1e46cdad9ef6b7f98c69748f0b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 18 Jun 2020 08:20:41 +0800 Subject: [PATCH 437/881] Fix few cases with postgres --- migrator/migrator.go | 2 +- schema/field.go | 9 ++++++++- tests/go.mod | 2 ++ tests/postgres_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 tests/postgres_test.go diff --git a/migrator/migrator.go b/migrator/migrator.go index 6baa9dc3..955cc6bb 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -74,7 +74,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { } if field.HasDefaultValue && field.DefaultValue != "" { - if field.DataType == schema.String { + if field.DataType == schema.String && field.DefaultValueInterface != nil { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValue}} m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValue) expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValue) diff --git a/schema/field.go b/schema/field.go index e0d49e2f..ea6dcd25 100644 --- a/schema/field.go +++ b/schema/field.go @@ -203,7 +203,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } case reflect.String: field.DataType = String - if field.HasDefaultValue { + isFunc := strings.Contains(field.DefaultValue, "(") && + strings.Contains(field.DefaultValue, ")") + + if field.HasDefaultValue && !isFunc { field.DefaultValue = strings.Trim(field.DefaultValue, "'") field.DefaultValue = strings.Trim(field.DefaultValue, "\"") field.DefaultValueInterface = field.DefaultValue @@ -253,6 +256,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + if field.DataType == "" && field.DBDataType != "" { + field.DataType = String + } + // setup permission if _, ok := field.TagSettings["-"]; ok { field.Creatable = false diff --git a/tests/go.mod b/tests/go.mod index e5e181d4..e500edd7 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,7 +3,9 @@ module gorm.io/gorm/tests go 1.14 require ( + github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 + github.com/lib/pq v1.6.0 gorm.io/driver/mysql v0.0.0-20200609004954-b8310c61c3f2 gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 gorm.io/driver/sqlite v1.0.0 diff --git a/tests/postgres_test.go b/tests/postgres_test.go new file mode 100644 index 00000000..98302d87 --- /dev/null +++ b/tests/postgres_test.go @@ -0,0 +1,39 @@ +package tests_test + +import ( + "testing" + + "github.com/google/uuid" + "github.com/lib/pq" + "gorm.io/gorm" +) + +func TestPostgres(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + type Harumph struct { + gorm.Model + Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` + Things pq.StringArray `gorm:"type:text[]"` + } + + if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { + t.Errorf("Failed to create extension pgcrypto, got error %v", err) + } + + DB.Migrator().DropTable(&Harumph{}) + + if err := DB.AutoMigrate(&Harumph{}); err != nil { + t.Fatalf("Failed to migrate for uuid default value, got error: %v", err) + } + + harumph := Harumph{} + DB.Create(&harumph) + + var result Harumph + if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil { + t.Errorf("No error should happen, but got %v", err) + } +} From 96368eb967bbfbab8ef0bdef2e9ff1fcbdee6710 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 18 Jun 2020 09:15:23 +0800 Subject: [PATCH 438/881] Test embedded struct implements Scan & Value interface --- migrator/migrator.go | 6 +---- schema/field.go | 18 ++++++-------- schema/schema_helper_test.go | 2 +- tests/embedded_struct_test.go | 45 +++++++++++++++++++++++++++++++++++ tests/go.mod | 8 +++---- 5 files changed, 58 insertions(+), 21 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 955cc6bb..8f872ee4 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -44,10 +44,6 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error } func (m Migrator) DataTypeOf(field *schema.Field) string { - if field.DBDataType != "" { - return field.DBDataType - } - fieldValue := reflect.New(field.IndirectFieldType) if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { if dataType := dataTyper.GormDBDataType(m.DB, field); dataType != "" { @@ -155,7 +151,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.FieldsByDBName[dbName] createTableSQL += fmt.Sprintf("? ?") - hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY") + hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") values = append(values, clause.Column{Name: dbName}, m.FullDataTypeOf(field)) createTableSQL += "," } diff --git a/schema/field.go b/schema/field.go index ea6dcd25..8bfa3b22 100644 --- a/schema/field.go +++ b/schema/field.go @@ -38,7 +38,6 @@ type Field struct { DBName string BindNames []string DataType DataType - DBDataType string PrimaryKey bool AutoIncrement bool Creatable bool @@ -104,7 +103,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } // if field is valuer, used its value or first fields as data type - if valuer, isValueOf := fieldValue.Interface().(driver.Valuer); isValueOf { + valuer, isValuer := fieldValue.Interface().(driver.Valuer) + if isValuer { var overrideFieldValue bool if v, err := valuer.Value(); v != nil && err == nil { overrideFieldValue = true @@ -176,10 +176,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Comment = val } - if val, ok := field.TagSettings["TYPE"]; ok { - field.DBDataType = val - } - switch reflect.Indirect(fieldValue).Kind() { case reflect.Bool: field.DataType = Bool @@ -227,6 +223,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = DataType(dataTyper.GormDataType()) } + if val, ok := field.TagSettings["TYPE"]; ok { + field.DataType = DataType(val) + } + if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if strings.ToUpper(v) == "NANO" { field.AutoCreateTime = UnixNanosecond @@ -256,10 +256,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if field.DataType == "" && field.DBDataType != "" { - field.DataType = String - } - // setup permission if _, ok := field.TagSettings["-"]; ok { field.Creatable = false @@ -293,7 +289,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { + if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer) { var err error field.Creatable = false field.Updatable = false diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index d2e68536..f202b487 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -52,7 +52,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* if parsedField, ok := s.FieldsByName[f.Name]; !ok { t.Errorf("schema %v failed to look up field with name %v", s, f.Name) } else { - tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") + tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") if f.DBName != "" { if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 9a1436fe..5f06f63c 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -1,6 +1,9 @@ package tests_test import ( + "database/sql/driver" + "encoding/json" + "errors" "testing" "gorm.io/gorm" @@ -102,3 +105,45 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) { t.Errorf("Should find correct value for embedded pointer type") } } + +type Content struct { + Content interface{} `gorm:"type:string"` +} + +func (c Content) Value() (driver.Value, error) { + return json.Marshal(c) +} + +func (c *Content) Scan(src interface{}) error { + b, ok := src.([]byte) + if !ok { + return errors.New("Embedded.Scan byte assertion failed") + } + + var value Content + if err := json.Unmarshal(b, &value); err != nil { + return err + } + + *c = value + + return nil +} + +func TestEmbeddedScanValuer(t *testing.T) { + type HNPost struct { + gorm.Model + Content + } + + DB.Migrator().DropTable(&HNPost{}) + if err := DB.Migrator().AutoMigrate(&HNPost{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + hnPost := HNPost{Content: Content{Content: "hello world"}} + + if err := DB.Create(&hnPost).Error; err != nil { + t.Errorf("Failed to create got error %v", err) + } +} diff --git a/tests/go.mod b/tests/go.mod index e500edd7..07ec6be2 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,10 +6,10 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.0.0-20200609004954-b8310c61c3f2 - gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 - gorm.io/driver/sqlite v1.0.0 - gorm.io/driver/sqlserver v0.0.0-20200610080012-25da0c25e81d + gorm.io/driver/mysql v0.2.0 + gorm.io/driver/postgres v0.2.0 + gorm.io/driver/sqlite v1.0.2 + gorm.io/driver/sqlserver v0.2.0 gorm.io/gorm v0.0.0-00010101000000-000000000000 ) From 07960fe661b5ced50c9ca30e010aa26513eaf851 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 18 Jun 2020 09:32:31 +0800 Subject: [PATCH 439/881] Fix []byte support --- schema/field.go | 2 +- statement.go | 3 +++ tests/scanner_valuer_test.go | 10 ++++++---- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/schema/field.go b/schema/field.go index 8bfa3b22..f8ecef60 100644 --- a/schema/field.go +++ b/schema/field.go @@ -214,7 +214,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = Time } case reflect.Array, reflect.Slice: - if fieldValue.Type().Elem() == reflect.TypeOf(uint8(0)) { + if reflect.Indirect(fieldValue).Type().Elem() == reflect.TypeOf(uint8(0)) { field.DataType = Bytes } } diff --git a/statement.go b/statement.go index 720ef283..2a092966 100644 --- a/statement.go +++ b/statement.go @@ -160,6 +160,9 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { case driver.Valuer: stmt.Vars = append(stmt.Vars, v) stmt.DB.Dialector.BindVarTo(writer, stmt, v) + case []byte: + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) case []interface{}: if len(v) > 0 { writer.WriteByte('(') diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index ec228f00..632bd74a 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -17,7 +17,7 @@ import ( func TestScannerValuer(t *testing.T) { DB.Migrator().DropTable(&ScannerValuerStruct{}) if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { - t.Errorf("no error should happen when migrate scanner, valuer struct") + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) } data := ScannerValuerStruct{ @@ -28,6 +28,7 @@ func TestScannerValuer(t *testing.T) { Height: sql.NullFloat64{Float64: 1.8888, Valid: true}, Birthday: sql.NullTime{Time: time.Now(), Valid: true}, Password: EncryptedData("pass1"), + Bytes: []byte("byte"), Num: 18, Strings: StringsSlice{"a", "b", "c"}, Structs: StructsSlice{ @@ -38,16 +39,16 @@ func TestScannerValuer(t *testing.T) { } if err := DB.Create(&data).Error; err != nil { - t.Errorf("No error should happened when create scanner valuer struct, but got %v", err) + t.Fatalf("No error should happened when create scanner valuer struct, but got %v", err) } var result ScannerValuerStruct if err := DB.Find(&result).Error; err != nil { - t.Errorf("no error should happen when query scanner, valuer struct, but got %v", err) + t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err) } - AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Num", "Strings", "Structs") + AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs") } func TestScannerValuerWithFirstOrCreate(t *testing.T) { @@ -130,6 +131,7 @@ type ScannerValuerStruct struct { Height sql.NullFloat64 Birthday sql.NullTime Password EncryptedData + Bytes []byte Num Num Strings StringsSlice Structs StructsSlice From 2c1b04a2cf0b9740a90d70b31c9cfdb5a1058183 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 19 Jun 2020 12:38:03 +0800 Subject: [PATCH 440/881] Fix failed to create second record in same transaction, close #3060 --- callbacks/transaction.go | 2 +- finisher_api.go | 5 +++-- statement.go | 5 +++++ tests/transaction_test.go | 10 ++++++++++ 4 files changed, 19 insertions(+), 3 deletions(-) diff --git a/callbacks/transaction.go b/callbacks/transaction.go index 430a341d..14d31a62 100644 --- a/callbacks/transaction.go +++ b/callbacks/transaction.go @@ -7,7 +7,7 @@ import ( func BeginTransaction(db *gorm.DB) { if tx := db.Begin(); tx.Error == nil { db.Statement.ConnPool = tx.Statement.ConnPool - tx.InstanceSet("gorm:started_transaction", true) + db.InstanceSet("gorm:started_transaction", true) } else { tx.Error = nil } diff --git a/finisher_api.go b/finisher_api.go index 73e42508..43aff843 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -351,7 +351,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } }() - err = fc(tx.Session(&Session{})) + err = fc(tx) if err == nil { err = tx.Commit().Error @@ -364,7 +364,8 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er // Begin begins a transaction func (db *DB) Begin(opts ...*sql.TxOptions) *DB { var ( - tx = db.getInstance() + // clone statement + tx = db.Session(&Session{WithConditions: true, Context: db.Statement.Context}) opt *sql.TxOptions err error ) diff --git a/statement.go b/statement.go index 2a092966..e3c882ee 100644 --- a/statement.go +++ b/statement.go @@ -351,5 +351,10 @@ func (stmt *Statement) clone() *Statement { newStmt.Joins[k] = j } + stmt.Settings.Range(func(k, v interface{}) bool { + newStmt.Settings.Store(k, v) + return true + }) + return newStmt } diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 592f1321..d1bf8645 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -20,6 +20,16 @@ func TestTransaction(t *testing.T) { t.Fatalf("Should find saved record, but got %v", err) } + user1 := *GetUser("transaction1-1", Config{}) + + if err := tx.Save(&user1).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil { t.Fatalf("Should return the underlying sql.Tx") } From 7dc255acfe2e20c033e082b532c6b1c85c7751a9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 19 Jun 2020 18:30:04 +0800 Subject: [PATCH 441/881] Add SavePoint/RollbackTo/NestedTransaction --- errors.go | 2 + finisher_api.go | 56 ++++++++++++++---- interfaces.go | 5 ++ tests/go.mod | 10 ++-- tests/transaction_test.go | 120 ++++++++++++++++++++++++++++++++++++++ wercker.yml | 6 -- 6 files changed, 177 insertions(+), 22 deletions(-) diff --git a/errors.go b/errors.go index ff06f24e..2506ecc5 100644 --- a/errors.go +++ b/errors.go @@ -25,4 +25,6 @@ var ( ErrorPrimaryKeyRequired = errors.New("primary key required") // ErrorModelValueRequired model value required ErrorModelValueRequired = errors.New("model value required") + // ErrUnsupportedDriver unsupported driver + ErrUnsupportedDriver = errors.New("unsupported driver") ) diff --git a/finisher_api.go b/finisher_api.go index 43aff843..92d4fe72 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -3,6 +3,7 @@ package gorm import ( "database/sql" "errors" + "fmt" "reflect" "strings" @@ -343,18 +344,33 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { // Transaction start a transaction as a block, return error will rollback, otherwise to commit. func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { panicked := true - tx := db.Begin(opts...) - defer func() { - // Make sure to rollback when panic, Block error or Commit error - if panicked || err != nil { - tx.Rollback() + + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + // nested transaction + db.SavePoint(fmt.Sprintf("sp%p", fc)) + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + db.RollbackTo(fmt.Sprintf("sp%p", fc)) + } + }() + + err = fc(db.Session(&Session{WithConditions: true})) + } else { + tx := db.Begin(opts...) + + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + tx.Rollback() + } + }() + + err = fc(tx) + + if err == nil { + err = tx.Commit().Error } - }() - - err = fc(tx) - - if err == nil { - err = tx.Commit().Error } panicked = false @@ -409,6 +425,24 @@ func (db *DB) Rollback() *DB { return db } +func (db *DB) SavePoint(name string) *DB { + if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + savePointer.SavePoint(db, name) + } else { + db.AddError(ErrUnsupportedDriver) + } + return db +} + +func (db *DB) RollbackTo(name string) *DB { + if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + savePointer.RollbackTo(db, name) + } else { + db.AddError(ErrUnsupportedDriver) + } + return db +} + // Exec execute raw sql func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() diff --git a/interfaces.go b/interfaces.go index 4be54565..f3e5c028 100644 --- a/interfaces.go +++ b/interfaces.go @@ -27,6 +27,11 @@ type ConnPool interface { QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row } +type SavePointerDialectorInterface interface { + SavePoint(tx *DB, name string) error + RollbackTo(tx *DB, name string) error +} + type TxBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } diff --git a/tests/go.mod b/tests/go.mod index 07ec6be2..a2121b7a 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,11 +6,11 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.2.0 - gorm.io/driver/postgres v0.2.0 - gorm.io/driver/sqlite v1.0.2 - gorm.io/driver/sqlserver v0.2.0 - gorm.io/gorm v0.0.0-00010101000000-000000000000 + gorm.io/driver/mysql v0.2.1 + gorm.io/driver/postgres v0.2.1 + gorm.io/driver/sqlite v1.0.4 + gorm.io/driver/sqlserver v0.2.1 + gorm.io/gorm v0.2.7 ) replace gorm.io/gorm => ../ diff --git a/tests/transaction_test.go b/tests/transaction_test.go index d1bf8645..c101388a 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -142,3 +142,123 @@ func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) { t.Fatalf("Rollback after commit should raise error") } } + +func TestTransactionWithSavePoint(t *testing.T) { + tx := DB.Begin() + + user := *GetUser("transaction-save-point", Config{}) + tx.Create(&user) + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.SavePoint("save_point1").Error; err != nil { + t.Fatalf("Failed to save point, got error %v", err) + } + + user1 := *GetUser("transaction-save-point-1", Config{}) + tx.Create(&user1) + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.RollbackTo("save_point1").Error; err != nil { + t.Fatalf("Failed to save point, got error %v", err) + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := tx.SavePoint("save_point2").Error; err != nil { + t.Fatalf("Failed to save point, got error %v", err) + } + + user2 := *GetUser("transaction-save-point-2", Config{}) + tx.Create(&user2) + + if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.Commit().Error; err != nil { + t.Fatalf("Failed to commit, got error %v", err) + } + + if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } +} + +func TestNestedTransactionWithBlock(t *testing.T) { + var ( + user = *GetUser("transaction-nested", Config{}) + user1 = *GetUser("transaction-nested-1", Config{}) + user2 = *GetUser("transaction-nested-2", Config{}) + ) + + if err := DB.Transaction(func(tx *gorm.DB) error { + tx.Create(&user) + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.Transaction(func(tx1 *gorm.DB) error { + tx1.Create(&user1) + + if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return errors.New("rollback") + }); err == nil { + t.Fatalf("nested transaction should returns error") + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := tx.Transaction(func(tx2 *gorm.DB) error { + tx2.Create(&user2) + + if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return nil + }); err != nil { + t.Fatalf("nested transaction returns error: %v", err) + } + + if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + return nil + }); err != nil { + t.Fatalf("no error should return, but got %v", err) + } + + if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } +} diff --git a/wercker.yml b/wercker.yml index baece1bc..d4fb63e3 100644 --- a/wercker.yml +++ b/wercker.yml @@ -124,9 +124,3 @@ build: name: test mssql code: | GORM_DIALECT=mssql GORM_VERBOSE=true GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" ./tests/tests_all.sh - - - script: - name: codecov - code: | - go test -race -coverprofile=coverage.txt -covermode=atomic ./... - bash <(curl -s https://codecov.io/bash) From e3292b3b4171cefe59694391729aa997640cc92e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 19 Jun 2020 18:44:19 +0800 Subject: [PATCH 442/881] Test with latest driver vesion --- tests/tests_all.sh | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/tests_all.sh b/tests/tests_all.sh index affb1847..fd696e38 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -4,6 +4,14 @@ if [[ $(pwd) == *"gorm/tests"* ]]; then cd .. fi +if [ -d tests ] +then + cd tests + cp go.mod go.mod.bak + sed '/gorm.io\/driver/d' go.mod.bak > go.mod + cd .. +fi + for dialect in "${dialects[@]}" ; do if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ] then @@ -35,3 +43,9 @@ for dialect in "${dialects[@]}" ; do fi fi done + +if [ -d tests ] +then + cd tests + mv go.mod.bak go.mod +fi From d4d339f3b5e9dc9d3da10d6bd34aed7ac6818d76 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 19 Jun 2020 22:51:46 +0800 Subject: [PATCH 443/881] Handle data type cases --- schema/field.go | 7 ++++++- tests/embedded_struct_test.go | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/schema/field.go b/schema/field.go index f8ecef60..737f56c4 100644 --- a/schema/field.go +++ b/schema/field.go @@ -224,7 +224,12 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if val, ok := field.TagSettings["TYPE"]; ok { - field.DataType = DataType(val) + switch DataType(strings.ToLower(val)) { + case Bool, Int, Uint, Float, String, Time, Bytes: + field.DataType = DataType(strings.ToLower(val)) + default: + field.DataType = DataType(val) + } } if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 5f06f63c..8536b605 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -107,7 +107,7 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) { } type Content struct { - Content interface{} `gorm:"type:string"` + Content interface{} `gorm:"type:String"` } func (c Content) Value() (driver.Value, error) { From 4f19e2a7b3f56b545f61aa2e5496da3e52bbf367 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 20 Jun 2020 00:48:15 +0800 Subject: [PATCH 444/881] Test ForeignKeyConstraints --- callbacks/update.go | 52 ++++++++--------- migrator/migrator.go | 20 ++++--- schema/relationship.go | 18 ++++-- tests/associations_test.go | 109 ++++++++++++++++++++++++++++++++++++ tests/preload_suits_test.go | 4 +- tests/tests_test.go | 1 + utils/tests/models.go | 2 +- 7 files changed, 165 insertions(+), 41 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 03d5c1e9..1ea77552 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -137,6 +137,32 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { updatingValue = updatingValue.Elem() } + if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var priamryKeyExprs []clause.Expression + for i := 0; i < stmt.ReflectValue.Len(); i++ { + var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) + var notZero bool + for idx, field := range stmt.Schema.PrimaryFields { + value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) + exprs[idx] = clause.Eq{Column: field.DBName, Value: value} + notZero = notZero || !isZero + } + if notZero { + priamryKeyExprs = append(priamryKeyExprs, clause.And(exprs...)) + } + } + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}}) + case reflect.Struct: + for _, field := range stmt.Schema.PrimaryFields { + if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + } + } + } + } + switch value := updatingValue.Interface().(type) { case map[string]interface{}: set = make([]clause.Assignment, 0, len(value)) @@ -218,31 +244,5 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } - if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { - switch stmt.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - var priamryKeyExprs []clause.Expression - for i := 0; i < stmt.ReflectValue.Len(); i++ { - var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) - var notZero bool - for idx, field := range stmt.Schema.PrimaryFields { - value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) - exprs[idx] = clause.Eq{Column: field.DBName, Value: value} - notZero = notZero || !isZero - } - if notZero { - priamryKeyExprs = append(priamryKeyExprs, clause.And(exprs...)) - } - } - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}}) - case reflect.Struct: - for _, field := range stmt.Schema.PrimaryFields { - if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) - } - } - } - } - return } diff --git a/migrator/migrator.go b/migrator/migrator.go index 8f872ee4..a4cc99a6 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -103,9 +103,11 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil { - if !tx.Migrator().HasConstraint(value, constraint.Name) { - if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { - return err + if constraint.Schema == stmt.Schema { + if !tx.Migrator().HasConstraint(value, constraint.Name) { + if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { + return err + } } } } @@ -177,9 +179,11 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil { - sql, vars := buildConstraint(constraint) - createTableSQL += sql + "," - values = append(values, vars...) + if constraint.Schema == stmt.Schema { + sql, vars := buildConstraint(constraint) + createTableSQL += sql + "," + values = append(values, vars...) + } } // create join table @@ -360,7 +364,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter } if constraint.OnUpdate != "" { - sql += " ON UPDATE " + constraint.OnUpdate + sql += " ON UPDATE " + constraint.OnUpdate } var foreignKeys, references []interface{} @@ -550,7 +554,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i dep.Parse(value) for _, rel := range dep.Schema.Relationships.Relations { - if c := rel.ParseConstraint(); c != nil && c.Schema != c.ReferenceSchema { + if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { dep.Depends = append(dep.Depends, c.ReferenceSchema) } } diff --git a/schema/relationship.go b/schema/relationship.go index efa44554..afa083ed 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -85,6 +85,10 @@ func (schema *Schema) parseRelation(field *Field) { } if relation.Type == "has" { + if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil { + relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation + } + switch field.IndirectFieldType.Kind() { case reflect.Struct: relation.Type = HasOne @@ -384,18 +388,24 @@ func (rel *Relationship) ParseConstraint() *Constraint { Field: rel.Field, OnUpdate: settings["ONUPDATE"], OnDelete: settings["ONDELETE"], - Schema: rel.Schema, } for _, ref := range rel.References { - if ref.PrimaryKey != nil && !ref.OwnPrimaryKey { + if ref.PrimaryKey != nil { constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey) constraint.References = append(constraint.References, ref.PrimaryKey) - constraint.ReferenceSchema = ref.PrimaryKey.Schema + + if ref.OwnPrimaryKey { + constraint.Schema = ref.ForeignKey.Schema + constraint.ReferenceSchema = rel.Schema + } else { + constraint.Schema = rel.Schema + constraint.ReferenceSchema = ref.PrimaryKey.Schema + } } } - if rel.JoinTable != nil || constraint.ReferenceSchema == nil { + if rel.JoinTable != nil { return nil } diff --git a/tests/associations_test.go b/tests/associations_test.go index 44262109..9b4dd105 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -31,3 +31,112 @@ func TestInvalidAssociation(t *testing.T) { t.Fatalf("should return errors for invalid association, but got nil") } } + +func TestForeignKeyConstraints(t *testing.T) { + type Profile struct { + ID uint + Name string + MemberID uint + } + + type Member struct { + ID uint + Refer uint `gorm:"unique_index"` + Name string + Profile Profile `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:MemberID;References:Refer"` + } + + DB.Migrator().DropTable(&Profile{}, &Member{}) + + if err := DB.AutoMigrate(&Profile{}, &Member{}); err != nil { + t.Fatalf("Failed to migrate, got error: %v", err) + } + + member := Member{Refer: 1, Name: "foreign_key_constraints", Profile: Profile{Name: "my_profile"}} + + DB.Create(&member) + + var profile Profile + if err := DB.First(&profile, "id = ?", member.Profile.ID).Error; err != nil { + t.Fatalf("failed to find profile, got error: %v", err) + } else if profile.MemberID != member.ID { + t.Fatalf("member id is not equal: expects: %v, got: %v", member.ID, profile.MemberID) + } + + member.Profile = Profile{} + DB.Model(&member).Update("Refer", 100) + + var profile2 Profile + if err := DB.First(&profile2, "id = ?", profile.ID).Error; err != nil { + t.Fatalf("failed to find profile, got error: %v", err) + } else if profile2.MemberID != 100 { + t.Fatalf("member id is not equal: expects: %v, got: %v", 100, profile2.MemberID) + } + + if r := DB.Delete(&member); r.Error != nil || r.RowsAffected != 1 { + t.Fatalf("Should delete member, got error: %v, affected: %v", r.Error, r.RowsAffected) + } + + var result Member + if err := DB.First(&result, member.ID).Error; err == nil { + t.Fatalf("Should not find deleted member") + } + + if err := DB.First(&profile2, profile.ID).Error; err == nil { + t.Fatalf("Should not find deleted profile") + } +} + +func TestForeignKeyConstraintsBelongsTo(t *testing.T) { + type Profile struct { + ID uint + Name string + Refer uint `gorm:"unique_index"` + } + + type Member struct { + ID uint + Name string + ProfileID uint + Profile Profile `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:ProfileID;References:Refer"` + } + + DB.Migrator().DropTable(&Profile{}, &Member{}) + + if err := DB.AutoMigrate(&Profile{}, &Member{}); err != nil { + t.Fatalf("Failed to migrate, got error: %v", err) + } + + member := Member{Name: "foreign_key_constraints_belongs_to", Profile: Profile{Name: "my_profile_belongs_to", Refer: 1}} + + DB.Create(&member) + + var profile Profile + if err := DB.First(&profile, "id = ?", member.Profile.ID).Error; err != nil { + t.Fatalf("failed to find profile, got error: %v", err) + } else if profile.Refer != member.ProfileID { + t.Fatalf("member id is not equal: expects: %v, got: %v", profile.Refer, member.ProfileID) + } + + DB.Model(&profile).Update("Refer", 100) + + var member2 Member + if err := DB.First(&member2, "id = ?", member.ID).Error; err != nil { + t.Fatalf("failed to find member, got error: %v", err) + } else if member2.ProfileID != 100 { + t.Fatalf("member id is not equal: expects: %v, got: %v", 100, member2.ProfileID) + } + + if r := DB.Delete(&profile); r.Error != nil || r.RowsAffected != 1 { + t.Fatalf("Should delete member, got error: %v, affected: %v", r.Error, r.RowsAffected) + } + + var result Member + if err := DB.First(&result, member.ID).Error; err == nil { + t.Fatalf("Should not find deleted member") + } + + if err := DB.First(&profile, profile.ID).Error; err == nil { + t.Fatalf("Should not find deleted profile") + } +} diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go index 8f678b21..4a25a69b 100644 --- a/tests/preload_suits_test.go +++ b/tests/preload_suits_test.go @@ -433,8 +433,8 @@ func TestNestedPreload9(t *testing.T) { Level1 struct { ID uint Value string - Level2ID uint - Level2_1ID uint + Level2ID *uint + Level2_1ID *uint Level0s []Level0 `json:",omitempty"` } Level2 struct { diff --git a/tests/tests_test.go b/tests/tests_test.go index c80fb849..9e135b4e 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -66,6 +66,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { default: log.Println("testing sqlite3...") db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) + db.Exec("PRAGMA foreign_keys = ON") } if debug := os.Getenv("DEBUG"); debug == "true" { diff --git a/utils/tests/models.go b/utils/tests/models.go index 878129e8..021b0229 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -37,7 +37,7 @@ type Account struct { type Pet struct { gorm.Model - UserID uint + UserID *uint Name string Toy Toy `gorm:"polymorphic:Owner;"` } From 3d8f6f9cf9e225c964c66634b6b34df8e139f792 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 20 Jun 2020 01:55:30 +0800 Subject: [PATCH 445/881] Test GroupConditions --- clause/where.go | 6 +++++- clause/where_test.go | 6 ++++++ statement.go | 8 ++++++++ tests/sql_builder_test.go | 25 +++++++++++++++++++++++++ 4 files changed, 44 insertions(+), 1 deletion(-) diff --git a/clause/where.go b/clause/where.go index 6399a2d5..f7cd3318 100644 --- a/clause/where.go +++ b/clause/where.go @@ -66,7 +66,11 @@ func (and AndConditions) Build(builder Builder) { } for idx, c := range and.Exprs { if idx > 0 { - builder.WriteString(" AND ") + if orConditions, ok := c.(OrConditions); ok && len(orConditions.Exprs) == 1 { + builder.WriteString(" OR ") + } else { + builder.WriteString(" AND ") + } } c.Build(builder) } diff --git a/clause/where_test.go b/clause/where_test.go index 95bba820..2fa11d76 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -53,6 +53,12 @@ func TestWhere(t *testing.T) { }}, "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))}, + }}, + "SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", []interface{}{18, "jinzhu"}, + }, } for idx, result := range results { diff --git a/statement.go b/statement.go index e3c882ee..7cc01bb8 100644 --- a/statement.go +++ b/statement.go @@ -245,6 +245,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c switch v := arg.(type) { case clause.Expression: conds = append(conds, v) + case *DB: + if cs, ok := v.Statement.Clauses["WHERE"]; ok { + if where, ok := cs.Expression.(clause.Where); ok { + conds = append(conds, clause.And(where.Exprs...)) + } else if cs.Expression != nil { + conds = append(conds, cs.Expression) + } + } case map[interface{}]interface{}: for i, j := range v { conds = append(conds, clause.Eq{Column: i, Value: j}) diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index a60514c9..b78c2484 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "strings" "testing" "gorm.io/gorm" @@ -138,3 +139,27 @@ func TestDryRun(t *testing.T) { t.Errorf("Failed to generate sql, got %v", stmt2.SQL.String()) } } + +func TestGroupConditions(t *testing.T) { + type Pizza struct { + ID uint + Name string + Size string + } + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Where( + DB.Where("pizza = ?", "pepperoni").Where(DB.Where("size = ?", "small").Or("size = ?", "medium")), + ).Or( + DB.Where("pizza = ?", "hawaiian").Where("size = ?", "xlarge"), + ).Find(&Pizza{}).Statement + + execStmt := dryRunDB.Exec("WHERE (pizza = ? AND (size = ? OR size = ?)) OR (pizza = ? AND size = ?)", "pepperoni", "small", "medium", "hawaiian", "xlarge").Statement + + result := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + expects := DB.Dialector.Explain(execStmt.SQL.String(), execStmt.Vars...) + + if !strings.HasSuffix(result, expects) { + t.Errorf("expects: %v, got %v", expects, result) + } +} From a1e35bdc94760e520ed40cfdeaefd6b8c67e779e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 20 Jun 2020 10:51:36 +0800 Subject: [PATCH 446/881] Support merge batch data some having primary values --- callbacks/associations.go | 84 ++++++++++++++++-------------- callbacks/create.go | 77 ++++++++++++++++++--------- clause/clause.go | 1 + interfaces.go | 1 + migrator/migrator.go | 4 -- schema/field_test.go | 2 +- schema/relationship.go | 4 +- schema/schema.go | 1 + schema/schema_test.go | 4 +- tests/associations_has_one_test.go | 2 + tests/go.mod | 10 ++-- tests/helper_test.go | 1 - tests/preload_suits_test.go | 8 ++- tests/tests_all.sh | 2 +- tests/tests_test.go | 4 +- utils/tests/dummy_dialecter.go | 4 ++ 16 files changed, 126 insertions(+), 83 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 5ff63cc4..3ff0f4b0 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -52,21 +52,19 @@ func SaveBeforeAssociations(db *gorm.DB) { obj := db.Statement.ReflectValue.Index(i) if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value rv := rel.Field.ReflectValueOf(obj) // relation reflect value - if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { - objs = append(objs, obj) - if isPtr { - elems = reflect.Append(elems, rv) - } else { - elems = reflect.Append(elems, rv.Addr()) - } + objs = append(objs, obj) + if isPtr { + elems = reflect.Append(elems, rv) } else { - setupReferences(obj, rv) + elems = reflect.Append(elems, rv.Addr()) } } } if elems.Len() > 0 { - if db.AddError(db.Session(&gorm.Session{}).Create(elems.Interface()).Error) == nil { + if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ + DoNothing: true, + }).Create(elems.Interface()).Error) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } @@ -79,10 +77,11 @@ func SaveBeforeAssociations(db *gorm.DB) { rv = rv.Addr() } - if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { - db.Session(&gorm.Session{}).Create(rv.Interface()) + if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ + DoNothing: true, + }).Create(rv.Interface()).Error) == nil { + setupReferences(db.Statement.ReflectValue, rv) } - setupReferences(db.Statement.ReflectValue, rv) } } } @@ -130,16 +129,20 @@ func SaveAfterAssociations(db *gorm.DB) { } } - if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { - elems = reflect.Append(elems, rv) - } else { - db.Session(&gorm.Session{}).Save(rv.Addr().Interface()) - } + elems = reflect.Append(elems, rv) } } if elems.Len() > 0 { - db.Session(&gorm.Session{}).Create(elems.Interface()) + assignmentColumns := []string{} + for _, ref := range rel.References { + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, + DoUpdates: clause.AssignmentColumns(assignmentColumns), + }).Create(elems.Interface()) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { @@ -148,6 +151,7 @@ func SaveAfterAssociations(db *gorm.DB) { f = f.Addr() } + assignmentColumns := []string{} for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) @@ -155,13 +159,13 @@ func SaveAfterAssociations(db *gorm.DB) { } else if ref.PrimaryValue != "" { ref.ForeignKey.Set(f, ref.PrimaryValue) } + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f); isZero { - db.Session(&gorm.Session{}).Create(f.Interface()) - } else { - db.Session(&gorm.Session{}).Save(f.Interface()) - } + db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, + DoUpdates: clause.AssignmentColumns(assignmentColumns), + }).Create(f.Interface()) } } } @@ -193,14 +197,10 @@ func SaveAfterAssociations(db *gorm.DB) { } } - if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero { - if isPtr { - elems = reflect.Append(elems, elem) - } else { - elems = reflect.Append(elems, elem.Addr()) - } + if isPtr { + elems = reflect.Append(elems, elem) } else { - db.Session(&gorm.Session{}).Save(elem.Addr().Interface()) + elems = reflect.Append(elems, elem.Addr()) } } } @@ -216,7 +216,15 @@ func SaveAfterAssociations(db *gorm.DB) { } if elems.Len() > 0 { - db.Session(&gorm.Session{}).Create(elems.Interface()) + assignmentColumns := []string{} + for _, ref := range rel.References { + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, + DoUpdates: clause.AssignmentColumns(assignmentColumns), + }).Create(elems.Interface()) } } @@ -258,15 +266,11 @@ func SaveAfterAssociations(db *gorm.DB) { for i := 0; i < f.Len(); i++ { elem := f.Index(i) - if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero { - objs = append(objs, v) - if isPtr { - elems = reflect.Append(elems, elem) - } else { - elems = reflect.Append(elems, elem.Addr()) - } + objs = append(objs, v) + if isPtr { + elems = reflect.Append(elems, elem) } else { - appendToJoins(v, elem) + elems = reflect.Append(elems, elem.Addr()) } } } @@ -282,7 +286,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if elems.Len() > 0 { - db.Session(&gorm.Session{}).Create(elems.Interface()) + db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()) for i := 0; i < elems.Len(); i++ { appendToJoins(objs[i], elems.Index(i)) diff --git a/callbacks/create.go b/callbacks/create.go index 684d5530..283d3fd1 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -55,29 +55,44 @@ func Create(config *Config) func(db *gorm.DB) { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { - if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - if insertID, err := result.LastInsertId(); err == nil { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID-- - } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID++ + db.RowsAffected, _ = result.RowsAffected() + if db.RowsAffected > 0 { + if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + if insertID, err := result.LastInsertId(); err == nil { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)) + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID-- + } + } + } else { + allUpdated := int(db.RowsAffected) == db.Statement.ReflectValue.Len() + isZero := true + + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + + if !allUpdated { + _, isZero = db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)) + } + + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID++ + } + } } + case reflect.Struct: + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) } - case reflect.Struct: - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } else { + db.AddError(err) } - } else { - db.AddError(err) } } - db.RowsAffected, _ = result.RowsAffected() } else { db.AddError(err) } @@ -129,9 +144,19 @@ func CreateWithReturning(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: + c := db.Statement.Clauses["ON CONFLICT"] + onConflict, _ := c.Expression.(clause.OnConflict) + for rows.Next() { + BEGIN: for idx, field := range fields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + fieldValue := field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))) + if onConflict.DoNothing && !fieldValue.IsZero() { + db.RowsAffected++ + goto BEGIN + } + + values[idx] = fieldValue.Addr().Interface() } db.RowsAffected++ @@ -211,7 +236,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { case reflect.Slice, reflect.Array: stmt.SQL.Grow(stmt.ReflectValue.Len() * 15) values.Values = make([][]interface{}, stmt.ReflectValue.Len()) - defaultValueFieldsHavingValue := map[string][]interface{}{} + defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} for i := 0; i < stmt.ReflectValue.Len(); i++ { rv := reflect.Indirect(stmt.ReflectValue.Index(i)) values.Values[i] = make([]interface{}, len(values.Columns)) @@ -231,20 +256,20 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, isZero := field.ValueOf(rv); !isZero { - if len(defaultValueFieldsHavingValue[field.DBName]) == 0 { - defaultValueFieldsHavingValue[field.DBName] = make([]interface{}, stmt.ReflectValue.Len()) + if len(defaultValueFieldsHavingValue[field]) == 0 { + defaultValueFieldsHavingValue[field] = make([]interface{}, stmt.ReflectValue.Len()) } - defaultValueFieldsHavingValue[field.DBName][i] = v + defaultValueFieldsHavingValue[field][i] = v } } } } - for db, vs := range defaultValueFieldsHavingValue { - values.Columns = append(values.Columns, clause.Column{Name: db}) + for field, vs := range defaultValueFieldsHavingValue { + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) for idx := range values.Values { if vs[idx] == nil { - values.Values[idx] = append(values.Values[idx], clause.Expr{SQL: "DEFAULT"}) + values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) } else { values.Values[idx] = append(values.Values[idx], vs[idx]) } diff --git a/clause/clause.go b/clause/clause.go index 64f08d14..c7d1efeb 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -64,6 +64,7 @@ func (c Clause) Build(builder Builder) { const ( PrimaryKey string = "@@@py@@@" // primary key CurrentTable string = "@@@ct@@@" // current table + Associations string = "@@@as@@@" // associations ) var ( diff --git a/interfaces.go b/interfaces.go index f3e5c028..96289a90 100644 --- a/interfaces.go +++ b/interfaces.go @@ -14,6 +14,7 @@ type Dialector interface { Initialize(*DB) error Migrator(db *DB) Migrator DataTypeOf(*schema.Field) string + DefaultValueOf(*schema.Field) clause.Expression BindVarTo(writer clause.Writer, stmt *Statement, v interface{}) QuoteTo(clause.Writer, string) Explain(sql string, vars ...interface{}) string diff --git a/migrator/migrator.go b/migrator/migrator.go index a4cc99a6..b598bd93 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -57,10 +57,6 @@ func (m Migrator) DataTypeOf(field *schema.Field) string { func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { expr.SQL = m.DataTypeOf(field) - if field.AutoIncrement { - expr.SQL += " AUTO_INCREMENT" - } - if field.NotNull { expr.SQL += " NOT NULL" } diff --git a/schema/field_test.go b/schema/field_test.go index 0936c0d1..7970b614 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -235,7 +235,7 @@ func TestParseFieldWithPermission(t *testing.T) { } fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true, HasDefaultValue: true}, + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true, HasDefaultValue: true, AutoIncrement: true}, {Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false}, {Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true}, {Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: true}, diff --git a/schema/relationship.go b/schema/relationship.go index afa083ed..c69a4a09 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -251,11 +251,13 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } relation.JoinTable.Name = many2many relation.JoinTable.Table = schema.namer.JoinTableName(many2many) + relation.JoinTable.PrimaryFields = make([]*Field, len(relation.JoinTable.Fields)) // build references - for _, f := range relation.JoinTable.Fields { + for idx, f := range relation.JoinTable.Fields { // use same data type for foreign keys f.DataType = fieldsMap[f.Name].DataType + relation.JoinTable.PrimaryFields[idx] = f relation.References = append(relation.References, &Reference{ PrimaryKey: fieldsMap[f.Name], diff --git a/schema/schema.go b/schema/schema.go index 5b360f5e..e5894443 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -188,6 +188,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } field.HasDefaultValue = true + field.AutoIncrement = true } } diff --git a/schema/schema_test.go b/schema/schema_test.go index 4ec7ff0c..99781e47 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -32,7 +32,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { // check fields fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64, HasDefaultValue: true}, + {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64, HasDefaultValue: true, AutoIncrement: true}, {Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time}, {Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time}, {Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time}, @@ -125,7 +125,7 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) { // check fields fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64, HasDefaultValue: true}, + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, {Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time}, diff --git a/tests/associations_has_one_test.go b/tests/associations_has_one_test.go index a6dcc6c5..f487bd9e 100644 --- a/tests/associations_has_one_test.go +++ b/tests/associations_has_one_test.go @@ -68,6 +68,7 @@ func TestHasOneAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Account", 0, "after delete") // Prepare Data for Clear + account = Account{Number: "account-has-one-append"} if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { t.Fatalf("Error happened when append Account, got %v", err) } @@ -185,6 +186,7 @@ func TestPolymorphicHasOneAssociation(t *testing.T) { AssertAssociationCount(t, pet2, "Toy", 0, "after delete") // Prepare Data for Clear + toy = Toy{Name: "toy-has-one-append"} if err := DB.Model(&pet2).Association("Toy").Append(&toy); err != nil { t.Fatalf("Error happened when append Toy, got %v", err) } diff --git a/tests/go.mod b/tests/go.mod index a2121b7a..1cd56f6b 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,11 +6,11 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.2.1 - gorm.io/driver/postgres v0.2.1 - gorm.io/driver/sqlite v1.0.4 - gorm.io/driver/sqlserver v0.2.1 - gorm.io/gorm v0.2.7 + gorm.io/driver/mysql v0.2.2 + gorm.io/driver/postgres v0.2.2 + gorm.io/driver/sqlite v1.0.5 + gorm.io/driver/sqlserver v0.2.2 + gorm.io/gorm v0.2.9 ) replace gorm.io/gorm => ../ diff --git a/tests/helper_test.go b/tests/helper_test.go index b05f5297..cc0d808c 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -58,7 +58,6 @@ func GetUser(name string, config Config) *User { for i := 0; i < config.Languages; i++ { name := name + "_locale_" + strconv.Itoa(i+1) language := Language{Code: name, Name: name} - DB.Create(&language) user.Languages = append(user.Languages, language) } diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go index 4a25a69b..d40309e7 100644 --- a/tests/preload_suits_test.go +++ b/tests/preload_suits_test.go @@ -4,6 +4,7 @@ import ( "database/sql" "encoding/json" "reflect" + "sort" "testing" "gorm.io/gorm" @@ -735,7 +736,6 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { t.Error(err) } - return if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) @@ -1459,6 +1459,12 @@ func TestPrefixedPreloadDuplication(t *testing.T) { t.Error(err) } + for _, level1 := range append(got, want...) { + sort.Slice(level1.Level2.Level3.Level4s, func(i, j int) bool { + return level1.Level2.Level3.Level4s[i].ID > level1.Level2.Level3.Level4s[j].ID + }) + } + if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } diff --git a/tests/tests_all.sh b/tests/tests_all.sh index fd696e38..a321fe31 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -8,7 +8,7 @@ if [ -d tests ] then cd tests cp go.mod go.mod.bak - sed '/gorm.io\/driver/d' go.mod.bak > go.mod + sed '/$[[:space:]]*gorm.io\/driver/d' go.mod.bak > go.mod cd .. fi diff --git a/tests/tests_test.go b/tests/tests_test.go index 9e135b4e..fa8bad5c 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -34,6 +34,9 @@ func init() { } RunMigrations() + if DB.Dialector.Name() == "sqlite" { + DB.Exec("PRAGMA foreign_keys = ON") + } } } @@ -66,7 +69,6 @@ func OpenTestConnection() (db *gorm.DB, err error) { default: log.Println("testing sqlite3...") db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) - db.Exec("PRAGMA foreign_keys = ON") } if debug := os.Getenv("DEBUG"); debug == "true" { diff --git a/utils/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go index cd4bbd45..b8452ef9 100644 --- a/utils/tests/dummy_dialecter.go +++ b/utils/tests/dummy_dialecter.go @@ -18,6 +18,10 @@ func (DummyDialector) Initialize(*gorm.DB) error { return nil } +func (DummyDialector) DefaultValueOf(field *schema.Field) clause.Expression { + return clause.Expr{SQL: "DEFAULT"} +} + func (DummyDialector) Migrator(*gorm.DB) gorm.Migrator { return nil } From 5883490aa773ad8dbc13c901bb4ffec502417477 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 20 Jun 2020 17:21:01 +0800 Subject: [PATCH 447/881] Select, Omit, Preload supports clause.Associations --- callbacks/helper.go | 15 ++++++++++----- callbacks/query.go | 14 +++++++++++--- tests/create_test.go | 24 +++++++++++++++++++++--- tests/preload_test.go | 23 +++++++++++++++++++++++ 4 files changed, 65 insertions(+), 11 deletions(-) diff --git a/callbacks/helper.go b/callbacks/helper.go index 97c8ad35..3b0cca16 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -19,10 +19,11 @@ func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate boo for _, dbName := range stmt.Schema.DBNames { results[dbName] = true } - break - } - - if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { + } else if column == clause.Associations { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = true + } + } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { results[field.DBName] = true } else { results[column] = true @@ -31,7 +32,11 @@ func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate boo // omit columns for _, omit := range stmt.Omits { - if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { + if omit == clause.Associations { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = false + } + } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { results[field.DBName] = false } else { results[omit] = false diff --git a/callbacks/query.go b/callbacks/query.go index e5557d4a..27d53a4d 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -140,9 +140,17 @@ func Preload(db *gorm.DB) { if db.Error == nil && len(db.Statement.Preloads) > 0 { preloadMap := map[string][]string{} for name := range db.Statement.Preloads { - preloadFields := strings.Split(name, ".") - for idx := range preloadFields { - preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] + if name == clause.Associations { + for _, rel := range db.Statement.Schema.Relationships.Relations { + if rel.Schema == db.Statement.Schema { + preloadMap[rel.Name] = []string{rel.Name} + } + } + } else { + preloadFields := strings.Split(name, ".") + for idx := range preloadFields { + preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] + } } } diff --git a/tests/create_test.go b/tests/create_test.go index 351f02a3..4bf623b3 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -6,6 +6,7 @@ import ( "github.com/jinzhu/now" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -282,13 +283,30 @@ func TestOmitWithCreate(t *testing.T) { user := *GetUser("omit_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) DB.Omit("Account", "Toys", "Manager", "Birthday").Create(&user) - var user2 User - DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&user2, user.ID) + var result User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result, user.ID) user.Birthday = nil user.Account = Account{} user.Toys = nil user.Manager = nil - CheckUser(t, user2, user) + CheckUser(t, result, user) + + user2 := *GetUser("omit_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Omit(clause.Associations).Create(&user2) + + var result2 User + DB.Preload(clause.Associations).First(&result2, user2.ID) + + user2.Account = Account{} + user2.Toys = nil + user2.Manager = nil + user2.Company = Company{} + user2.Pets = nil + user2.Team = nil + user2.Languages = nil + user2.Friends = nil + + CheckUser(t, result2, user2) } diff --git a/tests/preload_test.go b/tests/preload_test.go index 06e38f09..3caa17b4 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -9,6 +9,29 @@ import ( . "gorm.io/gorm/utils/tests" ) +func TestPreloadWithAssociations(t *testing.T) { + var user = *GetUser("preload_with_associations", Config{ + Account: true, + Pets: 2, + Toys: 3, + Company: true, + Manager: true, + Team: 4, + Languages: 3, + Friends: 1, + }) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + var user2 User + DB.Preload(clause.Associations).Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) +} + func TestNestedPreload(t *testing.T) { var user = *GetUser("nested_preload", Config{Pets: 2}) From fee1e4aafd39800814c08c8ab4d5c2d1dc773856 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 21 Jun 2020 10:19:16 +0800 Subject: [PATCH 448/881] Fix create foreign keys for many2many relations --- gorm.go | 7 ++++++ migrator/migrator.go | 29 ++++++++++++++++++------- schema/naming.go | 2 +- schema/relationship.go | 49 +++++++++++++++++++++++++++++++++++++++++- tests/go.mod | 4 ++-- 5 files changed, 79 insertions(+), 12 deletions(-) diff --git a/gorm.go b/gorm.go index a5f8bbfd..e3193f59 100644 --- a/gorm.go +++ b/gorm.go @@ -293,6 +293,13 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac } } + for name, rel := range relation.JoinTable.Relationships.Relations { + if _, ok := joinSchema.Relationships.Relations[name]; !ok { + rel.Schema = joinSchema + joinSchema.Relationships.Relations[name] = rel + } + } + relation.JoinTable = joinSchema } else { return fmt.Errorf("failed to found relation: %v", field) diff --git a/migrator/migrator.go b/migrator/migrator.go index b598bd93..90ab7892 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -88,7 +88,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { return err } } else { - if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { for _, field := range stmt.Schema.FieldsByDBName { if !tx.Migrator().HasColumn(value, field.DBName) { if err := tx.Migrator().AddColumn(value, field.DBName); err != nil { @@ -120,9 +120,13 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { if rel.JoinTable != nil { joinValue := reflect.New(rel.JoinTable.ModelType).Interface() if !tx.Migrator().HasTable(rel.JoinTable.Table) { - defer tx.Table(rel.JoinTable.Table).Migrator().CreateTable(joinValue) + defer func() { + errr = tx.Table(rel.JoinTable.Table).Migrator().CreateTable(joinValue) + }() } else { - defer tx.Table(rel.JoinTable.Table).Migrator().AutoMigrate(joinValue) + defer func() { + errr = tx.Table(rel.JoinTable.Table).Migrator().AutoMigrate(joinValue) + }() } } } @@ -139,7 +143,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { tx := m.DB.Session(&gorm.Session{}) - if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { var ( createTableSQL = "CREATE TABLE ? (" values = []interface{}{clause.Table{Name: stmt.Table}} @@ -166,7 +170,9 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, idx := range stmt.Schema.ParseIndexes() { if m.CreateIndexAfterCreateTable { - defer tx.Migrator().CreateIndex(value, idx.Name) + defer func() { + errr = tx.Migrator().CreateIndex(value, idx.Name) + }() } else { createTableSQL += "INDEX ? ?," values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) @@ -186,7 +192,9 @@ func (m Migrator) CreateTable(values ...interface{}) error { if rel.JoinTable != nil { joinValue := reflect.New(rel.JoinTable.ModelType).Interface() if !tx.Migrator().HasTable(rel.JoinTable.Table) { - defer tx.Table(rel.JoinTable.Table).Migrator().CreateTable(joinValue) + defer func(table string, joinValue interface{}) { + errr = tx.Table(table).Migrator().CreateTable(joinValue) + }(rel.JoinTable.Table, joinValue) } } } @@ -204,7 +212,8 @@ func (m Migrator) CreateTable(values ...interface{}) error { createTableSQL += fmt.Sprint(tableOption) } - return tx.Exec(createTableSQL, values...).Error + errr = tx.Exec(createTableSQL, values...).Error + return errr }); err != nil { return err } @@ -553,6 +562,10 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { dep.Depends = append(dep.Depends, c.ReferenceSchema) } + + if rel.JoinTable != nil && rel.Schema != rel.FieldSchema { + dep.Depends = append(dep.Depends, rel.FieldSchema) + } } valuesMap[dep.Schema.Table] = dep @@ -566,6 +579,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i if _, ok := orderedModelNamesMap[name]; ok { return // avoid loop } + orderedModelNamesMap[name] = true dep := valuesMap[name] for _, d := range dep.Depends { @@ -578,7 +592,6 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i } orderedModelNames = append(orderedModelNames, name) - orderedModelNamesMap[name] = true } for _, value := range values { diff --git a/schema/naming.go b/schema/naming.go index f7c82f32..d2a4919f 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -46,7 +46,7 @@ func (ns NamingStrategy) JoinTableName(str string) string { // RelationshipFKName generate fk name for relation func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { - return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Field.Name)) + return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name)) } // CheckerName generate checker name diff --git a/schema/relationship.go b/schema/relationship.go index c69a4a09..a13d53b9 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -253,16 +253,63 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel relation.JoinTable.Table = schema.namer.JoinTableName(many2many) relation.JoinTable.PrimaryFields = make([]*Field, len(relation.JoinTable.Fields)) + relName := relation.Schema.Name + relRefName := relation.FieldSchema.Name + if relName == relRefName { + relRefName = relation.Field.Name + } + + if _, ok := relation.JoinTable.Relationships.Relations[relName]; !ok { + relation.JoinTable.Relationships.Relations[relName] = &Relationship{ + Name: relName, + Type: BelongsTo, + Schema: relation.JoinTable, + FieldSchema: relation.Schema, + } + } else { + relation.JoinTable.Relationships.Relations[relName].References = []*Reference{} + } + + if _, ok := relation.JoinTable.Relationships.Relations[relRefName]; !ok { + relation.JoinTable.Relationships.Relations[relRefName] = &Relationship{ + Name: relRefName, + Type: BelongsTo, + Schema: relation.JoinTable, + FieldSchema: relation.FieldSchema, + } + } else { + relation.JoinTable.Relationships.Relations[relRefName].References = []*Reference{} + } + // build references for idx, f := range relation.JoinTable.Fields { // use same data type for foreign keys f.DataType = fieldsMap[f.Name].DataType relation.JoinTable.PrimaryFields[idx] = f + ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] + + if ownPriamryField { + joinRel := relation.JoinTable.Relationships.Relations[relName] + joinRel.Field = relation.Field + joinRel.References = append(joinRel.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + }) + } else { + joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] + if joinRefRel.Field == nil { + joinRefRel.Field = relation.Field + } + joinRefRel.References = append(joinRefRel.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + }) + } relation.References = append(relation.References, &Reference{ PrimaryKey: fieldsMap[f.Name], ForeignKey: f, - OwnPrimaryKey: schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name], + OwnPrimaryKey: ownPriamryField, }) } return diff --git a/tests/go.mod b/tests/go.mod index 1cd56f6b..85ef8dcb 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,9 +6,9 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.2.2 + gorm.io/driver/mysql v0.2.3 gorm.io/driver/postgres v0.2.2 - gorm.io/driver/sqlite v1.0.5 + gorm.io/driver/sqlite v1.0.6 gorm.io/driver/sqlserver v0.2.2 gorm.io/gorm v0.2.9 ) From d0764bead1bb0283c1f68842ce39cb4a001b8676 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 21 Jun 2020 13:53:13 +0800 Subject: [PATCH 449/881] Test migrate with comment and check created constraints --- migrator.go | 4 ++++ migrator/migrator.go | 36 +++++++++++++++--------------------- schema/index.go | 18 ++++++++++-------- tests/go.mod | 4 ++-- tests/migrate_test.go | 32 ++++++++++++++++++++++++++++++++ 5 files changed, 63 insertions(+), 31 deletions(-) diff --git a/migrator.go b/migrator.go index d45e3ac2..37051f81 100644 --- a/migrator.go +++ b/migrator.go @@ -2,6 +2,9 @@ package gorm import ( "database/sql" + + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" ) // Migrator returns migrator @@ -27,6 +30,7 @@ type Migrator interface { // Database CurrentDatabase() string + FullDataTypeOf(*schema.Field) clause.Expr // Tables CreateTable(dst ...interface{}) error diff --git a/migrator/migrator.go b/migrator/migrator.go index 90ab7892..64e02ac7 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -18,9 +18,8 @@ type Migrator struct { // Config schema config type Config struct { - CreateIndexAfterCreateTable bool - AllowDeferredConstraintsWhenAutoMigrate bool - DB *gorm.DB + CreateIndexAfterCreateTable bool + DB *gorm.DB gorm.Dialector } @@ -120,13 +119,13 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { if rel.JoinTable != nil { joinValue := reflect.New(rel.JoinTable.ModelType).Interface() if !tx.Migrator().HasTable(rel.JoinTable.Table) { - defer func() { - errr = tx.Table(rel.JoinTable.Table).Migrator().CreateTable(joinValue) - }() + defer func(table string, joinValue interface{}) { + errr = tx.Table(table).Migrator().CreateTable(joinValue) + }(rel.JoinTable.Table, joinValue) } else { - defer func() { - errr = tx.Table(rel.JoinTable.Table).Migrator().AutoMigrate(joinValue) - }() + defer func(table string, joinValue interface{}) { + errr = tx.Table(table).Migrator().AutoMigrate(joinValue) + }(rel.JoinTable.Table, joinValue) } } } @@ -154,7 +153,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { field := stmt.Schema.FieldsByDBName[dbName] createTableSQL += fmt.Sprintf("? ?") hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") - values = append(values, clause.Column{Name: dbName}, m.FullDataTypeOf(field)) + values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) createTableSQL += "," } @@ -170,9 +169,9 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, idx := range stmt.Schema.ParseIndexes() { if m.CreateIndexAfterCreateTable { - defer func() { - errr = tx.Migrator().CreateIndex(value, idx.Name) - }() + defer func(value interface{}, name string) { + errr = tx.Migrator().CreateIndex(value, name) + }(value, idx.Name) } else { createTableSQL += "INDEX ? ?," values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) @@ -277,7 +276,7 @@ func (m Migrator) AddColumn(value interface{}, field string) error { if field := stmt.Schema.LookUpField(field); field != nil { return m.DB.Exec( "ALTER TABLE ? ADD ? ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), ).Error } return fmt.Errorf("failed to look up field with name: %s", field) @@ -301,7 +300,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { if field := stmt.Schema.LookUpField(field); field != nil { return m.DB.Exec( "ALTER TABLE ? ALTER COLUMN ? TYPE ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), ).Error } return fmt.Errorf("failed to look up field with name: %s", field) @@ -436,7 +435,7 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool { m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() return m.DB.Raw( - "SELECT count(*) FROM INFORMATION_SCHEMA.referential_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", + "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", currentDatabase, stmt.Table, name, ).Row().Scan(&count) }) @@ -481,11 +480,6 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { } createIndexSQL += "INDEX ? ON ??" - if idx.Comment != "" { - values = append(values, idx.Comment) - createIndexSQL += " COMMENT ?" - } - if idx.Type != "" { createIndexSQL += " USING " + idx.Type } diff --git a/schema/index.go b/schema/index.go index 4228bba2..cf3338c3 100644 --- a/schema/index.go +++ b/schema/index.go @@ -53,16 +53,18 @@ func (schema *Schema) ParseIndexes() map[string]Index { } func (schema *Schema) LookIndex(name string) *Index { - indexes := schema.ParseIndexes() - for _, index := range indexes { - if index.Name == name { - return &index - } - - for _, field := range index.Fields { - if field.Name == name { + if schema != nil { + indexes := schema.ParseIndexes() + for _, index := range indexes { + if index.Name == name { return &index } + + for _, field := range index.Fields { + if field.Name == name { + return &index + } + } } } diff --git a/tests/go.mod b/tests/go.mod index 85ef8dcb..abe32cd6 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,8 +7,8 @@ require ( github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 gorm.io/driver/mysql v0.2.3 - gorm.io/driver/postgres v0.2.2 - gorm.io/driver/sqlite v1.0.6 + gorm.io/driver/postgres v0.2.3 + gorm.io/driver/sqlite v1.0.7 gorm.io/driver/sqlserver v0.2.2 gorm.io/gorm v0.2.9 ) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 194b5cbf..fce4c4aa 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -15,6 +15,8 @@ func TestMigrate(t *testing.T) { rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) + DB.Migrator().DropTable("user_speaks", "user_friends") + if err := DB.Migrator().DropTable(allModels...); err != nil { t.Fatalf("Failed to drop table, got error %v", err) } @@ -28,6 +30,36 @@ func TestMigrate(t *testing.T) { t.Fatalf("Failed to create table for %#v---", m) } } + + for _, indexes := range [][2]string{ + {"user_speaks", "fk_user_speaks_user"}, + {"user_speaks", "fk_user_speaks_language"}, + {"user_friends", "fk_user_friends_user"}, + {"user_friends", "fk_user_friends_friends"}, + {"accounts", "fk_users_account"}, + {"users", "fk_users_team"}, + {"users", "fk_users_manager"}, + {"users", "fk_users_company"}, + } { + if !DB.Migrator().HasConstraint(indexes[0], indexes[1]) { + t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1]) + } + } +} + +func TestMigrateWithComment(t *testing.T) { + type UserWithComment struct { + gorm.Model + Name string `gorm:"size:111;index:,comment:这是一个index;comment:this is a 字段"` + } + + if err := DB.Migrator().DropTable(&UserWithComment{}); err != nil { + t.Fatalf("Failed to drop table, got error %v", err) + } + + if err := DB.AutoMigrate(&UserWithComment{}); err != nil { + t.Fatalf("Failed to auto migrate, but got error %v", err) + } } func TestTable(t *testing.T) { From 7851faa094ef6369caccd1b9ba08c344c00ca9f5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 21 Jun 2020 18:01:50 +0800 Subject: [PATCH 450/881] Allow close prepared statements, double check before prepare --- gorm.go | 4 ++-- prepare_stmt.go | 22 +++++++++++++++++++--- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/gorm.go b/gorm.go index e3193f59..6027b4bb 100644 --- a/gorm.go +++ b/gorm.go @@ -102,7 +102,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { if config.PrepareStmt { db.ConnPool = &PreparedStmtDB{ ConnPool: db.ConnPool, - stmts: map[string]*sql.Stmt{}, + Stmts: map[string]*sql.Stmt{}, } } @@ -146,7 +146,7 @@ func (db *DB) Session(config *Session) *DB { if config.PrepareStmt { tx.Statement.ConnPool = &PreparedStmtDB{ ConnPool: db.Config.ConnPool, - stmts: map[string]*sql.Stmt{}, + Stmts: map[string]*sql.Stmt{}, } } diff --git a/prepare_stmt.go b/prepare_stmt.go index bc11abbf..ba9b04b6 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -7,23 +7,39 @@ import ( ) type PreparedStmtDB struct { - stmts map[string]*sql.Stmt + Stmts map[string]*sql.Stmt mux sync.RWMutex ConnPool } +func (db *PreparedStmtDB) Close() { + db.mux.Lock() + for k, stmt := range db.Stmts { + delete(db.Stmts, k) + stmt.Close() + } + + db.mux.Unlock() +} + func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { db.mux.RLock() - if stmt, ok := db.stmts[query]; ok { + if stmt, ok := db.Stmts[query]; ok { db.mux.RUnlock() return stmt, nil } db.mux.RUnlock() db.mux.Lock() + // double check + if stmt, ok := db.Stmts[query]; ok { + db.mux.Unlock() + return stmt, nil + } + stmt, err := db.ConnPool.PrepareContext(context.Background(), query) if err == nil { - db.stmts[query] = stmt + db.Stmts[query] = stmt } db.mux.Unlock() From 5d044642d1825dd35f3e32dc3284142ba49bb55e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 22 Jun 2020 11:04:44 +0800 Subject: [PATCH 451/881] Allow DisableForeignKeyConstraintWhenMigrating --- gorm.go | 2 ++ migrator/migrator.go | 24 ++++++++++++++---------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/gorm.go b/gorm.go index 6027b4bb..47a209ab 100644 --- a/gorm.go +++ b/gorm.go @@ -30,6 +30,8 @@ type Config struct { PrepareStmt bool // DisableAutomaticPing DisableAutomaticPing bool + // DisableForeignKeyConstraintWhenMigrating + DisableForeignKeyConstraintWhenMigrating bool // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder diff --git a/migrator/migrator.go b/migrator/migrator.go index 64e02ac7..a239c926 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -97,11 +97,13 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } for _, rel := range stmt.Schema.Relationships.Relations { - if constraint := rel.ParseConstraint(); constraint != nil { - if constraint.Schema == stmt.Schema { - if !tx.Migrator().HasConstraint(value, constraint.Name) { - if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { - return err + if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { + if constraint := rel.ParseConstraint(); constraint != nil { + if constraint.Schema == stmt.Schema { + if !tx.Migrator().HasConstraint(value, constraint.Name) { + if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { + return err + } } } } @@ -179,11 +181,13 @@ func (m Migrator) CreateTable(values ...interface{}) error { } for _, rel := range stmt.Schema.Relationships.Relations { - if constraint := rel.ParseConstraint(); constraint != nil { - if constraint.Schema == stmt.Schema { - sql, vars := buildConstraint(constraint) - createTableSQL += sql + "," - values = append(values, vars...) + if !m.DB.DisableForeignKeyConstraintWhenMigrating { + if constraint := rel.ParseConstraint(); constraint != nil { + if constraint.Schema == stmt.Schema { + sql, vars := buildConstraint(constraint) + createTableSQL += sql + "," + values = append(values, vars...) + } } } From 59d7150917183005c9658c28ad0d3e5e55780a9a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 22 Jun 2020 20:22:15 +0800 Subject: [PATCH 452/881] Update README --- README.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 6c2c7731..a73be1b9 100644 --- a/README.md +++ b/README.md @@ -12,13 +12,15 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Overview -* Full-Featured ORM (almost) -* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism) +* Full-Featured ORM +* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism, Single-table inheritance) * Hooks (Before/After Create/Save/Update/Delete/Find) -* Preloading (eager loading) -* Transactions +* Eager loading with `Preload`, `Joins` +* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point +* Context, Prepared Statment Mode, DryRun Mode +* Batch Insert, FindInBatches, Find To Map +* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints * Composite Primary Key -* SQL Builder * Auto Migrations * Logger * Extendable, write Plugins based on GORM callbacks From 60d1e68567b9592f9620be914dc9d826884e1756 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 22 Jun 2020 22:32:12 +0800 Subject: [PATCH 453/881] Update github action CI --- .github/workflows/ci.yml | 157 +++++++++++++++++++++++++++++++++++++++ .github/workflows/go.yml | 73 ------------------ 2 files changed, 157 insertions(+), 73 deletions(-) create mode 100644 .github/workflows/ci.yml delete mode 100644 .github/workflows/go.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..b60e369a --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,157 @@ +name: ci + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + # Label of the container job + tests: + runs-on: ubuntu-latest + strategy: + matrix: + go: [ '1.14', '1.13' ] + + services: + # start postgres + postgres: + image: postgres:latest + env: + POSTGRES_PASSWORD: gorm + POSTGRES_USER: gorm + POSTGRES_DB: gorm + TZ: Asia/Shanghai + ports: + - 9920:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + postgres11: + image: postgres:11 + env: + POSTGRES_PASSWORD: gorm + POSTGRES_USER: gorm + POSTGRES_DB: gorm + TZ: Asia/Shanghai + ports: + - 9921:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + postgres10: + image: postgres:10 + env: + POSTGRES_PASSWORD: gorm + POSTGRES_USER: gorm + POSTGRES_DB: gorm + TZ: Asia/Shanghai + ports: + - 9922:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + # start mysql + mysql: + image: mysql:latest + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + ports: + - 9910:3306 + + mysql57: + image: mysql:5.7 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + ports: + - 9911:3306 + + mysql56: + image: mysql:5.6 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + ports: + - 9912:3306 + + mariadb: + image: mariadb:latest + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + ports: + - 9913:3306 + + # start mssql + mssql: + image: mcmoe/mssqldocker:latest + env: + ACCEPT_EULA: Y + SA_PASSWORD: LoremIpsum86 + MSSQL_DB: gorm + MSSQL_USER: gorm + MSSQL_PASSWORD: LoremIpsum86 + ports: + - 9930:1433 + + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + + - name: show ports + run: netstat -lntp + + - name: run sqlite + run: GORM_DIALECT=sqlite GORM_VERBOSE=true ./tests/tests_all.sh + + - name: run mysql + run: GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + + - name: run mysql57 + run: GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(localhost:9911)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + + - name: run mysql56 + run: GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(localhost:9912)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + + - name: run mariadb + run: GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(localhost:9913)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + + - name: run postgres + run: GORM_DIALECT=postgres GORM_VERBOSE=true GORM_DSN="user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh + + - name: run postgres11 + run: GORM_DIALECT=postgres GORM_VERBOSE=true GORM_DSN="user=gorm password=gorm DB.name=gorm port=9921 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh + + - name: run postgres10 + run: GORM_DIALECT=postgres GORM_VERBOSE=true GORM_DSN="user=gorm password=gorm DB.name=gorm port=9922 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh + + - name: run mssql + run: GORM_DIALECT=mssql GORM_VERBOSE=true GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml deleted file mode 100644 index a5dc41a3..00000000 --- a/.github/workflows/go.yml +++ /dev/null @@ -1,73 +0,0 @@ -name: Go - -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] - -jobs: - # Label of the container job - containerTest: - # Containers must run in Linux based operating systems - runs-on: ubuntu-latest - # Docker Hub image that `container-job` executes in - #container: node:10.18-jessie - - # Service containers to run with `container-job` - services: - # start postgres - postgres: - image: postgres:latest - env: - POSTGRES_PASSWORD: gorm - POSTGRES_USER: gorm - POSTGRES_DB: gorm - TZ: Asia/Shanghai - - ports: - - 9920:5432 - # Set health checks to wait until postgres has started - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - - # start mysql - mysql: - image: mysql:latest - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - ports: - - 9910:3306 - # start mssql - mssql: - image: mcmoe/mssqldocker:latest - env: - ACCEPT_EULA: Y - SA_PASSWORD: LoremIpsum86 - MSSQL_DB: gorm - MSSQL_USER: gorm - MSSQL_PASSWORD: LoremIpsum86 - ports: - - 9930:1433 - steps: - - name: Set up Go 1.x - uses: actions/setup-go@v2 - with: - go-version: ^1.13 - id: go - - - name: Check out code into the Go module directory - uses: actions/checkout@v2 - - - name: show ports - run: netstat -lntp - - - name: run tests - run: cd tests && ./tests_all.sh From 71ae2ddbeeec6217c0418df71e247ef88597f371 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 22 Jun 2020 22:51:54 +0800 Subject: [PATCH 454/881] Refactor github actions --- .github/workflows/ci.yml | 173 +++++++++++++++++---------------------- 1 file changed, 75 insertions(+), 98 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b60e369a..01d06b77 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,16 +8,68 @@ on: jobs: # Label of the container job - tests: + sqlite: runs-on: ubuntu-latest strategy: matrix: - go: [ '1.14', '1.13' ] + go: ['1.14', '1.13'] + + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + + - name: run sqlite + run: GORM_DIALECT=sqlite GORM_VERBOSE=true ./tests/tests_all.sh + + + mysql: + runs-on: ubuntu-latest + strategy: + matrix: + dbversion: ['mysql:latest', 'mysql:5.7', 'mysql:5.6', 'mariadb:latest'] + go: ['1.14', '1.13'] + + services: + mysql: + image: ${{ matrix.dbversion }} + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + ports: + - 9910:3306 + + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + + - name: show ports + run: netstat -lntp + + - name: run mysql + run: GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + + postgres: + runs-on: ubuntu-latest + strategy: + matrix: + dbversion: ['postgres:latest', 'postgres:11', 'postgres:10'] + go: ['1.14', '1.13'] services: - # start postgres postgres: - image: postgres:latest + image: ${{ matrix.dbversion }} env: POSTGRES_PASSWORD: gorm POSTGRES_USER: gorm @@ -32,80 +84,29 @@ jobs: --health-timeout 5s --health-retries 5 - postgres11: - image: postgres:11 - env: - POSTGRES_PASSWORD: gorm - POSTGRES_USER: gorm - POSTGRES_DB: gorm - TZ: Asia/Shanghai - ports: - - 9921:5432 - # Set health checks to wait until postgres has started - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} - postgres10: - image: postgres:10 - env: - POSTGRES_PASSWORD: gorm - POSTGRES_USER: gorm - POSTGRES_DB: gorm - TZ: Asia/Shanghai - ports: - - 9922:5432 - # Set health checks to wait until postgres has started - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 + - name: Check out code into the Go module directory + uses: actions/checkout@v2 - # start mysql - mysql: - image: mysql:latest - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - ports: - - 9910:3306 + - name: show ports + run: netstat -lntp - mysql57: - image: mysql:5.7 - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - ports: - - 9911:3306 + - name: run postgres + run: GORM_DIALECT=postgres GORM_VERBOSE=true GORM_DSN="user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh - mysql56: - image: mysql:5.6 - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - ports: - - 9912:3306 - mariadb: - image: mariadb:latest - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - ports: - - 9913:3306 + sqlserver: + runs-on: ubuntu-latest + strategy: + matrix: + go: ['1.14', '1.13'] - # start mssql + services: mssql: image: mcmoe/mssqldocker:latest env: @@ -129,29 +130,5 @@ jobs: - name: show ports run: netstat -lntp - - name: run sqlite - run: GORM_DIALECT=sqlite GORM_VERBOSE=true ./tests/tests_all.sh - - - name: run mysql - run: GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh - - - name: run mysql57 - run: GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(localhost:9911)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh - - - name: run mysql56 - run: GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(localhost:9912)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh - - - name: run mariadb - run: GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(localhost:9913)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh - - - name: run postgres - run: GORM_DIALECT=postgres GORM_VERBOSE=true GORM_DSN="user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh - - - name: run postgres11 - run: GORM_DIALECT=postgres GORM_VERBOSE=true GORM_DSN="user=gorm password=gorm DB.name=gorm port=9921 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh - - - name: run postgres10 - run: GORM_DIALECT=postgres GORM_VERBOSE=true GORM_DSN="user=gorm password=gorm DB.name=gorm port=9922 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh - - - name: run mssql - run: GORM_DIALECT=mssql GORM_VERBOSE=true GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh + - name: run sqlserver + run: GORM_DIALECT=sqlserver GORM_VERBOSE=true GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh From c84a8fe5717c5a061e19b1a3022cec864cb45f7e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 22 Jun 2020 23:14:17 +0800 Subject: [PATCH 455/881] Switch to github actions --- .github/workflows/{ci.yml => tests.yml} | 8 +- README.md | 2 +- wercker.yml | 126 ------------------------ 3 files changed, 6 insertions(+), 130 deletions(-) rename .github/workflows/{ci.yml => tests.yml} (97%) delete mode 100644 wercker.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/tests.yml similarity index 97% rename from .github/workflows/ci.yml rename to .github/workflows/tests.yml index 01d06b77..a0aac7f0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/tests.yml @@ -1,10 +1,12 @@ -name: ci +name: tests on: push: - branches: [ master ] + branches-ignore: + - 'gh-pages' pull_request: - branches: [ master ] + branches-ignore: + - 'gh-pages' jobs: # Label of the container job diff --git a/README.md b/README.md index a73be1b9..140c0d28 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) -[![wercker status](https://app.wercker.com/status/55136410c77335a6289ebd58b2f70125/s/master "wercker status")](https://app.wercker.com/project/byKey/55136410c77335a6289ebd58b2f70125) +[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) [![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm) [![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm) diff --git a/wercker.yml b/wercker.yml deleted file mode 100644 index d4fb63e3..00000000 --- a/wercker.yml +++ /dev/null @@ -1,126 +0,0 @@ -# use the default golang container from Docker Hub -box: golang - -services: - - name: mariadb - id: mariadb:latest - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql - id: mysql:latest - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql57 - id: mysql:5.7 - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: mysql56 - id: mysql:5.6 - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - name: postgres - id: postgres:latest - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: postgres11 - id: postgres:11 - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: postgres10 - id: postgres:10 - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - - name: mssql - id: mcmoe/mssqldocker:latest - env: - ACCEPT_EULA: Y - SA_PASSWORD: LoremIpsum86 - MSSQL_DB: gorm - MSSQL_USER: gorm - MSSQL_PASSWORD: LoremIpsum86 - -# The steps that will be executed in the build pipeline -build: - # The steps that will be executed on build - steps: - # Sets the go workspace and places you package - # at the right place in the workspace tree - - setup-go-workspace - - # Gets the dependencies - - script: - name: go get - code: | - cd $WERCKER_SOURCE_DIR - go version - go get -t -v ./... - - # Build the project - - script: - name: go build - code: | - go build ./... - - # Test the project - - script: - name: test sqlite - code: | - GORM_DIALECT=sqlite GORM_VERBOSE=true ./tests/tests_all.sh - - - script: - name: test mariadb - code: | - GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh - - - script: - name: test mysql - code: | - GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh - - - script: - name: test mysql5.7 - code: | - GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh - - - script: - name: test mysql5.6 - code: | - GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh - - - script: - name: test postgres - code: | - GORM_DIALECT=postgres GORM_VERBOSE=true GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" ./tests/tests_all.sh - - - script: - name: test postgres11 - code: | - GORM_DIALECT=postgres GORM_VERBOSE=true GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" ./tests/tests_all.sh - - - script: - name: test postgres10 - code: | - GORM_DIALECT=postgres GORM_VERBOSE=true GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" ./tests/tests_all.sh - - - script: - name: test mssql - code: | - GORM_DIALECT=mssql GORM_VERBOSE=true GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" ./tests/tests_all.sh From 32bd6b3e8f126e1d52e8ebb31b7533389b875ae0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 23 Jun 2020 08:51:01 +0800 Subject: [PATCH 456/881] Fix Count with Select --- finisher_api.go | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 92d4fe72..b443f4b5 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -268,16 +268,18 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 0 { tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) - } else if len(tx.Statement.Selects) == 1 && !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") { - column := tx.Statement.Selects[0] - if tx.Statement.Parse(tx.Statement.Model) == nil { - if f := tx.Statement.Schema.LookUpField(column); f != nil { - column = f.DBName + } else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") { + expr := clause.Expr{SQL: "count(1)"} + + if len(tx.Statement.Selects) == 1 { + if tx.Statement.Parse(tx.Statement.Model) == nil { + if f := tx.Statement.Schema.LookUpField(tx.Statement.Selects[0]); f != nil { + expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: f.DBName}}} + } } } - tx.Statement.AddClause(clause.Select{ - Expression: clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: column}}}, - }) + + tx.Statement.AddClause(clause.Select{Expression: expr}) } tx.Statement.Dest = count From e77e7bb842499e58a9f4b53631bb3ce9c72d6d5a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 23 Jun 2020 09:09:46 +0800 Subject: [PATCH 457/881] Fix nested embedded field with pointer, close #3071 --- schema/field.go | 12 +++++++----- tests/embedded_struct_test.go | 5 +++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/schema/field.go b/schema/field.go index 737f56c4..a8328367 100644 --- a/schema/field.go +++ b/schema/field.go @@ -397,11 +397,11 @@ func (field *Field) setupValuerAndSetter() { default: field.ReflectValueOf = func(value reflect.Value) reflect.Value { v := reflect.Indirect(value) - for _, idx := range field.StructField.Index { - if idx >= 0 { - v = v.Field(idx) + for idx, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) } else { - v = v.Field(-idx - 1) + v = v.Field(-fieldIdx - 1) } if v.Kind() == reflect.Ptr { @@ -436,7 +436,9 @@ func (field *Field) setupValuerAndSetter() { fieldValue := field.ReflectValueOf(value) if reflectV.Type().AssignableTo(field.FieldType.Elem()) { - if fieldValue.IsNil() { + if !fieldValue.IsValid() { + fieldValue = reflect.New(field.FieldType.Elem()) + } else if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } fieldValue.Elem().Set(reflectV) diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 8536b605..7f40a0a4 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -10,10 +10,15 @@ import ( ) func TestEmbeddedStruct(t *testing.T) { + type ReadOnly struct { + ReadOnly *bool + } + type BasePost struct { Id int64 Title string URL string + ReadOnly } type Author struct { From f4bfc435cc84824e0ca3a9c4e21458996bce67d3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 23 Jun 2020 09:38:51 +0800 Subject: [PATCH 458/881] Add register plugin API --- errors.go | 2 ++ gorm.go | 15 +++++++++++++++ interfaces.go | 7 +++++++ 3 files changed, 24 insertions(+) diff --git a/errors.go b/errors.go index 2506ecc5..b41eefae 100644 --- a/errors.go +++ b/errors.go @@ -27,4 +27,6 @@ var ( ErrorModelValueRequired = errors.New("model value required") // ErrUnsupportedDriver unsupported driver ErrUnsupportedDriver = errors.New("unsupported driver") + // ErrRegistered registered + ErrRegistered = errors.New("registered") ) diff --git a/gorm.go b/gorm.go index 47a209ab..c506c6f3 100644 --- a/gorm.go +++ b/gorm.go @@ -39,6 +39,8 @@ type Config struct { ConnPool ConnPool // Dialector database dialector Dialector + // Plugins registered plugins + Plugins map[string]Plugin callbacks *callbacks cacheStore *sync.Map @@ -309,3 +311,16 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac return nil } + +func (db *DB) Use(plugin Plugin) (err error) { + name := plugin.Name() + if _, ok := db.Plugins[name]; !ok { + if err = plugin.Initialize(db); err == nil { + db.Plugins[name] = plugin + } + } else { + return ErrRegistered + } + + return err +} diff --git a/interfaces.go b/interfaces.go index 96289a90..b2ce59b3 100644 --- a/interfaces.go +++ b/interfaces.go @@ -20,6 +20,12 @@ type Dialector interface { Explain(sql string, vars ...interface{}) string } +// Plugin GORM plugin interface +type Plugin interface { + Name() string + Initialize(*DB) error +} + // ConnPool db conns pool interface type ConnPool interface { PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) @@ -28,6 +34,7 @@ type ConnPool interface { QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row } +// SavePointerDialectorInterface save pointer interface type SavePointerDialectorInterface interface { SavePoint(tx *DB, name string) error RollbackTo(tx *DB, name string) error From 1df757113ad47c4347776e3abadb1e19d6b4a55d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 23 Jun 2020 10:36:45 +0800 Subject: [PATCH 459/881] initialize plugins map --- gorm.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/gorm.go b/gorm.go index c506c6f3..1c6d3383 100644 --- a/gorm.go +++ b/gorm.go @@ -87,6 +87,10 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { config.Dialector = dialector } + if config.Plugins == nil { + config.Plugins = map[string]Plugin{} + } + if config.cacheStore == nil { config.cacheStore = &sync.Map{} } From b733d16f56bcc79ab68903bd6f028c521da2b6e7 Mon Sep 17 00:00:00 2001 From: Hinagiku Soranoba Date: Tue, 23 Jun 2020 15:38:36 +0900 Subject: [PATCH 460/881] Create supports Array / ArrayPtr (#3076) * add Array / ArrayPtr create tests * support create using array --- schema/schema.go | 2 +- tests/create_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/schema/schema.go b/schema/schema.go index e5894443..72bc6544 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -73,7 +73,7 @@ type Tabler interface { // get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() - for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr { + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } diff --git a/tests/create_test.go b/tests/create_test.go index 4bf623b3..75059f18 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -189,6 +189,48 @@ func TestPolymorphicHasOne(t *testing.T) { CheckPet(t, *pet, *pet) } }) + + t.Run("Array", func(t *testing.T) { + var pets = [...]Pet{{ + Name: "PolymorphicHasOne-Array-1", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-1"}, + }, { + Name: "PolymorphicHasOne-Array-2", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-2"}, + }, { + Name: "PolymorphicHasOne-Array-3", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-3"}, + }} + + if err := DB.Create(&pets).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, pet := range pets { + CheckPet(t, pet, pet) + } + }) + + t.Run("ArrayPtr", func(t *testing.T) { + var pets = [...]*Pet{{ + Name: "PolymorphicHasOne-Array-1", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-1"}, + }, { + Name: "PolymorphicHasOne-Array-2", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-2"}, + }, { + Name: "PolymorphicHasOne-Array-3", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-3"}, + }} + + if err := DB.Create(&pets).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, pet := range pets { + CheckPet(t, *pet, *pet) + } + }) } func TestCreateEmptyStruct(t *testing.T) { From dd7caa9db0fc598cdcbcfc58b9f1da15d407278d Mon Sep 17 00:00:00 2001 From: mojotv <34467684+mojocn@users.noreply.github.com> Date: Tue, 23 Jun 2020 16:00:04 +0800 Subject: [PATCH 461/881] add macos and windows for sqlite unit test and use cache for go mod package download (#3079) Co-authored-by: EricZhou --- .github/workflows/tests.yml | 65 ++++++++++++++++++++++++++++++------- 1 file changed, 54 insertions(+), 11 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a0aac7f0..106afdc9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -11,10 +11,11 @@ on: jobs: # Label of the container job sqlite: - runs-on: ubuntu-latest strategy: matrix: go: ['1.14', '1.13'] + platform: [ubuntu-latest, macos-latest] # can not run in windows OS + runs-on: ${{ matrix.platform }} steps: - name: Set up Go 1.x @@ -25,16 +26,47 @@ jobs: - name: Check out code into the Go module directory uses: actions/checkout@v2 + - name: go mod pakcage cache + uses: actions/cache@v2 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} + - name: run sqlite run: GORM_DIALECT=sqlite GORM_VERBOSE=true ./tests/tests_all.sh + sqlite_windows: + strategy: + matrix: + go: ['1.14', '1.13'] + platform: [windows-latest] + runs-on: ${{ matrix.platform }} + + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + + - name: go mod pakcage cache + uses: actions/cache@v2 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} + + - name: run sqlite + run: cd tests && set GORM_DIALECT=sqlite && go test $race -count=1 -v ./... #run the line in widnows's CMD, default GORM_DIALECT is sqlite mysql: - runs-on: ubuntu-latest strategy: matrix: dbversion: ['mysql:latest', 'mysql:5.7', 'mysql:5.6', 'mariadb:latest'] go: ['1.14', '1.13'] + platform: [ubuntu-latest] + runs-on: ${{ matrix.platform }} services: mysql: @@ -56,18 +88,23 @@ jobs: - name: Check out code into the Go module directory uses: actions/checkout@v2 - - name: show ports - run: netstat -lntp + + - name: go mod pakcage cache + uses: actions/cache@v2 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: run mysql run: GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh postgres: - runs-on: ubuntu-latest strategy: matrix: dbversion: ['postgres:latest', 'postgres:11', 'postgres:10'] go: ['1.14', '1.13'] + platform: [ubuntu-latest] # can not run in macOS and widnowsOS + runs-on: ${{ matrix.platform }} services: postgres: @@ -95,18 +132,21 @@ jobs: - name: Check out code into the Go module directory uses: actions/checkout@v2 - - name: show ports - run: netstat -lntp + - name: go mod pakcage cache + uses: actions/cache@v2 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: run postgres run: GORM_DIALECT=postgres GORM_VERBOSE=true GORM_DSN="user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh - sqlserver: - runs-on: ubuntu-latest strategy: matrix: go: ['1.14', '1.13'] + platform: [ubuntu-latest] # can not run test in macOS and windows + runs-on: ${{ matrix.platform }} services: mssql: @@ -129,8 +169,11 @@ jobs: - name: Check out code into the Go module directory uses: actions/checkout@v2 - - name: show ports - run: netstat -lntp + - name: go mod pakcage cache + uses: actions/cache@v2 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: run sqlserver run: GORM_DIALECT=sqlserver GORM_VERBOSE=true GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh From 4201f7bdab7826cd4523550d58a969438f6bb50b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 23 Jun 2020 22:14:17 +0800 Subject: [PATCH 462/881] Fix create unique index when creating table, close #3081 --- migrator/migrator.go | 3 +++ tests/migrate_test.go | 17 +++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/migrator/migrator.go b/migrator/migrator.go index a239c926..c8fe17ab 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -175,6 +175,9 @@ func (m Migrator) CreateTable(values ...interface{}) error { errr = tx.Migrator().CreateIndex(value, name) }(value, idx.Name) } else { + if idx.Class != "" { + createTableSQL += idx.Class + " " + } createTableSQL += "INDEX ? ?," values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index fce4c4aa..2c593a70 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -62,6 +62,23 @@ func TestMigrateWithComment(t *testing.T) { } } +func TestMigrateWithUniqueIndex(t *testing.T) { + type UserWithUniqueIndex struct { + ID int + Name string `gorm:"size:20;index:idx_name,unique"` + Date time.Time `gorm:"index:idx_name,unique"` + } + + DB.Migrator().DropTable(&UserWithUniqueIndex{}) + if err := DB.AutoMigrate(&UserWithUniqueIndex{}); err != nil { + t.Fatalf("failed to migrate, got %v", err) + } + + if !DB.Migrator().HasIndex(&UserWithUniqueIndex{}, "idx_name") { + t.Errorf("Failed to find created index") + } +} + func TestTable(t *testing.T) { type TableStruct struct { gorm.Model From 7e1fa4a44de7b1febfc5620cab4afe77276b4a72 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 23 Jun 2020 22:41:41 +0800 Subject: [PATCH 463/881] Fix Count after Session --- finisher_api.go | 4 ++-- tests/count_test.go | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index b443f4b5..6d961811 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -284,8 +284,8 @@ func (db *DB) Count(count *int64) (tx *DB) { tx.Statement.Dest = count tx.callbacks.Query().Execute(tx) - if db.RowsAffected != 1 { - *count = db.RowsAffected + if tx.RowsAffected != 1 { + *count = tx.RowsAffected } return } diff --git a/tests/count_test.go b/tests/count_test.go index 63238089..0662ae5c 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -4,6 +4,7 @@ import ( "fmt" "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -31,6 +32,13 @@ func TestCount(t *testing.T) { t.Errorf("multiple count in chain should works") } + tx := DB.Model(&User{}).Where("name = ?", user1.Name).Session(&gorm.Session{WithConditions: true}) + tx.Count(&count1) + tx.Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2) + if count1 != 1 || count2 != 3 { + t.Errorf("count after new session should works") + } + var count3 int64 if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { t.Errorf("Error happened when count with group, but got %v", err) From 90f817db29b87c7ee0380d1c750c48be64f30617 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 24 Jun 2020 14:48:44 +0800 Subject: [PATCH 464/881] Update issue template --- .github/ISSUE_TEMPLATE.md | 37 ++------------------------------ .github/PULL_REQUEST_TEMPLATE.md | 4 +++- .github/workflows/tests.yml | 34 ++++++++++++++--------------- 3 files changed, 22 insertions(+), 53 deletions(-) diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index 74824a19..ac311633 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -1,38 +1,5 @@ Your issue may already be reported! Please search on the [issue track](https://github.com/go-gorm/gorm/issues) before creating one. -### What version of Go are you using (`go version`)? +To report a bug, your issue *have to* include an [GORM playground pull request link](https://github.com/go-gorm/playground), for general questions, please delete below line. - -### Which database and its version are you using? - - -### Please provide a complete runnable program to reproduce your issue. **IMPORTANT** - -Need to runnable with [GORM's docker compose config](https://github.com/go-gorm/gorm/blob/master/tests/docker-compose.yml) or please provides your config. - -```go -package main - -import ( - "gorm.io/gorm" - "gorm.io/driver/sqlite" -// "gorm.io/driver/mysql" -// "gorm.io/driver/postgres" -// "gorm.io/driver/sqlserver" -) - -func main() { - db, err := gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) - // db, err := gorm.Open(postgres.Open("user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai"), &gorm.Config{}) - // db, err := gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{}) - // db, err := gorm.Open(sqlserver.Open("sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"), &gorm.Config{}) - - /* your code */ - - if /* failure condition */ { - fmt.Println("failed") - } else { - fmt.Println("success") - } -} -``` +## GORM Playground Link: https://github.com/go-gorm/playground/pull/1 (change this to your link) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index b467b6ce..930ff176 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,8 +2,10 @@ Make sure these boxes checked before submitting your pull request. - [] Do only one thing - [] No API-breaking changes -- [] New code/logic commented & tested +- [] New code/logic commented & tested (important) For significant changes like big bug fixes, new features, please open an issue to make an agreement on an implementation design/plan first before starting it. ### What did this pull request do? + +### Use Case diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 106afdc9..15091def 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -32,7 +32,7 @@ jobs: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - - name: run sqlite + - name: Tests run: GORM_DIALECT=sqlite GORM_VERBOSE=true ./tests/tests_all.sh sqlite_windows: @@ -43,22 +43,22 @@ jobs: runs-on: ${{ matrix.platform }} steps: - - name: Set up Go 1.x - uses: actions/setup-go@v2 - with: - go-version: ${{ matrix.go }} + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} - - name: Check out code into the Go module directory - uses: actions/checkout@v2 + - name: Check out code into the Go module directory + uses: actions/checkout@v2 - - name: go mod pakcage cache - uses: actions/cache@v2 - with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} + - name: go mod pakcage cache + uses: actions/cache@v2 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - - name: run sqlite - run: cd tests && set GORM_DIALECT=sqlite && go test $race -count=1 -v ./... #run the line in widnows's CMD, default GORM_DIALECT is sqlite + - name: Tests + run: cd tests && set GORM_DIALECT=sqlite && go test $race -count=1 -v ./... #run the line in widnows's CMD, default GORM_DIALECT is sqlite mysql: strategy: @@ -95,7 +95,7 @@ jobs: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - - name: run mysql + - name: Tests run: GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh postgres: @@ -138,7 +138,7 @@ jobs: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - - name: run postgres + - name: Tests run: GORM_DIALECT=postgres GORM_VERBOSE=true GORM_DSN="user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh sqlserver: @@ -175,5 +175,5 @@ jobs: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - - name: run sqlserver + - name: Tests run: GORM_DIALECT=sqlserver GORM_VERBOSE=true GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh From 67bd842645f98ebf3c8db9a69b454f91e0a7590f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 24 Jun 2020 14:56:04 +0800 Subject: [PATCH 465/881] Update tests all script --- tests/tests_all.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tests_all.sh b/tests/tests_all.sh index a321fe31..47f25401 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -1,3 +1,5 @@ +#!/bin/bash -e + dialects=("sqlite" "mysql" "postgres" "sqlserver") if [[ $(pwd) == *"gorm/tests"* ]]; then From 834cfa2c78866e281732b9a48ea8cef9a8cb6ec8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 24 Jun 2020 15:04:46 +0800 Subject: [PATCH 466/881] Disable GORM_VERBOSE in github action --- .github/workflows/tests.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 15091def..108db6a6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -33,7 +33,7 @@ jobs: key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=sqlite GORM_VERBOSE=true ./tests/tests_all.sh + run: GORM_DIALECT=sqlite ./tests/tests_all.sh sqlite_windows: strategy: @@ -96,7 +96,7 @@ jobs: key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=mysql GORM_VERBOSE=true GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + run: GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh postgres: strategy: @@ -139,7 +139,7 @@ jobs: key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=postgres GORM_VERBOSE=true GORM_DSN="user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh + run: GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh sqlserver: strategy: @@ -176,4 +176,4 @@ jobs: key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=sqlserver GORM_VERBOSE=true GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh + run: GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh From eac6d1bdb9f4b1e04b663dbc8b211f1ffd9217cf Mon Sep 17 00:00:00 2001 From: EricZhou Date: Wed, 24 Jun 2020 16:20:12 +0800 Subject: [PATCH 467/881] issue --- .github/labeler.yml | 6 ++++++ .github/workflows/issue.yml | 15 +++++++++++++++ .github/workflows/issue_stale.yml | 19 +++++++++++++++++++ 3 files changed, 40 insertions(+) create mode 100644 .github/labeler.yml create mode 100644 .github/workflows/issue.yml create mode 100644 .github/workflows/issue_stale.yml diff --git a/.github/labeler.yml b/.github/labeler.yml new file mode 100644 index 00000000..d96bafa0 --- /dev/null +++ b/.github/labeler.yml @@ -0,0 +1,6 @@ +# Add/remove 'critical' label if issue contains the words 'urgent' or 'critical' +HasGormPlaygroundTestCase: + - '(github.com/go-gorm/playground/pull/\d)' + +NoTestCase: + - '(change this to your link)' diff --git a/.github/workflows/issue.yml b/.github/workflows/issue.yml new file mode 100644 index 00000000..0759782c --- /dev/null +++ b/.github/workflows/issue.yml @@ -0,0 +1,15 @@ +name: "Issue-Labeler" +on: + issues: + types: [opened, edited] + +jobs: + triage: + runs-on: ubuntu-latest + steps: + - uses: github/issue-labeler@v2.0 + with: + repo-token: "${{ secrets.GITHUB_TOKEN }}" + configuration-path: ".github/labeler.yml" + not-before: "2020-01-15T02:54:32Z" + enable-versioned-regex: 0 \ No newline at end of file diff --git a/.github/workflows/issue_stale.yml b/.github/workflows/issue_stale.yml new file mode 100644 index 00000000..fadfb522 --- /dev/null +++ b/.github/workflows/issue_stale.yml @@ -0,0 +1,19 @@ +name: Issue cleanup +on: + schedule: + - cron: '0 1 * * *' # At 01:00, everyday +jobs: + triage_issues: + name: Issue triage + runs-on: ubuntu-latest + steps: + - name: Find old issues and mark them stale + uses: Krizzu/issue-triage-action@v1.0.0 + with: + ghToken: ${{ secrets.GITHUB_TOKEN }} + staleAfter: 7 + closeAfter: 14 + staleLabel: "STALE 📺" + staleComment: "This issue is %DAYS_OLD% days old, marking as stale! cc: @%AUTHOR%" + closeComment: "Issue last updated %DAYS_OLD% days ago! Closing down!" + showLogs: true \ No newline at end of file From 4a01d4c263249af6a7e4e1abb2d85163c6dca616 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 24 Jun 2020 16:43:53 +0800 Subject: [PATCH 468/881] Create join table with ReorderModels --- migrator/migrator.go | 37 +++++++++----------------------- tests/multi_primary_keys_test.go | 27 +++++++++++++++++++---- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index c8fe17ab..799bf433 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -116,20 +116,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } } - - // create join table - if rel.JoinTable != nil { - joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !tx.Migrator().HasTable(rel.JoinTable.Table) { - defer func(table string, joinValue interface{}) { - errr = tx.Table(table).Migrator().CreateTable(joinValue) - }(rel.JoinTable.Table, joinValue) - } else { - defer func(table string, joinValue interface{}) { - errr = tx.Table(table).Migrator().AutoMigrate(joinValue) - }(rel.JoinTable.Table, joinValue) - } - } } return nil }); err != nil { @@ -193,16 +179,6 @@ func (m Migrator) CreateTable(values ...interface{}) error { } } } - - // create join table - if rel.JoinTable != nil { - joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !tx.Migrator().HasTable(rel.JoinTable.Table) { - defer func(table string, joinValue interface{}) { - errr = tx.Table(table).Migrator().CreateTable(joinValue) - }(rel.JoinTable.Table, joinValue) - } - } } for _, chk := range stmt.Schema.ParseCheckConstraints() { @@ -551,9 +527,10 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i orderedModelNamesMap = map[string]bool{} valuesMap = map[string]Dependency{} insertIntoOrderedList func(name string) + parseDependence func(value interface{}, addToList bool) ) - parseDependence := func(value interface{}, addToList bool) { + parseDependence = func(value interface{}, addToList bool) { dep := Dependency{ Statement: &gorm.Statement{DB: m.DB, Dest: value}, } @@ -564,8 +541,14 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i dep.Depends = append(dep.Depends, c.ReferenceSchema) } - if rel.JoinTable != nil && rel.Schema != rel.FieldSchema { - dep.Depends = append(dep.Depends, rel.FieldSchema) + if rel.JoinTable != nil { + if rel.Schema != rel.FieldSchema { + dep.Depends = append(dep.Depends, rel.FieldSchema) + } + // append join value + defer func(joinValue interface{}) { + parseDependence(joinValue, autoAdd) + }(reflect.New(rel.JoinTable.ModelType).Interface()) } } diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go index 05267bbb..617010c5 100644 --- a/tests/multi_primary_keys_test.go +++ b/tests/multi_primary_keys_test.go @@ -4,6 +4,8 @@ import ( "reflect" "sort" "testing" + + "gorm.io/gorm" ) type Blog struct { @@ -11,7 +13,7 @@ type Blog struct { Locale string `gorm:"primary_key"` Subject string Body string - Tags []Tag `gorm:"many2many:blog_tags;"` + Tags []Tag `gorm:"many2many:blogs_tags;"` SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;References:id"` LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;References:id"` } @@ -38,7 +40,16 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } - DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") + if name := DB.Dialector.Name(); name == "postgres" { + stmt := gorm.Statement{DB: DB} + stmt.Parse(&Blog{}) + stmt.Schema.LookUpField("ID").Unique = true + stmt.Parse(&Tag{}) + stmt.Schema.LookUpField("ID").Unique = true + // postgers only allow unique constraint matching given keys + } + + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { t.Fatalf("Failed to auto migrate, got error: %v", err) } @@ -127,7 +138,11 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } - DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") + if name := DB.Dialector.Name(); name == "postgres" { + t.Skip("skip postgers due to it only allow unique constraint matching given keys") + } + + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { t.Fatalf("Failed to auto migrate, got error: %v", err) } @@ -248,7 +263,11 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } - DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") + if name := DB.Dialector.Name(); name == "postgres" { + t.Skip("skip postgers due to it only allow unique constraint matching given keys") + } + + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { t.Fatalf("Failed to auto migrate, got error: %v", err) } From 8ce2dd5548689f2281e290b80680764e39c4778b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 24 Jun 2020 19:09:19 +0800 Subject: [PATCH 469/881] Update test script --- tests/main_test.go | 4 ++++ tests/tests_all.sh | 14 ++++---------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/main_test.go b/tests/main_test.go index 9d933caf..5b8c7dbb 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -7,6 +7,10 @@ import ( ) func TestExceptionsWithInvalidSql(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlserver" { + t.Skip("skip sqlserver due to it will raise data race for invalid sql") + } + var columns []string if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { t.Errorf("Should got error with invalid SQL") diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 47f25401..e87ff045 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -19,27 +19,21 @@ for dialect in "${dialects[@]}" ; do then echo "testing ${dialect}..." - race="" - if [ "$GORM_DIALECT" = "sqlserver" ] - then - race="-race" - fi - if [ "$GORM_VERBOSE" = "" ] then - GORM_DIALECT=${dialect} go test $race -count=1 ./... + GORM_DIALECT=${dialect} go test -race -count=1 ./... if [ -d tests ] then cd tests - GORM_DIALECT=${dialect} go test $race -count=1 ./... + GORM_DIALECT=${dialect} go test -race -count=1 ./... cd .. fi else - GORM_DIALECT=${dialect} go test $race -count=1 -v ./... + GORM_DIALECT=${dialect} go test -race -count=1 -v ./... if [ -d tests ] then cd tests - GORM_DIALECT=${dialect} go test $race -count=1 -v ./... + GORM_DIALECT=${dialect} go test -race -count=1 -v ./... cd .. fi fi From 3ec7ed1d51b94490db916b08c8c974f4234f0ccf Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 24 Jun 2020 20:19:28 +0800 Subject: [PATCH 470/881] Upgrade default mysql driver --- tests/go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/go.mod b/tests/go.mod index abe32cd6..f4d93ecb 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,7 +6,7 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.2.3 + gorm.io/driver/mysql v0.2.6 gorm.io/driver/postgres v0.2.3 gorm.io/driver/sqlite v1.0.7 gorm.io/driver/sqlserver v0.2.2 From fb56fe993af7ce155662c17fd24f94722fb3a8eb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 25 Jun 2020 06:38:07 +0800 Subject: [PATCH 471/881] Add default value test --- tests/default_value_test.go | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 tests/default_value_test.go diff --git a/tests/default_value_test.go b/tests/default_value_test.go new file mode 100644 index 00000000..52292cf7 --- /dev/null +++ b/tests/default_value_test.go @@ -0,0 +1,37 @@ +package tests_test + +import ( + "testing" + + "gorm.io/gorm" +) + +func TestDefaultValue(t *testing.T) { + type Harumph struct { + gorm.Model + Email string `gorm:"not null;"` + Name string `gorm:"not null;default:foo"` + Name2 string `gorm:"not null;default:'foo'"` + Age int `gorm:"default:18"` + } + + DB.Migrator().DropTable(&Harumph{}) + + if err := DB.AutoMigrate(&Harumph{}); err != nil { + t.Fatalf("Failed to migrate with default value, got error: %v", err) + } + + var harumph = Harumph{Email: "hello@gorm.io"} + if err := DB.Create(&harumph).Error; err != nil { + t.Fatalf("Failed to create data with default value, got error: %v", err) + } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Age != 18 { + t.Fatalf("Failed to create data with default value, got: %+v", harumph) + } + + var result Harumph + if err := DB.First(&result, "email = ?", "hello@gorm.io").Error; err != nil { + t.Fatalf("Failed to find created data, got error: %v", err) + } else if result.Name != "foo" || result.Name2 != "foo" || result.Age != 18 { + t.Fatalf("Failed to find created data with default data, got %+v", result) + } +} From 1b28c187c0374e3a1347221fece12d8d8d5e40c0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 25 Jun 2020 08:00:10 +0800 Subject: [PATCH 472/881] Fix create with default value --- migrator/migrator.go | 8 ++++---- tests/default_value_test.go | 13 +++++++------ tests/go.mod | 2 ++ 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 799bf433..9c4ce2d5 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -65,10 +65,10 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { } if field.HasDefaultValue && field.DefaultValue != "" { - if field.DataType == schema.String && field.DefaultValueInterface != nil { - defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValue}} - m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValue) - expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValue) + if field.DefaultValueInterface != nil { + defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} + m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) + expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface) } else { expr.SQL += " DEFAULT " + field.DefaultValue } diff --git a/tests/default_value_test.go b/tests/default_value_test.go index 52292cf7..28a456d3 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -9,10 +9,11 @@ import ( func TestDefaultValue(t *testing.T) { type Harumph struct { gorm.Model - Email string `gorm:"not null;"` - Name string `gorm:"not null;default:foo"` - Name2 string `gorm:"not null;default:'foo'"` - Age int `gorm:"default:18"` + Email string `gorm:"not null;index:,unique"` + Name string `gorm:"not null;default:'foo'"` + Name2 string `gorm:"not null;default:'foo'"` + Age int `gorm:"default:18"` + Enabled bool `gorm:"default:true"` } DB.Migrator().DropTable(&Harumph{}) @@ -24,14 +25,14 @@ func TestDefaultValue(t *testing.T) { var harumph = Harumph{Email: "hello@gorm.io"} if err := DB.Create(&harumph).Error; err != nil { t.Fatalf("Failed to create data with default value, got error: %v", err) - } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Age != 18 { + } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Age != 18 || !harumph.Enabled { t.Fatalf("Failed to create data with default value, got: %+v", harumph) } var result Harumph if err := DB.First(&result, "email = ?", "hello@gorm.io").Error; err != nil { t.Fatalf("Failed to find created data, got error: %v", err) - } else if result.Name != "foo" || result.Name2 != "foo" || result.Age != 18 { + } else if result.Name != "foo" || result.Name2 != "foo" || result.Age != 18 || !result.Enabled { t.Fatalf("Failed to find created data with default data, got %+v", result) } } diff --git a/tests/go.mod b/tests/go.mod index f4d93ecb..d43ee8f1 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -14,3 +14,5 @@ require ( ) replace gorm.io/gorm => ../ + +replace gorm.io/driver/sqlserver => /Users/jinzhu/Projects/sqlserver From c5feff1591518ba500898dc6d1a5b8eb7bee1092 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 25 Jun 2020 08:08:37 +0800 Subject: [PATCH 473/881] Fix go.mod --- tests/default_value_test.go | 2 +- tests/go.mod | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/default_value_test.go b/tests/default_value_test.go index 28a456d3..7a7790bc 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -11,7 +11,7 @@ func TestDefaultValue(t *testing.T) { gorm.Model Email string `gorm:"not null;index:,unique"` Name string `gorm:"not null;default:'foo'"` - Name2 string `gorm:"not null;default:'foo'"` + Name2 string `gorm:"size:233;not null;default:'foo'"` Age int `gorm:"default:18"` Enabled bool `gorm:"default:true"` } diff --git a/tests/go.mod b/tests/go.mod index d43ee8f1..955bafe2 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,10 +9,8 @@ require ( gorm.io/driver/mysql v0.2.6 gorm.io/driver/postgres v0.2.3 gorm.io/driver/sqlite v1.0.7 - gorm.io/driver/sqlserver v0.2.2 + gorm.io/driver/sqlserver v0.2.3 gorm.io/gorm v0.2.9 ) replace gorm.io/gorm => ../ - -replace gorm.io/driver/sqlserver => /Users/jinzhu/Projects/sqlserver From f2b49437fbbab9c42ec85f9a0fcf4ad10abc32ec Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 25 Jun 2020 22:48:10 +0800 Subject: [PATCH 474/881] Test set string field's default value to blank string --- migrator/migrator.go | 2 +- tests/default_value_test.go | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 9c4ce2d5..5edd800e 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -64,7 +64,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { expr.SQL += " UNIQUE" } - if field.HasDefaultValue && field.DefaultValue != "" { + if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { if field.DefaultValueInterface != nil { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) diff --git a/tests/default_value_test.go b/tests/default_value_test.go index 7a7790bc..ea496d60 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -12,6 +12,7 @@ func TestDefaultValue(t *testing.T) { Email string `gorm:"not null;index:,unique"` Name string `gorm:"not null;default:'foo'"` Name2 string `gorm:"size:233;not null;default:'foo'"` + Name3 string `gorm:"size:233;not null;default:''"` Age int `gorm:"default:18"` Enabled bool `gorm:"default:true"` } @@ -25,14 +26,14 @@ func TestDefaultValue(t *testing.T) { var harumph = Harumph{Email: "hello@gorm.io"} if err := DB.Create(&harumph).Error; err != nil { t.Fatalf("Failed to create data with default value, got error: %v", err) - } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Age != 18 || !harumph.Enabled { + } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled { t.Fatalf("Failed to create data with default value, got: %+v", harumph) } var result Harumph if err := DB.First(&result, "email = ?", "hello@gorm.io").Error; err != nil { t.Fatalf("Failed to find created data, got error: %v", err) - } else if result.Name != "foo" || result.Name2 != "foo" || result.Age != 18 || !result.Enabled { + } else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled { t.Fatalf("Failed to find created data with default data, got %+v", result) } } From 4eae3fea41b0f0e4badc8cb96e67588acf094ec7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 25 Jun 2020 23:37:49 +0800 Subject: [PATCH 475/881] Test group by with multiple columns --- tests/group_by_test.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/group_by_test.go b/tests/group_by_test.go index cb4c4f43..b08f48f1 100644 --- a/tests/group_by_test.go +++ b/tests/group_by_test.go @@ -11,6 +11,7 @@ func TestGroupBy(t *testing.T) { Name: "groupby", Age: 10, Birthday: Now(), + Active: true, }, { Name: "groupby", Age: 20, @@ -19,6 +20,7 @@ func TestGroupBy(t *testing.T) { Name: "groupby", Age: 30, Birthday: Now(), + Active: true, }, { Name: "groupby1", Age: 110, @@ -27,10 +29,12 @@ func TestGroupBy(t *testing.T) { Name: "groupby1", Age: 220, Birthday: Now(), + Active: true, }, { Name: "groupby1", Age: 330, Birthday: Now(), + Active: true, }} if err := DB.Create(&users).Error; err != nil { @@ -54,4 +58,13 @@ func TestGroupBy(t *testing.T) { if name != "groupby1" || total != 660 { t.Errorf("name should be groupby, but got %v, total should be 660, but got %v", name, total) } + + var active bool + if err := DB.Model(&User{}).Select("name, active, sum(age)").Where("name = ? and active = ?", "groupby", true).Group("name").Group("active").Row().Scan(&name, &active, &total); err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if name != "groupby" || active != true || total != 40 { + t.Errorf("group by two columns, name %v, age %v, active: %v", name, total, active) + } } From 2476c0fbb470e3ced8f61278da2bbdce1c24564c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 26 Jun 2020 07:26:45 +0800 Subject: [PATCH 476/881] Set db type after autotime --- schema/field.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/schema/field.go b/schema/field.go index a8328367..f02968fa 100644 --- a/schema/field.go +++ b/schema/field.go @@ -223,15 +223,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = DataType(dataTyper.GormDataType()) } - if val, ok := field.TagSettings["TYPE"]; ok { - switch DataType(strings.ToLower(val)) { - case Bool, Int, Uint, Float, String, Time, Bytes: - field.DataType = DataType(strings.ToLower(val)) - default: - field.DataType = DataType(val) - } - } - if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if strings.ToUpper(v) == "NANO" { field.AutoCreateTime = UnixNanosecond @@ -248,6 +239,15 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + if val, ok := field.TagSettings["TYPE"]; ok { + switch DataType(strings.ToLower(val)) { + case Bool, Int, Uint, Float, String, Time, Bytes: + field.DataType = DataType(strings.ToLower(val)) + default: + field.DataType = DataType(val) + } + } + if field.Size == 0 { switch reflect.Indirect(fieldValue).Kind() { case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: From cb5a35a80770c8e8815da93c970d4a43b7eeafae Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 26 Jun 2020 08:39:18 +0800 Subject: [PATCH 477/881] Test group with table name --- tests/go.mod | 8 ++++---- tests/group_by_test.go | 8 ++++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index 955bafe2..c467f34b 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,10 +6,10 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.2.6 - gorm.io/driver/postgres v0.2.3 - gorm.io/driver/sqlite v1.0.7 - gorm.io/driver/sqlserver v0.2.3 + gorm.io/driver/mysql v0.2.7 + gorm.io/driver/postgres v0.2.4 + gorm.io/driver/sqlite v1.0.8 + gorm.io/driver/sqlserver v0.2.4 gorm.io/gorm v0.2.9 ) diff --git a/tests/group_by_test.go b/tests/group_by_test.go index b08f48f1..6d0ed39c 100644 --- a/tests/group_by_test.go +++ b/tests/group_by_test.go @@ -51,6 +51,14 @@ func TestGroupBy(t *testing.T) { t.Errorf("name should be groupby, but got %v, total should be 60, but got %v", name, total) } + if err := DB.Model(&User{}).Select("name, sum(age)").Where("name = ?", "groupby").Group("users.name").Row().Scan(&name, &total); err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if name != "groupby" || total != 60 { + t.Errorf("name should be groupby, but got %v, total should be 60, but got %v", name, total) + } + if err := DB.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Row().Scan(&name, &total); err != nil { t.Errorf("no error should happen, but got %v", err) } From 9bfe3069755739e23a96255805071032a7b7fd40 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 27 Jun 2020 08:04:12 +0800 Subject: [PATCH 478/881] Only query with readable fields --- statement.go | 24 ++++++++++++++---------- tests/customize_field_test.go | 8 ++++++++ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/statement.go b/statement.go index 7cc01bb8..e902b739 100644 --- a/statement.go +++ b/statement.go @@ -271,22 +271,26 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c switch reflectValue.Kind() { case reflect.Struct: for _, field := range s.Fields { - if v, isZero := field.ValueOf(reflectValue); !isZero { - if field.DBName == "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) - } else { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + if field.Readable { + if v, isZero := field.ValueOf(reflectValue); !isZero { + if field.DBName == "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) + } else { + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + } } } } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { for _, field := range s.Fields { - if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { - if field.DBName == "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) - } else { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + if field.Readable { + if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { + if field.DBName == "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) + } else { + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + } } } } diff --git a/tests/customize_field_test.go b/tests/customize_field_test.go index 910fa6ae..9c6ab948 100644 --- a/tests/customize_field_test.go +++ b/tests/customize_field_test.go @@ -134,10 +134,18 @@ func TestCustomizeField(t *testing.T) { t.Fatalf("invalid updated result: %#v", result2) } + if err := DB.Where(CustomizeFieldStruct{Name: create.Name, FieldReadonly: create.FieldReadonly, FieldIgnore: create.FieldIgnore}).First(&CustomizeFieldStruct{}).Error; err == nil { + t.Fatalf("Should failed to find result") + } + if err := DB.Table("customize_field_structs").Where("1 = 1").UpdateColumn("field_readonly", "readonly").Error; err != nil { t.Fatalf("failed to update field_readonly column") } + if err := DB.Where(CustomizeFieldStruct{Name: create.Name, FieldReadonly: "readonly", FieldIgnore: create.FieldIgnore}).First(&CustomizeFieldStruct{}).Error; err != nil { + t.Fatalf("Should find result") + } + var result3 CustomizeFieldStruct DB.Find(&result3, "name = ?", "create") From 2d048d9ece097f86ecf77872ba050c0ce242bfc0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 30 Jun 2020 07:29:15 +0800 Subject: [PATCH 479/881] SingularTable for JoinTable --- schema/naming.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/schema/naming.go b/schema/naming.go index d2a4919f..9b7c9471 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -41,6 +41,9 @@ func (ns NamingStrategy) ColumnName(table, column string) string { // JoinTableName convert string to join table name func (ns NamingStrategy) JoinTableName(str string) string { + if ns.SingularTable { + return ns.TablePrefix + toDBName(str) + } return ns.TablePrefix + inflection.Plural(toDBName(str)) } From f5566288de9b58172f4796053055abde57988b7f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 30 Jun 2020 16:53:54 +0800 Subject: [PATCH 480/881] Add SetColumn, Changed method --- callbacks/associations.go | 4 +- callbacks/create.go | 2 +- callbacks/helper.go | 58 +------------------ callbacks/update.go | 2 +- errors.go | 2 + statement.go | 117 ++++++++++++++++++++++++++++++++++++++ tests/hooks_test.go | 81 ++++++++++++++++++++++++++ utils/utils.go | 15 +++++ 8 files changed, 221 insertions(+), 60 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 3ff0f4b0..bcb6c414 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -11,7 +11,7 @@ import ( func SaveBeforeAssociations(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil { - selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) + selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) // Save Belongs To associations for _, rel := range db.Statement.Schema.Relationships.BelongsTo { @@ -90,7 +90,7 @@ func SaveBeforeAssociations(db *gorm.DB) { func SaveAfterAssociations(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil { - selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) + selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) // Save Has One associations for _, rel := range db.Statement.Schema.Relationships.HasOne { diff --git a/callbacks/create.go b/callbacks/create.go index 283d3fd1..eecb80a1 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -218,7 +218,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { values = ConvertSliceOfMapToValuesForCreate(stmt, value) default: var ( - selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) + selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) curTime = stmt.DB.NowFunc() isZero bool ) diff --git a/callbacks/helper.go b/callbacks/helper.go index 3b0cca16..1b06e0b7 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -7,64 +7,10 @@ import ( "gorm.io/gorm/clause" ) -// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false -func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate bool) (map[string]bool, bool) { - results := map[string]bool{} - notRestricted := false - - // select columns - for _, column := range stmt.Selects { - if column == "*" { - notRestricted = true - for _, dbName := range stmt.Schema.DBNames { - results[dbName] = true - } - } else if column == clause.Associations { - for _, rel := range stmt.Schema.Relationships.Relations { - results[rel.Name] = true - } - } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { - results[field.DBName] = true - } else { - results[column] = true - } - } - - // omit columns - for _, omit := range stmt.Omits { - if omit == clause.Associations { - for _, rel := range stmt.Schema.Relationships.Relations { - results[rel.Name] = false - } - } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { - results[field.DBName] = false - } else { - results[omit] = false - } - } - - if stmt.Schema != nil { - for _, field := range stmt.Schema.Fields { - name := field.DBName - if name == "" { - name = field.Name - } - - if requireCreate && !field.Creatable { - results[name] = false - } else if requireUpdate && !field.Updatable { - results[name] = false - } - } - } - - return results, !notRestricted && len(stmt.Selects) > 0 -} - // ConvertMapToValuesForCreate convert map to values func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { columns := make([]string, 0, len(mapValue)) - selectColumns, restricted := SelectAndOmitColumns(stmt, true, false) + selectColumns, restricted := stmt.SelectAndOmitColumns(true, false) var keys []string for k := range mapValue { @@ -91,7 +37,7 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st var ( columns = []string{} result = map[string][]interface{}{} - selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) + selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) ) for idx, mapValue := range mapValues { diff --git a/callbacks/update.go b/callbacks/update.go index 1ea77552..f84e933c 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -110,7 +110,7 @@ func AfterUpdate(db *gorm.DB) { // ConvertToAssignments convert to update assignments func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { var ( - selectColumns, restricted = SelectAndOmitColumns(stmt, false, true) + selectColumns, restricted = stmt.SelectAndOmitColumns(false, true) assignValue func(field *schema.Field, value interface{}) ) diff --git a/errors.go b/errors.go index b41eefae..e1b58835 100644 --- a/errors.go +++ b/errors.go @@ -29,4 +29,6 @@ var ( ErrUnsupportedDriver = errors.New("unsupported driver") // ErrRegistered registered ErrRegistered = errors.New("registered") + // ErrInvalidField invalid field + ErrInvalidField = errors.New("invalid field") ) diff --git a/statement.go b/statement.go index e902b739..164ddbd7 100644 --- a/statement.go +++ b/statement.go @@ -12,6 +12,7 @@ import ( "gorm.io/gorm/clause" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) // Statement statement @@ -370,3 +371,119 @@ func (stmt *Statement) clone() *Statement { return newStmt } + +// Helpers +// SetColumn set column's value +func (stmt *Statement) SetColumn(name string, value interface{}) { + if v, ok := stmt.Dest.(map[string]interface{}); ok { + v[name] = value + } else if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(name); field != nil { + field.Set(stmt.ReflectValue, value) + } else { + stmt.AddError(ErrInvalidField) + } + } else { + stmt.AddError(ErrInvalidField) + } +} + +// Changed check model changed or not when updating +func (stmt *Statement) Changed(fields ...string) bool { + modelValue := reflect.ValueOf(stmt.Model) + for modelValue.Kind() == reflect.Ptr { + modelValue = modelValue.Elem() + } + + selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) + changed := func(field *schema.Field) bool { + fieldValue, isZero := field.ValueOf(modelValue) + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if v, ok := stmt.Dest.(map[string]interface{}); ok { + if fv, ok := v[field.Name]; ok { + return !utils.AssertEqual(fv, fieldValue) + } else if fv, ok := v[field.DBName]; ok { + return !utils.AssertEqual(fv, fieldValue) + } else if isZero { + return true + } + } else { + changedValue, _ := field.ValueOf(stmt.ReflectValue) + return !utils.AssertEqual(changedValue, fieldValue) + } + } + return false + } + + if len(fields) == 0 { + for _, field := range stmt.Schema.FieldsByDBName { + if changed(field) { + return true + } + } + } else { + for _, name := range fields { + if field := stmt.Schema.LookUpField(name); field != nil { + if changed(field) { + return true + } + } + } + } + + return false +} + +// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false +func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { + results := map[string]bool{} + notRestricted := false + + // select columns + for _, column := range stmt.Selects { + if column == "*" { + notRestricted = true + for _, dbName := range stmt.Schema.DBNames { + results[dbName] = true + } + } else if column == clause.Associations { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = true + } + } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { + results[field.DBName] = true + } else { + results[column] = true + } + } + + // omit columns + for _, omit := range stmt.Omits { + if omit == clause.Associations { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = false + } + } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { + results[field.DBName] = false + } else { + results[omit] = false + } + } + + if stmt.Schema != nil { + for _, field := range stmt.Schema.Fields { + name := field.DBName + if name == "" { + name = field.Name + } + + if requireCreate && !field.Creatable { + results[name] = false + } else if requireUpdate && !field.Updatable { + results[name] = false + } + } + } + + return results, !notRestricted && len(stmt.Selects) > 0 +} diff --git a/tests/hooks_test.go b/tests/hooks_test.go index c74e8f10..8f8c60f5 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -285,3 +285,84 @@ func TestUseDBInHooks(t *testing.T) { t.Errorf("Admin product's price should not be changed, expects: %v, got %v", 600, result4.Price) } } + +type Product3 struct { + gorm.Model + Name string + Code string + Price int64 + Owner string +} + +func (s Product3) BeforeCreate(tx *gorm.DB) (err error) { + tx.Statement.SetColumn("Price", s.Price+100) + return nil +} + +func (s Product3) BeforeUpdate(tx *gorm.DB) (err error) { + if tx.Statement.Changed() { + tx.Statement.SetColumn("Price", s.Price+10) + } + + if tx.Statement.Changed("Code") { + s.Price += 20 + tx.Statement.SetColumn("Price", s.Price+30) + } + return nil +} + +func TestSetColumn(t *testing.T) { + DB.Migrator().DropTable(&Product3{}) + DB.AutoMigrate(&Product3{}) + + product := Product3{Name: "Product", Price: 0} + DB.Create(&product) + + if product.Price != 100 { + t.Errorf("invalid price after create, got %+v", product) + } + + DB.Model(&product).Select("code", "price").Updates(map[string]interface{}{"code": "L1212"}) + + if product.Price != 150 || product.Code != "L1212" { + t.Errorf("invalid data after update, got %+v", product) + } + + // Code not changed, price should not change + DB.Model(&product).Updates(map[string]interface{}{"Name": "Product New"}) + + if product.Name != "Product New" || product.Price != 160 || product.Code != "L1212" { + t.Errorf("invalid data after update, got %+v", product) + } + + // Code changed, but not selected, price should not change + DB.Model(&product).Select("Name", "Price").Updates(map[string]interface{}{"Name": "Product New2", "code": "L1213"}) + + if product.Name != "Product New2" || product.Price != 170 || product.Code != "L1212" { + t.Errorf("invalid data after update, got %+v", product) + } + + // Code changed, price should changed + DB.Model(&product).Select("Name", "Code", "Price").Updates(map[string]interface{}{"Name": "Product New3", "code": "L1213"}) + + if product.Name != "Product New3" || product.Price != 220 || product.Code != "L1213" { + t.Errorf("invalid data after update, got %+v", product) + } + + var result Product3 + DB.First(&result, product.ID) + + AssertEqual(t, result, product) + + // Code changed, price not selected, price should not change + DB.Model(&product).Select("code").Updates(map[string]interface{}{"name": "L1214"}) + + if product.Price != 220 || product.Code != "L1213" { + t.Errorf("invalid data after update, got %+v", product) + } + + var result2 Product3 + DB.First(&result2, product.ID) + + AssertEqual(t, result2, product) +} diff --git a/utils/utils.go b/utils/utils.go index 81d2dc34..9bf00683 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -68,3 +68,18 @@ func ToStringKey(values ...interface{}) string { return strings.Join(results, "_") } + +func AssertEqual(src, dst interface{}) bool { + if !reflect.DeepEqual(src, dst) { + if valuer, ok := src.(driver.Valuer); ok { + src, _ = valuer.Value() + } + + if valuer, ok := dst.(driver.Valuer); ok { + dst, _ = valuer.Value() + } + + return reflect.DeepEqual(src, dst) + } + return true +} From 929c0c576cd55e935cf204a4ee3c492734a4293b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 30 Jun 2020 22:47:21 +0800 Subject: [PATCH 481/881] Test Hooks For Slice --- callbacks/callmethod.go | 4 +++- statement.go | 17 +++++++++++---- tests/hooks_test.go | 48 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 5 deletions(-) diff --git a/callbacks/callmethod.go b/callbacks/callmethod.go index a0e9b0e7..0160f354 100644 --- a/callbacks/callmethod.go +++ b/callbacks/callmethod.go @@ -11,8 +11,10 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { if called := fc(db.Statement.Dest, tx); !called { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: + db.Statement.CurDestIndex = 0 for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - fc(db.Statement.ReflectValue.Index(i).Addr().Interface(), tx) + fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx) + db.Statement.CurDestIndex++ } case reflect.Struct: fc(db.Statement.ReflectValue.Addr().Interface(), tx) diff --git a/statement.go b/statement.go index 164ddbd7..e65a064f 100644 --- a/statement.go +++ b/statement.go @@ -38,6 +38,7 @@ type Statement struct { SQL strings.Builder Vars []interface{} NamedVars []sql.NamedArg + CurDestIndex int attrs []interface{} assigns []interface{} } @@ -379,7 +380,12 @@ func (stmt *Statement) SetColumn(name string, value interface{}) { v[name] = value } else if stmt.Schema != nil { if field := stmt.Schema.LookUpField(name); field != nil { - field.Set(stmt.ReflectValue, value) + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) + case reflect.Struct: + field.Set(stmt.ReflectValue, value) + } } else { stmt.AddError(ErrInvalidField) } @@ -395,17 +401,20 @@ func (stmt *Statement) Changed(fields ...string) bool { modelValue = modelValue.Elem() } + switch modelValue.Kind() { + case reflect.Slice, reflect.Array: + modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex) + } + selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) changed := func(field *schema.Field) bool { - fieldValue, isZero := field.ValueOf(modelValue) + fieldValue, _ := field.ValueOf(modelValue) if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := stmt.Dest.(map[string]interface{}); ok { if fv, ok := v[field.Name]; ok { return !utils.AssertEqual(fv, fieldValue) } else if fv, ok := v[field.DBName]; ok { return !utils.AssertEqual(fv, fieldValue) - } else if isZero { - return true } } else { changedValue, _ := field.ValueOf(stmt.ReflectValue) diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 8f8c60f5..ed5ee746 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -366,3 +366,51 @@ func TestSetColumn(t *testing.T) { AssertEqual(t, result2, product) } + +func TestHooksForSlice(t *testing.T) { + products := []*Product3{ + {Name: "Product-1", Price: 100}, + {Name: "Product-2", Price: 200}, + {Name: "Product-3", Price: 300}, + } + + DB.Create(&products) + + for idx, value := range []int64{200, 300, 400} { + if products[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price) + } + } + + DB.Model(&products).Update("Name", "product-name") + + // will set all product's price to last product's price + 10 + for idx, value := range []int64{410, 410, 410} { + if products[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price) + } + } + + products2 := []Product3{ + {Name: "Product-1", Price: 100}, + {Name: "Product-2", Price: 200}, + {Name: "Product-3", Price: 300}, + } + + DB.Create(&products2) + + for idx, value := range []int64{200, 300, 400} { + if products2[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price) + } + } + + DB.Model(&products2).Update("Name", "product-name") + + // will set all product's price to last product's price + 10 + for idx, value := range []int64{410, 410, 410} { + if products2[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price) + } + } +} From ee1f46e3a1295f2342e72d5da9dc33f8a2a2a9d5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 30 Jun 2020 23:06:48 +0800 Subject: [PATCH 482/881] Allow to use sql function in Group, Pluck --- chainable_api.go | 4 +++- finisher_api.go | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index dbd783fd..e2ba44cc 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -162,8 +162,10 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { // Group specify the group method on the find func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() + + fields := strings.FieldsFunc(name, utils.IsChar) tx.Statement.AddClause(clause.GroupBy{ - Columns: []clause.Column{{Name: name}}, + Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, }) return } diff --git a/finisher_api.go b/finisher_api.go index 6d961811..af040106 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -8,6 +8,7 @@ import ( "strings" "gorm.io/gorm/clause" + "gorm.io/gorm/utils" ) // Create insert the value into database @@ -325,9 +326,10 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx.AddError(ErrorModelValueRequired) } + fields := strings.FieldsFunc(column, utils.IsChar) tx.Statement.AddClauseIfNotExists(clause.Select{ Distinct: tx.Statement.Distinct, - Columns: []clause.Column{{Name: column}}, + Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, }) tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) From 9075b33620f14a62680d0c296522243874be2700 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 1 Jul 2020 08:56:21 +0800 Subject: [PATCH 483/881] Query with smaller struct --- callbacks/query.go | 12 +++++++++++- scan.go | 24 +++++++++++++++++------- tests/query_test.go | 23 ++++++++++++++++++++++- 3 files changed, 50 insertions(+), 9 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 27d53a4d..4b7f5bd5 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -40,7 +40,7 @@ func BuildQuerySQL(db *gorm.DB) { db.Statement.SQL.Grow(100) clauseSelect := clause.Select{Distinct: db.Statement.Distinct} - if db.Statement.ReflectValue.Kind() == reflect.Struct { + if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { var conds []clause.Expression for _, primaryField := range db.Statement.Schema.PrimaryFields { if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { @@ -64,6 +64,16 @@ func BuildQuerySQL(db *gorm.DB) { clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} } } + } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() && db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType { + stmt := gorm.Statement{DB: db} + // smaller struct + if err := stmt.Parse(db.Statement.Dest); err == nil { + clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) + + for idx, dbName := range stmt.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Name: dbName} + } + } } // inline joins diff --git a/scan.go b/scan.go index 2d227ec2..0b199029 100644 --- a/scan.go +++ b/scan.go @@ -69,6 +69,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { db.AddError(rows.Scan(dest)) } default: + Schema := db.Statement.Schema + switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: var ( @@ -84,16 +86,20 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) - if db.Statement.Schema != nil { + if Schema != nil { + if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { + Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + } + for idx, column := range columns { - if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { + if field := Schema.LookUpField(column); field != nil && field.Readable { fields[idx] = field } else if names := strings.Split(column, "__"); len(names) > 1 { if len(joinFields) == 0 { joinFields = make([][2]*schema.Field, len(columns)) } - if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { + if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { fields[idx] = field joinFields[idx] = [2]*schema.Field{rel.Field, field} @@ -151,12 +157,16 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } } case reflect.Struct: + if db.Statement.ReflectValue.Type() != Schema.ModelType { + Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + } + if initialized || rows.Next() { for idx, column := range columns { - if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { + if field := Schema.LookUpField(column); field != nil && field.Readable { values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { + if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() continue @@ -172,10 +182,10 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { db.AddError(rows.Scan(values...)) for idx, column := range columns { - if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { + if field := Schema.LookUpField(column); field != nil && field.Readable { field.Set(db.Statement.ReflectValue, values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { + if rel, ok := Schema.Relationships.Relations[names[0]]; ok { relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { value := reflect.ValueOf(values[idx]).Elem() diff --git a/tests/query_test.go b/tests/query_test.go index de65b63b..7973fd51 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -3,6 +3,7 @@ package tests_test import ( "fmt" "reflect" + "regexp" "sort" "strconv" "testing" @@ -144,8 +145,8 @@ func TestFillSmallerStruct(t *testing.T) { user := User{Name: "SmallerUser", Age: 100} DB.Save(&user) type SimpleUser struct { - Name string ID int64 + Name string UpdatedAt time.Time CreatedAt time.Time } @@ -156,6 +157,26 @@ func TestFillSmallerStruct(t *testing.T) { } AssertObjEqual(t, user, simpleUser, "Name", "ID", "UpdatedAt", "CreatedAt") + + var simpleUser2 SimpleUser + if err := DB.Model(&User{}).Select("id").First(&simpleUser2, user.ID).Error; err != nil { + t.Fatalf("Failed to query smaller user, got error %v", err) + } + + AssertObjEqual(t, user, simpleUser2, "ID") + + var simpleUsers []SimpleUser + if err := DB.Model(&User{}).Select("id").Find(&simpleUsers, user.ID).Error; err != nil || len(simpleUsers) != 1 { + t.Fatalf("Failed to query smaller user, got error %v", err) + } + + AssertObjEqual(t, user, simpleUsers[0], "ID") + + result := DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&simpleUsers, user.ID) + + if !regexp.MustCompile("SELECT .*id.*name.*updated_at.*created_at.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String()) + } } func TestPluck(t *testing.T) { From d02b592c6cd276c169ade515b8999132def9e555 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 1 Jul 2020 10:19:52 +0800 Subject: [PATCH 484/881] Better support Count in chain --- finisher_api.go | 2 ++ tests/count_test.go | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/finisher_api.go b/finisher_api.go index af040106..25c56e49 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -269,6 +269,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 0 { tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) + defer tx.Statement.AddClause(clause.Select{}) } else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") { expr := clause.Expr{SQL: "count(1)"} @@ -281,6 +282,7 @@ func (db *DB) Count(count *int64) (tx *DB) { } tx.Statement.AddClause(clause.Select{Expression: expr}) + defer tx.Statement.AddClause(clause.Select{}) } tx.Statement.Dest = count diff --git a/tests/count_test.go b/tests/count_test.go index 0662ae5c..826d6a36 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -27,6 +27,14 @@ func TestCount(t *testing.T) { t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users)) } + if err := DB.Model(&User{}).Where("name = ?", user1.Name).Or("name = ?", user3.Name).Count(&count).Find(&users).Error; err != nil { + t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + if count != int64(len(users)) { + t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users)) + } + DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2) if count1 != 1 || count2 != 3 { t.Errorf("multiple count in chain should works") From fea181e87c019a20135623b0644b6b9585d6db13 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 1 Jul 2020 11:47:46 +0800 Subject: [PATCH 485/881] Test multiple index tags --- schema/index_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/schema/index_test.go b/schema/index_test.go index 384e902b..71a70a8c 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -16,7 +16,7 @@ type UserIndex struct { Name5 int64 `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"` Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"` Age int64 `gorm:"index:profile,expression:ABS(age)"` - OID int64 `gorm:"index:idx_id"` + OID int64 `gorm:"index:idx_id;index:idx_oid,unique"` MemberNumber string `gorm:"index:idx_id"` } @@ -70,6 +70,11 @@ func TestParseIndex(t *testing.T) { Name: "idx_id", Fields: []schema.IndexOption{{}, {}}, }, + "idx_oid": { + Name: "idx_oid", + Class: "UNIQUE", + Fields: []schema.IndexOption{{}}, + }, } indices := user.ParseIndexes() From 630f4fe03f9d2fd93ed3dcc0ec248c8c76c05cd5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 24 Jun 2020 16:43:53 +0800 Subject: [PATCH 486/881] Create join table with ReorderModels --- migrator/migrator.go | 37 +++++++++----------------------- tests/multi_primary_keys_test.go | 27 +++++++++++++++++++---- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index c8fe17ab..799bf433 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -116,20 +116,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } } - - // create join table - if rel.JoinTable != nil { - joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !tx.Migrator().HasTable(rel.JoinTable.Table) { - defer func(table string, joinValue interface{}) { - errr = tx.Table(table).Migrator().CreateTable(joinValue) - }(rel.JoinTable.Table, joinValue) - } else { - defer func(table string, joinValue interface{}) { - errr = tx.Table(table).Migrator().AutoMigrate(joinValue) - }(rel.JoinTable.Table, joinValue) - } - } } return nil }); err != nil { @@ -193,16 +179,6 @@ func (m Migrator) CreateTable(values ...interface{}) error { } } } - - // create join table - if rel.JoinTable != nil { - joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !tx.Migrator().HasTable(rel.JoinTable.Table) { - defer func(table string, joinValue interface{}) { - errr = tx.Table(table).Migrator().CreateTable(joinValue) - }(rel.JoinTable.Table, joinValue) - } - } } for _, chk := range stmt.Schema.ParseCheckConstraints() { @@ -551,9 +527,10 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i orderedModelNamesMap = map[string]bool{} valuesMap = map[string]Dependency{} insertIntoOrderedList func(name string) + parseDependence func(value interface{}, addToList bool) ) - parseDependence := func(value interface{}, addToList bool) { + parseDependence = func(value interface{}, addToList bool) { dep := Dependency{ Statement: &gorm.Statement{DB: m.DB, Dest: value}, } @@ -564,8 +541,14 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i dep.Depends = append(dep.Depends, c.ReferenceSchema) } - if rel.JoinTable != nil && rel.Schema != rel.FieldSchema { - dep.Depends = append(dep.Depends, rel.FieldSchema) + if rel.JoinTable != nil { + if rel.Schema != rel.FieldSchema { + dep.Depends = append(dep.Depends, rel.FieldSchema) + } + // append join value + defer func(joinValue interface{}) { + parseDependence(joinValue, autoAdd) + }(reflect.New(rel.JoinTable.ModelType).Interface()) } } diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go index 05267bbb..617010c5 100644 --- a/tests/multi_primary_keys_test.go +++ b/tests/multi_primary_keys_test.go @@ -4,6 +4,8 @@ import ( "reflect" "sort" "testing" + + "gorm.io/gorm" ) type Blog struct { @@ -11,7 +13,7 @@ type Blog struct { Locale string `gorm:"primary_key"` Subject string Body string - Tags []Tag `gorm:"many2many:blog_tags;"` + Tags []Tag `gorm:"many2many:blogs_tags;"` SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;References:id"` LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;References:id"` } @@ -38,7 +40,16 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } - DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") + if name := DB.Dialector.Name(); name == "postgres" { + stmt := gorm.Statement{DB: DB} + stmt.Parse(&Blog{}) + stmt.Schema.LookUpField("ID").Unique = true + stmt.Parse(&Tag{}) + stmt.Schema.LookUpField("ID").Unique = true + // postgers only allow unique constraint matching given keys + } + + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { t.Fatalf("Failed to auto migrate, got error: %v", err) } @@ -127,7 +138,11 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } - DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") + if name := DB.Dialector.Name(); name == "postgres" { + t.Skip("skip postgers due to it only allow unique constraint matching given keys") + } + + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { t.Fatalf("Failed to auto migrate, got error: %v", err) } @@ -248,7 +263,11 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } - DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") + if name := DB.Dialector.Name(); name == "postgres" { + t.Skip("skip postgers due to it only allow unique constraint matching given keys") + } + + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { t.Fatalf("Failed to auto migrate, got error: %v", err) } From 6b92bca6648ebb9137339b7347ae82ac8a462754 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 24 Jun 2020 19:09:19 +0800 Subject: [PATCH 487/881] Update test script --- tests/main_test.go | 4 ++++ tests/tests_all.sh | 14 ++++---------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/main_test.go b/tests/main_test.go index 9d933caf..5b8c7dbb 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -7,6 +7,10 @@ import ( ) func TestExceptionsWithInvalidSql(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlserver" { + t.Skip("skip sqlserver due to it will raise data race for invalid sql") + } + var columns []string if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { t.Errorf("Should got error with invalid SQL") diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 47f25401..e87ff045 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -19,27 +19,21 @@ for dialect in "${dialects[@]}" ; do then echo "testing ${dialect}..." - race="" - if [ "$GORM_DIALECT" = "sqlserver" ] - then - race="-race" - fi - if [ "$GORM_VERBOSE" = "" ] then - GORM_DIALECT=${dialect} go test $race -count=1 ./... + GORM_DIALECT=${dialect} go test -race -count=1 ./... if [ -d tests ] then cd tests - GORM_DIALECT=${dialect} go test $race -count=1 ./... + GORM_DIALECT=${dialect} go test -race -count=1 ./... cd .. fi else - GORM_DIALECT=${dialect} go test $race -count=1 -v ./... + GORM_DIALECT=${dialect} go test -race -count=1 -v ./... if [ -d tests ] then cd tests - GORM_DIALECT=${dialect} go test $race -count=1 -v ./... + GORM_DIALECT=${dialect} go test -race -count=1 -v ./... cd .. fi fi From 19f56ddc2a212019a950c6ef81e55950342b713a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 24 Jun 2020 20:19:28 +0800 Subject: [PATCH 488/881] Upgrade default mysql driver --- tests/go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/go.mod b/tests/go.mod index abe32cd6..f4d93ecb 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,7 +6,7 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.2.3 + gorm.io/driver/mysql v0.2.6 gorm.io/driver/postgres v0.2.3 gorm.io/driver/sqlite v1.0.7 gorm.io/driver/sqlserver v0.2.2 From 4cbd99aa94d04292ac369fd9abe3b1a78d6d7fe6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 25 Jun 2020 06:38:07 +0800 Subject: [PATCH 489/881] Add default value test --- tests/default_value_test.go | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 tests/default_value_test.go diff --git a/tests/default_value_test.go b/tests/default_value_test.go new file mode 100644 index 00000000..52292cf7 --- /dev/null +++ b/tests/default_value_test.go @@ -0,0 +1,37 @@ +package tests_test + +import ( + "testing" + + "gorm.io/gorm" +) + +func TestDefaultValue(t *testing.T) { + type Harumph struct { + gorm.Model + Email string `gorm:"not null;"` + Name string `gorm:"not null;default:foo"` + Name2 string `gorm:"not null;default:'foo'"` + Age int `gorm:"default:18"` + } + + DB.Migrator().DropTable(&Harumph{}) + + if err := DB.AutoMigrate(&Harumph{}); err != nil { + t.Fatalf("Failed to migrate with default value, got error: %v", err) + } + + var harumph = Harumph{Email: "hello@gorm.io"} + if err := DB.Create(&harumph).Error; err != nil { + t.Fatalf("Failed to create data with default value, got error: %v", err) + } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Age != 18 { + t.Fatalf("Failed to create data with default value, got: %+v", harumph) + } + + var result Harumph + if err := DB.First(&result, "email = ?", "hello@gorm.io").Error; err != nil { + t.Fatalf("Failed to find created data, got error: %v", err) + } else if result.Name != "foo" || result.Name2 != "foo" || result.Age != 18 { + t.Fatalf("Failed to find created data with default data, got %+v", result) + } +} From dcdcc6fedc9e55ca6ebec4e8676cbdb238fc955f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 25 Jun 2020 08:00:10 +0800 Subject: [PATCH 490/881] Fix create with default value --- migrator/migrator.go | 8 ++++---- tests/default_value_test.go | 13 +++++++------ tests/go.mod | 2 ++ 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 799bf433..9c4ce2d5 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -65,10 +65,10 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { } if field.HasDefaultValue && field.DefaultValue != "" { - if field.DataType == schema.String && field.DefaultValueInterface != nil { - defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValue}} - m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValue) - expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValue) + if field.DefaultValueInterface != nil { + defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} + m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) + expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface) } else { expr.SQL += " DEFAULT " + field.DefaultValue } diff --git a/tests/default_value_test.go b/tests/default_value_test.go index 52292cf7..28a456d3 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -9,10 +9,11 @@ import ( func TestDefaultValue(t *testing.T) { type Harumph struct { gorm.Model - Email string `gorm:"not null;"` - Name string `gorm:"not null;default:foo"` - Name2 string `gorm:"not null;default:'foo'"` - Age int `gorm:"default:18"` + Email string `gorm:"not null;index:,unique"` + Name string `gorm:"not null;default:'foo'"` + Name2 string `gorm:"not null;default:'foo'"` + Age int `gorm:"default:18"` + Enabled bool `gorm:"default:true"` } DB.Migrator().DropTable(&Harumph{}) @@ -24,14 +25,14 @@ func TestDefaultValue(t *testing.T) { var harumph = Harumph{Email: "hello@gorm.io"} if err := DB.Create(&harumph).Error; err != nil { t.Fatalf("Failed to create data with default value, got error: %v", err) - } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Age != 18 { + } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Age != 18 || !harumph.Enabled { t.Fatalf("Failed to create data with default value, got: %+v", harumph) } var result Harumph if err := DB.First(&result, "email = ?", "hello@gorm.io").Error; err != nil { t.Fatalf("Failed to find created data, got error: %v", err) - } else if result.Name != "foo" || result.Name2 != "foo" || result.Age != 18 { + } else if result.Name != "foo" || result.Name2 != "foo" || result.Age != 18 || !result.Enabled { t.Fatalf("Failed to find created data with default data, got %+v", result) } } diff --git a/tests/go.mod b/tests/go.mod index f4d93ecb..d43ee8f1 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -14,3 +14,5 @@ require ( ) replace gorm.io/gorm => ../ + +replace gorm.io/driver/sqlserver => /Users/jinzhu/Projects/sqlserver From c888560a0e9971b174f7232cb847d3dc38229575 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 25 Jun 2020 08:08:37 +0800 Subject: [PATCH 491/881] Fix go.mod --- tests/default_value_test.go | 2 +- tests/go.mod | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/default_value_test.go b/tests/default_value_test.go index 28a456d3..7a7790bc 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -11,7 +11,7 @@ func TestDefaultValue(t *testing.T) { gorm.Model Email string `gorm:"not null;index:,unique"` Name string `gorm:"not null;default:'foo'"` - Name2 string `gorm:"not null;default:'foo'"` + Name2 string `gorm:"size:233;not null;default:'foo'"` Age int `gorm:"default:18"` Enabled bool `gorm:"default:true"` } diff --git a/tests/go.mod b/tests/go.mod index d43ee8f1..955bafe2 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,10 +9,8 @@ require ( gorm.io/driver/mysql v0.2.6 gorm.io/driver/postgres v0.2.3 gorm.io/driver/sqlite v1.0.7 - gorm.io/driver/sqlserver v0.2.2 + gorm.io/driver/sqlserver v0.2.3 gorm.io/gorm v0.2.9 ) replace gorm.io/gorm => ../ - -replace gorm.io/driver/sqlserver => /Users/jinzhu/Projects/sqlserver From af632199cf92c8609975a48a66a8be976a077d96 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 25 Jun 2020 22:48:10 +0800 Subject: [PATCH 492/881] Test set string field's default value to blank string --- migrator/migrator.go | 2 +- tests/default_value_test.go | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 9c4ce2d5..5edd800e 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -64,7 +64,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { expr.SQL += " UNIQUE" } - if field.HasDefaultValue && field.DefaultValue != "" { + if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { if field.DefaultValueInterface != nil { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) diff --git a/tests/default_value_test.go b/tests/default_value_test.go index 7a7790bc..ea496d60 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -12,6 +12,7 @@ func TestDefaultValue(t *testing.T) { Email string `gorm:"not null;index:,unique"` Name string `gorm:"not null;default:'foo'"` Name2 string `gorm:"size:233;not null;default:'foo'"` + Name3 string `gorm:"size:233;not null;default:''"` Age int `gorm:"default:18"` Enabled bool `gorm:"default:true"` } @@ -25,14 +26,14 @@ func TestDefaultValue(t *testing.T) { var harumph = Harumph{Email: "hello@gorm.io"} if err := DB.Create(&harumph).Error; err != nil { t.Fatalf("Failed to create data with default value, got error: %v", err) - } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Age != 18 || !harumph.Enabled { + } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled { t.Fatalf("Failed to create data with default value, got: %+v", harumph) } var result Harumph if err := DB.First(&result, "email = ?", "hello@gorm.io").Error; err != nil { t.Fatalf("Failed to find created data, got error: %v", err) - } else if result.Name != "foo" || result.Name2 != "foo" || result.Age != 18 || !result.Enabled { + } else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled { t.Fatalf("Failed to find created data with default data, got %+v", result) } } From 81f4fafae4c6a4237d8ad25d1b55340652d0c066 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 25 Jun 2020 23:37:49 +0800 Subject: [PATCH 493/881] Test group by with multiple columns --- tests/group_by_test.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/group_by_test.go b/tests/group_by_test.go index cb4c4f43..b08f48f1 100644 --- a/tests/group_by_test.go +++ b/tests/group_by_test.go @@ -11,6 +11,7 @@ func TestGroupBy(t *testing.T) { Name: "groupby", Age: 10, Birthday: Now(), + Active: true, }, { Name: "groupby", Age: 20, @@ -19,6 +20,7 @@ func TestGroupBy(t *testing.T) { Name: "groupby", Age: 30, Birthday: Now(), + Active: true, }, { Name: "groupby1", Age: 110, @@ -27,10 +29,12 @@ func TestGroupBy(t *testing.T) { Name: "groupby1", Age: 220, Birthday: Now(), + Active: true, }, { Name: "groupby1", Age: 330, Birthday: Now(), + Active: true, }} if err := DB.Create(&users).Error; err != nil { @@ -54,4 +58,13 @@ func TestGroupBy(t *testing.T) { if name != "groupby1" || total != 660 { t.Errorf("name should be groupby, but got %v, total should be 660, but got %v", name, total) } + + var active bool + if err := DB.Model(&User{}).Select("name, active, sum(age)").Where("name = ? and active = ?", "groupby", true).Group("name").Group("active").Row().Scan(&name, &active, &total); err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if name != "groupby" || active != true || total != 40 { + t.Errorf("group by two columns, name %v, age %v, active: %v", name, total, active) + } } From a550a058823234587dc53a815e158be2c9355424 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 26 Jun 2020 07:26:45 +0800 Subject: [PATCH 494/881] Set db type after autotime --- schema/field.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/schema/field.go b/schema/field.go index a8328367..f02968fa 100644 --- a/schema/field.go +++ b/schema/field.go @@ -223,15 +223,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = DataType(dataTyper.GormDataType()) } - if val, ok := field.TagSettings["TYPE"]; ok { - switch DataType(strings.ToLower(val)) { - case Bool, Int, Uint, Float, String, Time, Bytes: - field.DataType = DataType(strings.ToLower(val)) - default: - field.DataType = DataType(val) - } - } - if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if strings.ToUpper(v) == "NANO" { field.AutoCreateTime = UnixNanosecond @@ -248,6 +239,15 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + if val, ok := field.TagSettings["TYPE"]; ok { + switch DataType(strings.ToLower(val)) { + case Bool, Int, Uint, Float, String, Time, Bytes: + field.DataType = DataType(strings.ToLower(val)) + default: + field.DataType = DataType(val) + } + } + if field.Size == 0 { switch reflect.Indirect(fieldValue).Kind() { case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: From d5d31b38a7442f44da356cc413ad4afb30fa1abb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 26 Jun 2020 08:39:18 +0800 Subject: [PATCH 495/881] Test group with table name --- tests/go.mod | 8 ++++---- tests/group_by_test.go | 8 ++++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index 955bafe2..c467f34b 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,10 +6,10 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.2.6 - gorm.io/driver/postgres v0.2.3 - gorm.io/driver/sqlite v1.0.7 - gorm.io/driver/sqlserver v0.2.3 + gorm.io/driver/mysql v0.2.7 + gorm.io/driver/postgres v0.2.4 + gorm.io/driver/sqlite v1.0.8 + gorm.io/driver/sqlserver v0.2.4 gorm.io/gorm v0.2.9 ) diff --git a/tests/group_by_test.go b/tests/group_by_test.go index b08f48f1..6d0ed39c 100644 --- a/tests/group_by_test.go +++ b/tests/group_by_test.go @@ -51,6 +51,14 @@ func TestGroupBy(t *testing.T) { t.Errorf("name should be groupby, but got %v, total should be 60, but got %v", name, total) } + if err := DB.Model(&User{}).Select("name, sum(age)").Where("name = ?", "groupby").Group("users.name").Row().Scan(&name, &total); err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if name != "groupby" || total != 60 { + t.Errorf("name should be groupby, but got %v, total should be 60, but got %v", name, total) + } + if err := DB.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Row().Scan(&name, &total); err != nil { t.Errorf("no error should happen, but got %v", err) } From eeee014500669387fb0442ebbed1556a04bad8c5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 27 Jun 2020 08:04:12 +0800 Subject: [PATCH 496/881] Only query with readable fields --- statement.go | 24 ++++++++++++++---------- tests/customize_field_test.go | 8 ++++++++ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/statement.go b/statement.go index 7cc01bb8..e902b739 100644 --- a/statement.go +++ b/statement.go @@ -271,22 +271,26 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c switch reflectValue.Kind() { case reflect.Struct: for _, field := range s.Fields { - if v, isZero := field.ValueOf(reflectValue); !isZero { - if field.DBName == "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) - } else { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + if field.Readable { + if v, isZero := field.ValueOf(reflectValue); !isZero { + if field.DBName == "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) + } else { + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + } } } } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { for _, field := range s.Fields { - if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { - if field.DBName == "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) - } else { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + if field.Readable { + if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { + if field.DBName == "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) + } else { + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + } } } } diff --git a/tests/customize_field_test.go b/tests/customize_field_test.go index 910fa6ae..9c6ab948 100644 --- a/tests/customize_field_test.go +++ b/tests/customize_field_test.go @@ -134,10 +134,18 @@ func TestCustomizeField(t *testing.T) { t.Fatalf("invalid updated result: %#v", result2) } + if err := DB.Where(CustomizeFieldStruct{Name: create.Name, FieldReadonly: create.FieldReadonly, FieldIgnore: create.FieldIgnore}).First(&CustomizeFieldStruct{}).Error; err == nil { + t.Fatalf("Should failed to find result") + } + if err := DB.Table("customize_field_structs").Where("1 = 1").UpdateColumn("field_readonly", "readonly").Error; err != nil { t.Fatalf("failed to update field_readonly column") } + if err := DB.Where(CustomizeFieldStruct{Name: create.Name, FieldReadonly: "readonly", FieldIgnore: create.FieldIgnore}).First(&CustomizeFieldStruct{}).Error; err != nil { + t.Fatalf("Should find result") + } + var result3 CustomizeFieldStruct DB.Find(&result3, "name = ?", "create") From e308b103c02b05d5b0ab5b8a6f1ea70321d9f757 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 30 Jun 2020 07:29:15 +0800 Subject: [PATCH 497/881] SingularTable for JoinTable --- schema/naming.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/schema/naming.go b/schema/naming.go index d2a4919f..9b7c9471 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -41,6 +41,9 @@ func (ns NamingStrategy) ColumnName(table, column string) string { // JoinTableName convert string to join table name func (ns NamingStrategy) JoinTableName(str string) string { + if ns.SingularTable { + return ns.TablePrefix + toDBName(str) + } return ns.TablePrefix + inflection.Plural(toDBName(str)) } From 66dcd7e3cae8998f4c22a642299d1f4e7175c148 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 30 Jun 2020 16:53:54 +0800 Subject: [PATCH 498/881] Add SetColumn, Changed method --- callbacks/associations.go | 4 +- callbacks/create.go | 2 +- callbacks/helper.go | 58 +------------------ callbacks/update.go | 2 +- errors.go | 2 + statement.go | 117 ++++++++++++++++++++++++++++++++++++++ tests/hooks_test.go | 81 ++++++++++++++++++++++++++ utils/utils.go | 15 +++++ 8 files changed, 221 insertions(+), 60 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 3ff0f4b0..bcb6c414 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -11,7 +11,7 @@ import ( func SaveBeforeAssociations(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil { - selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) + selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) // Save Belongs To associations for _, rel := range db.Statement.Schema.Relationships.BelongsTo { @@ -90,7 +90,7 @@ func SaveBeforeAssociations(db *gorm.DB) { func SaveAfterAssociations(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil { - selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) + selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) // Save Has One associations for _, rel := range db.Statement.Schema.Relationships.HasOne { diff --git a/callbacks/create.go b/callbacks/create.go index 283d3fd1..eecb80a1 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -218,7 +218,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { values = ConvertSliceOfMapToValuesForCreate(stmt, value) default: var ( - selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) + selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) curTime = stmt.DB.NowFunc() isZero bool ) diff --git a/callbacks/helper.go b/callbacks/helper.go index 3b0cca16..1b06e0b7 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -7,64 +7,10 @@ import ( "gorm.io/gorm/clause" ) -// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false -func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate bool) (map[string]bool, bool) { - results := map[string]bool{} - notRestricted := false - - // select columns - for _, column := range stmt.Selects { - if column == "*" { - notRestricted = true - for _, dbName := range stmt.Schema.DBNames { - results[dbName] = true - } - } else if column == clause.Associations { - for _, rel := range stmt.Schema.Relationships.Relations { - results[rel.Name] = true - } - } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { - results[field.DBName] = true - } else { - results[column] = true - } - } - - // omit columns - for _, omit := range stmt.Omits { - if omit == clause.Associations { - for _, rel := range stmt.Schema.Relationships.Relations { - results[rel.Name] = false - } - } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { - results[field.DBName] = false - } else { - results[omit] = false - } - } - - if stmt.Schema != nil { - for _, field := range stmt.Schema.Fields { - name := field.DBName - if name == "" { - name = field.Name - } - - if requireCreate && !field.Creatable { - results[name] = false - } else if requireUpdate && !field.Updatable { - results[name] = false - } - } - } - - return results, !notRestricted && len(stmt.Selects) > 0 -} - // ConvertMapToValuesForCreate convert map to values func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { columns := make([]string, 0, len(mapValue)) - selectColumns, restricted := SelectAndOmitColumns(stmt, true, false) + selectColumns, restricted := stmt.SelectAndOmitColumns(true, false) var keys []string for k := range mapValue { @@ -91,7 +37,7 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st var ( columns = []string{} result = map[string][]interface{}{} - selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) + selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) ) for idx, mapValue := range mapValues { diff --git a/callbacks/update.go b/callbacks/update.go index 1ea77552..f84e933c 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -110,7 +110,7 @@ func AfterUpdate(db *gorm.DB) { // ConvertToAssignments convert to update assignments func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { var ( - selectColumns, restricted = SelectAndOmitColumns(stmt, false, true) + selectColumns, restricted = stmt.SelectAndOmitColumns(false, true) assignValue func(field *schema.Field, value interface{}) ) diff --git a/errors.go b/errors.go index b41eefae..e1b58835 100644 --- a/errors.go +++ b/errors.go @@ -29,4 +29,6 @@ var ( ErrUnsupportedDriver = errors.New("unsupported driver") // ErrRegistered registered ErrRegistered = errors.New("registered") + // ErrInvalidField invalid field + ErrInvalidField = errors.New("invalid field") ) diff --git a/statement.go b/statement.go index e902b739..164ddbd7 100644 --- a/statement.go +++ b/statement.go @@ -12,6 +12,7 @@ import ( "gorm.io/gorm/clause" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) // Statement statement @@ -370,3 +371,119 @@ func (stmt *Statement) clone() *Statement { return newStmt } + +// Helpers +// SetColumn set column's value +func (stmt *Statement) SetColumn(name string, value interface{}) { + if v, ok := stmt.Dest.(map[string]interface{}); ok { + v[name] = value + } else if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(name); field != nil { + field.Set(stmt.ReflectValue, value) + } else { + stmt.AddError(ErrInvalidField) + } + } else { + stmt.AddError(ErrInvalidField) + } +} + +// Changed check model changed or not when updating +func (stmt *Statement) Changed(fields ...string) bool { + modelValue := reflect.ValueOf(stmt.Model) + for modelValue.Kind() == reflect.Ptr { + modelValue = modelValue.Elem() + } + + selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) + changed := func(field *schema.Field) bool { + fieldValue, isZero := field.ValueOf(modelValue) + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if v, ok := stmt.Dest.(map[string]interface{}); ok { + if fv, ok := v[field.Name]; ok { + return !utils.AssertEqual(fv, fieldValue) + } else if fv, ok := v[field.DBName]; ok { + return !utils.AssertEqual(fv, fieldValue) + } else if isZero { + return true + } + } else { + changedValue, _ := field.ValueOf(stmt.ReflectValue) + return !utils.AssertEqual(changedValue, fieldValue) + } + } + return false + } + + if len(fields) == 0 { + for _, field := range stmt.Schema.FieldsByDBName { + if changed(field) { + return true + } + } + } else { + for _, name := range fields { + if field := stmt.Schema.LookUpField(name); field != nil { + if changed(field) { + return true + } + } + } + } + + return false +} + +// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false +func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { + results := map[string]bool{} + notRestricted := false + + // select columns + for _, column := range stmt.Selects { + if column == "*" { + notRestricted = true + for _, dbName := range stmt.Schema.DBNames { + results[dbName] = true + } + } else if column == clause.Associations { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = true + } + } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { + results[field.DBName] = true + } else { + results[column] = true + } + } + + // omit columns + for _, omit := range stmt.Omits { + if omit == clause.Associations { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = false + } + } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { + results[field.DBName] = false + } else { + results[omit] = false + } + } + + if stmt.Schema != nil { + for _, field := range stmt.Schema.Fields { + name := field.DBName + if name == "" { + name = field.Name + } + + if requireCreate && !field.Creatable { + results[name] = false + } else if requireUpdate && !field.Updatable { + results[name] = false + } + } + } + + return results, !notRestricted && len(stmt.Selects) > 0 +} diff --git a/tests/hooks_test.go b/tests/hooks_test.go index c74e8f10..8f8c60f5 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -285,3 +285,84 @@ func TestUseDBInHooks(t *testing.T) { t.Errorf("Admin product's price should not be changed, expects: %v, got %v", 600, result4.Price) } } + +type Product3 struct { + gorm.Model + Name string + Code string + Price int64 + Owner string +} + +func (s Product3) BeforeCreate(tx *gorm.DB) (err error) { + tx.Statement.SetColumn("Price", s.Price+100) + return nil +} + +func (s Product3) BeforeUpdate(tx *gorm.DB) (err error) { + if tx.Statement.Changed() { + tx.Statement.SetColumn("Price", s.Price+10) + } + + if tx.Statement.Changed("Code") { + s.Price += 20 + tx.Statement.SetColumn("Price", s.Price+30) + } + return nil +} + +func TestSetColumn(t *testing.T) { + DB.Migrator().DropTable(&Product3{}) + DB.AutoMigrate(&Product3{}) + + product := Product3{Name: "Product", Price: 0} + DB.Create(&product) + + if product.Price != 100 { + t.Errorf("invalid price after create, got %+v", product) + } + + DB.Model(&product).Select("code", "price").Updates(map[string]interface{}{"code": "L1212"}) + + if product.Price != 150 || product.Code != "L1212" { + t.Errorf("invalid data after update, got %+v", product) + } + + // Code not changed, price should not change + DB.Model(&product).Updates(map[string]interface{}{"Name": "Product New"}) + + if product.Name != "Product New" || product.Price != 160 || product.Code != "L1212" { + t.Errorf("invalid data after update, got %+v", product) + } + + // Code changed, but not selected, price should not change + DB.Model(&product).Select("Name", "Price").Updates(map[string]interface{}{"Name": "Product New2", "code": "L1213"}) + + if product.Name != "Product New2" || product.Price != 170 || product.Code != "L1212" { + t.Errorf("invalid data after update, got %+v", product) + } + + // Code changed, price should changed + DB.Model(&product).Select("Name", "Code", "Price").Updates(map[string]interface{}{"Name": "Product New3", "code": "L1213"}) + + if product.Name != "Product New3" || product.Price != 220 || product.Code != "L1213" { + t.Errorf("invalid data after update, got %+v", product) + } + + var result Product3 + DB.First(&result, product.ID) + + AssertEqual(t, result, product) + + // Code changed, price not selected, price should not change + DB.Model(&product).Select("code").Updates(map[string]interface{}{"name": "L1214"}) + + if product.Price != 220 || product.Code != "L1213" { + t.Errorf("invalid data after update, got %+v", product) + } + + var result2 Product3 + DB.First(&result2, product.ID) + + AssertEqual(t, result2, product) +} diff --git a/utils/utils.go b/utils/utils.go index 81d2dc34..9bf00683 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -68,3 +68,18 @@ func ToStringKey(values ...interface{}) string { return strings.Join(results, "_") } + +func AssertEqual(src, dst interface{}) bool { + if !reflect.DeepEqual(src, dst) { + if valuer, ok := src.(driver.Valuer); ok { + src, _ = valuer.Value() + } + + if valuer, ok := dst.(driver.Valuer); ok { + dst, _ = valuer.Value() + } + + return reflect.DeepEqual(src, dst) + } + return true +} From 3e4dbde920e3fe88a56a97429ab8146408d18da6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 30 Jun 2020 22:47:21 +0800 Subject: [PATCH 499/881] Test Hooks For Slice --- callbacks/callmethod.go | 4 +++- statement.go | 17 +++++++++++---- tests/hooks_test.go | 48 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 5 deletions(-) diff --git a/callbacks/callmethod.go b/callbacks/callmethod.go index a0e9b0e7..0160f354 100644 --- a/callbacks/callmethod.go +++ b/callbacks/callmethod.go @@ -11,8 +11,10 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { if called := fc(db.Statement.Dest, tx); !called { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: + db.Statement.CurDestIndex = 0 for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - fc(db.Statement.ReflectValue.Index(i).Addr().Interface(), tx) + fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx) + db.Statement.CurDestIndex++ } case reflect.Struct: fc(db.Statement.ReflectValue.Addr().Interface(), tx) diff --git a/statement.go b/statement.go index 164ddbd7..e65a064f 100644 --- a/statement.go +++ b/statement.go @@ -38,6 +38,7 @@ type Statement struct { SQL strings.Builder Vars []interface{} NamedVars []sql.NamedArg + CurDestIndex int attrs []interface{} assigns []interface{} } @@ -379,7 +380,12 @@ func (stmt *Statement) SetColumn(name string, value interface{}) { v[name] = value } else if stmt.Schema != nil { if field := stmt.Schema.LookUpField(name); field != nil { - field.Set(stmt.ReflectValue, value) + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) + case reflect.Struct: + field.Set(stmt.ReflectValue, value) + } } else { stmt.AddError(ErrInvalidField) } @@ -395,17 +401,20 @@ func (stmt *Statement) Changed(fields ...string) bool { modelValue = modelValue.Elem() } + switch modelValue.Kind() { + case reflect.Slice, reflect.Array: + modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex) + } + selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) changed := func(field *schema.Field) bool { - fieldValue, isZero := field.ValueOf(modelValue) + fieldValue, _ := field.ValueOf(modelValue) if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := stmt.Dest.(map[string]interface{}); ok { if fv, ok := v[field.Name]; ok { return !utils.AssertEqual(fv, fieldValue) } else if fv, ok := v[field.DBName]; ok { return !utils.AssertEqual(fv, fieldValue) - } else if isZero { - return true } } else { changedValue, _ := field.ValueOf(stmt.ReflectValue) diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 8f8c60f5..ed5ee746 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -366,3 +366,51 @@ func TestSetColumn(t *testing.T) { AssertEqual(t, result2, product) } + +func TestHooksForSlice(t *testing.T) { + products := []*Product3{ + {Name: "Product-1", Price: 100}, + {Name: "Product-2", Price: 200}, + {Name: "Product-3", Price: 300}, + } + + DB.Create(&products) + + for idx, value := range []int64{200, 300, 400} { + if products[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price) + } + } + + DB.Model(&products).Update("Name", "product-name") + + // will set all product's price to last product's price + 10 + for idx, value := range []int64{410, 410, 410} { + if products[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price) + } + } + + products2 := []Product3{ + {Name: "Product-1", Price: 100}, + {Name: "Product-2", Price: 200}, + {Name: "Product-3", Price: 300}, + } + + DB.Create(&products2) + + for idx, value := range []int64{200, 300, 400} { + if products2[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price) + } + } + + DB.Model(&products2).Update("Name", "product-name") + + // will set all product's price to last product's price + 10 + for idx, value := range []int64{410, 410, 410} { + if products2[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price) + } + } +} From 7aaac3a580d5c0a4b28853c8f53d8feb0327530f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 30 Jun 2020 23:06:48 +0800 Subject: [PATCH 500/881] Allow to use sql function in Group, Pluck --- chainable_api.go | 4 +++- finisher_api.go | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index dbd783fd..e2ba44cc 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -162,8 +162,10 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { // Group specify the group method on the find func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() + + fields := strings.FieldsFunc(name, utils.IsChar) tx.Statement.AddClause(clause.GroupBy{ - Columns: []clause.Column{{Name: name}}, + Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, }) return } diff --git a/finisher_api.go b/finisher_api.go index 6d961811..af040106 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -8,6 +8,7 @@ import ( "strings" "gorm.io/gorm/clause" + "gorm.io/gorm/utils" ) // Create insert the value into database @@ -325,9 +326,10 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx.AddError(ErrorModelValueRequired) } + fields := strings.FieldsFunc(column, utils.IsChar) tx.Statement.AddClauseIfNotExists(clause.Select{ Distinct: tx.Statement.Distinct, - Columns: []clause.Column{{Name: column}}, + Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, }) tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) From 9d7df71332b26949d6d61eff94ad416c0984d7f3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 1 Jul 2020 08:56:21 +0800 Subject: [PATCH 501/881] Query with smaller struct --- callbacks/query.go | 12 +++++++++++- scan.go | 24 +++++++++++++++++------- tests/query_test.go | 23 ++++++++++++++++++++++- 3 files changed, 50 insertions(+), 9 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 27d53a4d..4b7f5bd5 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -40,7 +40,7 @@ func BuildQuerySQL(db *gorm.DB) { db.Statement.SQL.Grow(100) clauseSelect := clause.Select{Distinct: db.Statement.Distinct} - if db.Statement.ReflectValue.Kind() == reflect.Struct { + if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { var conds []clause.Expression for _, primaryField := range db.Statement.Schema.PrimaryFields { if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { @@ -64,6 +64,16 @@ func BuildQuerySQL(db *gorm.DB) { clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} } } + } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() && db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType { + stmt := gorm.Statement{DB: db} + // smaller struct + if err := stmt.Parse(db.Statement.Dest); err == nil { + clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) + + for idx, dbName := range stmt.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Name: dbName} + } + } } // inline joins diff --git a/scan.go b/scan.go index 2d227ec2..0b199029 100644 --- a/scan.go +++ b/scan.go @@ -69,6 +69,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { db.AddError(rows.Scan(dest)) } default: + Schema := db.Statement.Schema + switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: var ( @@ -84,16 +86,20 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) - if db.Statement.Schema != nil { + if Schema != nil { + if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { + Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + } + for idx, column := range columns { - if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { + if field := Schema.LookUpField(column); field != nil && field.Readable { fields[idx] = field } else if names := strings.Split(column, "__"); len(names) > 1 { if len(joinFields) == 0 { joinFields = make([][2]*schema.Field, len(columns)) } - if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { + if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { fields[idx] = field joinFields[idx] = [2]*schema.Field{rel.Field, field} @@ -151,12 +157,16 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } } case reflect.Struct: + if db.Statement.ReflectValue.Type() != Schema.ModelType { + Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + } + if initialized || rows.Next() { for idx, column := range columns { - if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { + if field := Schema.LookUpField(column); field != nil && field.Readable { values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { + if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() continue @@ -172,10 +182,10 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { db.AddError(rows.Scan(values...)) for idx, column := range columns { - if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { + if field := Schema.LookUpField(column); field != nil && field.Readable { field.Set(db.Statement.ReflectValue, values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { + if rel, ok := Schema.Relationships.Relations[names[0]]; ok { relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { value := reflect.ValueOf(values[idx]).Elem() diff --git a/tests/query_test.go b/tests/query_test.go index de65b63b..7973fd51 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -3,6 +3,7 @@ package tests_test import ( "fmt" "reflect" + "regexp" "sort" "strconv" "testing" @@ -144,8 +145,8 @@ func TestFillSmallerStruct(t *testing.T) { user := User{Name: "SmallerUser", Age: 100} DB.Save(&user) type SimpleUser struct { - Name string ID int64 + Name string UpdatedAt time.Time CreatedAt time.Time } @@ -156,6 +157,26 @@ func TestFillSmallerStruct(t *testing.T) { } AssertObjEqual(t, user, simpleUser, "Name", "ID", "UpdatedAt", "CreatedAt") + + var simpleUser2 SimpleUser + if err := DB.Model(&User{}).Select("id").First(&simpleUser2, user.ID).Error; err != nil { + t.Fatalf("Failed to query smaller user, got error %v", err) + } + + AssertObjEqual(t, user, simpleUser2, "ID") + + var simpleUsers []SimpleUser + if err := DB.Model(&User{}).Select("id").Find(&simpleUsers, user.ID).Error; err != nil || len(simpleUsers) != 1 { + t.Fatalf("Failed to query smaller user, got error %v", err) + } + + AssertObjEqual(t, user, simpleUsers[0], "ID") + + result := DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&simpleUsers, user.ID) + + if !regexp.MustCompile("SELECT .*id.*name.*updated_at.*created_at.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String()) + } } func TestPluck(t *testing.T) { From d342f4122af9a14b2d4aa768af759ea6a0c56d7a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 1 Jul 2020 10:19:52 +0800 Subject: [PATCH 502/881] Better support Count in chain --- finisher_api.go | 2 ++ tests/count_test.go | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/finisher_api.go b/finisher_api.go index af040106..25c56e49 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -269,6 +269,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 0 { tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) + defer tx.Statement.AddClause(clause.Select{}) } else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") { expr := clause.Expr{SQL: "count(1)"} @@ -281,6 +282,7 @@ func (db *DB) Count(count *int64) (tx *DB) { } tx.Statement.AddClause(clause.Select{Expression: expr}) + defer tx.Statement.AddClause(clause.Select{}) } tx.Statement.Dest = count diff --git a/tests/count_test.go b/tests/count_test.go index 0662ae5c..826d6a36 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -27,6 +27,14 @@ func TestCount(t *testing.T) { t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users)) } + if err := DB.Model(&User{}).Where("name = ?", user1.Name).Or("name = ?", user3.Name).Count(&count).Find(&users).Error; err != nil { + t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + if count != int64(len(users)) { + t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users)) + } + DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2) if count1 != 1 || count2 != 3 { t.Errorf("multiple count in chain should works") From 65d6c19d73e5574d5d6024b2a3fe6008962c6300 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 1 Jul 2020 11:47:46 +0800 Subject: [PATCH 503/881] Test multiple index tags --- schema/index_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/schema/index_test.go b/schema/index_test.go index 384e902b..71a70a8c 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -16,7 +16,7 @@ type UserIndex struct { Name5 int64 `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"` Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"` Age int64 `gorm:"index:profile,expression:ABS(age)"` - OID int64 `gorm:"index:idx_id"` + OID int64 `gorm:"index:idx_id;index:idx_oid,unique"` MemberNumber string `gorm:"index:idx_id"` } @@ -70,6 +70,11 @@ func TestParseIndex(t *testing.T) { Name: "idx_id", Fields: []schema.IndexOption{{}, {}}, }, + "idx_oid": { + Name: "idx_oid", + Class: "UNIQUE", + Fields: []schema.IndexOption{{}}, + }, } indices := user.ParseIndexes() From 322c6a36ee92dd8ab375cc9eda5fb267db131c5b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 1 Jul 2020 19:50:24 +0800 Subject: [PATCH 504/881] Fix .github config --- .github/ISSUE_TEMPLATE.md | 5 -- .github/PULL_REQUEST_TEMPLATE.md | 11 --- .github/labeler.yml | 6 -- .github/labels.json | 139 ++++++++++++++++++++++++++++++ .github/workflows/issue.yml | 15 ---- .github/workflows/issue_stale.yml | 19 ---- .github/workflows/labeler.yml | 19 ++++ .github/workflows/stale.yml | 21 +++++ 8 files changed, 179 insertions(+), 56 deletions(-) delete mode 100644 .github/ISSUE_TEMPLATE.md delete mode 100644 .github/PULL_REQUEST_TEMPLATE.md delete mode 100644 .github/labeler.yml create mode 100644 .github/labels.json delete mode 100644 .github/workflows/issue.yml delete mode 100644 .github/workflows/issue_stale.yml create mode 100644 .github/workflows/labeler.yml create mode 100644 .github/workflows/stale.yml diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md deleted file mode 100644 index ac311633..00000000 --- a/.github/ISSUE_TEMPLATE.md +++ /dev/null @@ -1,5 +0,0 @@ -Your issue may already be reported! Please search on the [issue track](https://github.com/go-gorm/gorm/issues) before creating one. - -To report a bug, your issue *have to* include an [GORM playground pull request link](https://github.com/go-gorm/playground), for general questions, please delete below line. - -## GORM Playground Link: https://github.com/go-gorm/playground/pull/1 (change this to your link) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md deleted file mode 100644 index 930ff176..00000000 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ /dev/null @@ -1,11 +0,0 @@ -Make sure these boxes checked before submitting your pull request. - -- [] Do only one thing -- [] No API-breaking changes -- [] New code/logic commented & tested (important) - -For significant changes like big bug fixes, new features, please open an issue to make an agreement on an implementation design/plan first before starting it. - -### What did this pull request do? - -### Use Case diff --git a/.github/labeler.yml b/.github/labeler.yml deleted file mode 100644 index d96bafa0..00000000 --- a/.github/labeler.yml +++ /dev/null @@ -1,6 +0,0 @@ -# Add/remove 'critical' label if issue contains the words 'urgent' or 'critical' -HasGormPlaygroundTestCase: - - '(github.com/go-gorm/playground/pull/\d)' - -NoTestCase: - - '(change this to your link)' diff --git a/.github/labels.json b/.github/labels.json new file mode 100644 index 00000000..8b1ce849 --- /dev/null +++ b/.github/labels.json @@ -0,0 +1,139 @@ +{ + "labels": { + "critical": { + "name": "type:critical", + "colour": "#E84137", + "description": "critical questions" + }, + "question": { + "name": "type:question", + "colour": "#EDEDED", + "description": "general questions" + }, + "with_playground": { + "name": "type:with reproduction steps", + "colour": "#00ff00", + "description": "with reproduction steps" + }, + "without_playground": { + "name": "type:missing reproduction steps", + "colour": "#CF2E1F", + "description": "missing reproduction steps" + }, + "has_pr": { + "name": "type:has pull request", + "colour": "#43952A", + "description": "has pull request" + }, + "not_tested": { + "name": "type:not tested", + "colour": "#CF2E1F", + "description": "not tested" + }, + "tested": { + "name": "type:tested", + "colour": "#00ff00", + "description": "tested" + }, + "breaking_change": { + "name": "type:breaking change", + "colour": "#CF2E1F", + "description": "breaking change" + } + }, + "issue": { + "with_playground": { + "requires": 1, + "conditions": [ + { + "type": "descriptionMatches", + "pattern": "/github.com\/go-gorm\/playground\/pull\/\\d\\d+/s" + } + ] + }, + "critical": { + "requires": 1, + "conditions": [ + { + "type": "descriptionMatches", + "pattern": "/(critical|urgent)/i" + }, + { + "type": "titleMatches", + "pattern": "/(critical|urgent)/i" + } + ] + }, + "question": { + "requires": 1, + "conditions": [ + { + "type": "titleMatches", + "pattern": "/question/i" + }, + { + "type": "descriptionMatches", + "pattern": "/question/i" + } + ] + }, + "without_playground": { + "requires": 5, + "conditions": [ + { + "type": "descriptionMatches", + "pattern": "/^((?!github.com\/go-gorm\/playground\/pull\/\\d\\d+).)*$/s" + }, + { + "type": "titleMatches", + "pattern": "/^((?!question).)*$/s" + }, + { + "type": "descriptionMatches", + "pattern": "/^((?!question).)*$/is" + }, + { + "type": "titleMatches", + "pattern": "/^((?!critical|urgent).)*$/s" + }, + { + "type": "descriptionMatches", + "pattern": "/^((?!critical|urgent).)*$/s" + } + ] + } + }, + "pr": { + "critical": { + "requires": 1, + "conditions": [ + { + "type": "descriptionMatches", + "pattern": "/(critical|urgent)/i" + }, + { + "type": "titleMatches", + "pattern": "/(critical|urgent)/i" + } + ] + }, + "not_tested": { + "requires": 1, + "conditions": [ + { + "type": "descriptionMatches", + "pattern": "/\\[\\] Tested/" + } + ] + }, + "breaking_change": { + "requires": 1, + "conditions": [ + { + "type": "descriptionMatches", + "pattern": "/\\[\\] Non breaking API changes/" + } + ] + } + } +} diff --git a/.github/workflows/issue.yml b/.github/workflows/issue.yml deleted file mode 100644 index 0759782c..00000000 --- a/.github/workflows/issue.yml +++ /dev/null @@ -1,15 +0,0 @@ -name: "Issue-Labeler" -on: - issues: - types: [opened, edited] - -jobs: - triage: - runs-on: ubuntu-latest - steps: - - uses: github/issue-labeler@v2.0 - with: - repo-token: "${{ secrets.GITHUB_TOKEN }}" - configuration-path: ".github/labeler.yml" - not-before: "2020-01-15T02:54:32Z" - enable-versioned-regex: 0 \ No newline at end of file diff --git a/.github/workflows/issue_stale.yml b/.github/workflows/issue_stale.yml deleted file mode 100644 index fadfb522..00000000 --- a/.github/workflows/issue_stale.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Issue cleanup -on: - schedule: - - cron: '0 1 * * *' # At 01:00, everyday -jobs: - triage_issues: - name: Issue triage - runs-on: ubuntu-latest - steps: - - name: Find old issues and mark them stale - uses: Krizzu/issue-triage-action@v1.0.0 - with: - ghToken: ${{ secrets.GITHUB_TOKEN }} - staleAfter: 7 - closeAfter: 14 - staleLabel: "STALE 📺" - staleComment: "This issue is %DAYS_OLD% days old, marking as stale! cc: @%AUTHOR%" - closeComment: "Issue last updated %DAYS_OLD% days ago! Closing down!" - showLogs: true \ No newline at end of file diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml new file mode 100644 index 00000000..1490730b --- /dev/null +++ b/.github/workflows/labeler.yml @@ -0,0 +1,19 @@ +name: "Issue Labeler" +on: + issues: + types: [opened, edited, reopened] + pull_request: + types: [opened, edited, reopened, ready_for_review, synchronize] + +jobs: + triage: + runs-on: ubuntu-latest + name: Label issues and pull requests + steps: + - name: check out + uses: actions/checkout@v2 + + - name: labeler + uses: jinzhu/super-labeler-action@develop + with: + GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 00000000..6fb714ca --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,21 @@ +name: "Close Missing Playground issues" +on: + schedule: + - cron: "*/10 * * * *" + +jobs: + stale: + runs-on: ubuntu-latest + env: + ACTIONS_STEP_DEBUG: true + steps: + - name: Close Stale Issues + uses: actions/stale@v3.0.7 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + stale-issue-message: "This issue has been automatically marked as stale as it missing playground pull request link, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details, it will be closed in 2 days if no further activity occurs." + stale-issue-label: "status:stale" + days-before-stale: 0 + days-before-close: 2 + remove-stale-when-updated: true + only-labels: "type:missing reproduction steps" From 63e48191a83f0891af4c7a19a8a0c89a521240a0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 1 Jul 2020 21:28:19 +0800 Subject: [PATCH 505/881] Test failed to save association should rollback, close #3100 --- callbacks/associations.go | 16 ++++++------- tests/hooks_test.go | 48 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 8 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index bcb6c414..0968b460 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -139,10 +139,10 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ + db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(elems.Interface()) + }).Create(elems.Interface()).Error) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { @@ -162,10 +162,10 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ + db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(f.Interface()) + }).Create(f.Interface()).Error) } } } @@ -221,10 +221,10 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ + db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(elems.Interface()) + }).Create(elems.Interface()).Error) } } @@ -286,7 +286,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if elems.Len() > 0 { - db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()) + db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()).Error) for i := 0; i < elems.Len(); i++ { appendToJoins(objs[i], elems.Index(i)) @@ -294,7 +294,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if joins.Len() > 0 { - db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()) + db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()).Error) } } } diff --git a/tests/hooks_test.go b/tests/hooks_test.go index ed5ee746..3612857b 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -368,6 +368,9 @@ func TestSetColumn(t *testing.T) { } func TestHooksForSlice(t *testing.T) { + DB.Migrator().DropTable(&Product3{}) + DB.AutoMigrate(&Product3{}) + products := []*Product3{ {Name: "Product-1", Price: 100}, {Name: "Product-2", Price: 200}, @@ -414,3 +417,48 @@ func TestHooksForSlice(t *testing.T) { } } } + +type Product4 struct { + gorm.Model + Name string + Code string + Price int64 + Owner string + Item ProductItem +} + +type ProductItem struct { + gorm.Model + Code string + Product4ID uint +} + +func (pi ProductItem) BeforeCreate(*gorm.DB) error { + if pi.Code == "invalid" { + return errors.New("invalid item") + } + return nil +} + +func TestFailedToSaveAssociationShouldRollback(t *testing.T) { + DB.Migrator().DropTable(&Product4{}, &ProductItem{}) + DB.AutoMigrate(&Product4{}, &ProductItem{}) + + product := Product4{Name: "Product-1", Price: 100, Item: ProductItem{Code: "invalid"}} + if err := DB.Create(&product).Error; err == nil { + t.Errorf("should got failed to save, but error is nil") + } + + if DB.First(&Product4{}, "name = ?", product.Name).Error == nil { + t.Errorf("should got RecordNotFound, but got nil") + } + + product = Product4{Name: "Product-2", Price: 100, Item: ProductItem{Code: "valid"}} + if err := DB.Create(&product).Error; err != nil { + t.Errorf("should create product, but got error %v", err) + } + + if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil { + t.Errorf("should find product, but got error %v", err) + } +} From 3f355dc050111d506478b9ec9bcda924596b5bcf Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 2 Jul 2020 10:14:30 +0800 Subject: [PATCH 506/881] Refactor --- callbacks/associations.go | 25 ++++--------------------- prepare_stmt.go | 24 ++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 0968b460..408f3fc9 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -5,8 +5,6 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" - "gorm.io/gorm/schema" - "gorm.io/gorm/utils" ) func SaveBeforeAssociations(db *gorm.DB) { @@ -15,7 +13,7 @@ func SaveBeforeAssociations(db *gorm.DB) { // Save Belongs To associations for _, rel := range db.Statement.Schema.Relationships.BelongsTo { - if !saveAssociationCheck(db, rel, selectColumns, restricted) { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { continue } @@ -94,7 +92,7 @@ func SaveAfterAssociations(db *gorm.DB) { // Save Has One associations for _, rel := range db.Statement.Schema.Relationships.HasOne { - if !saveAssociationCheck(db, rel, selectColumns, restricted) { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { continue } @@ -172,7 +170,7 @@ func SaveAfterAssociations(db *gorm.DB) { // Save Has Many associations for _, rel := range db.Statement.Schema.Relationships.HasMany { - if !saveAssociationCheck(db, rel, selectColumns, restricted) { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { continue } @@ -230,7 +228,7 @@ func SaveAfterAssociations(db *gorm.DB) { // Save Many2Many associations for _, rel := range db.Statement.Schema.Relationships.Many2Many { - if !saveAssociationCheck(db, rel, selectColumns, restricted) { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { continue } @@ -299,18 +297,3 @@ func SaveAfterAssociations(db *gorm.DB) { } } } - -func saveAssociationCheck(db *gorm.DB, rel *schema.Relationship, selectColumns map[string]bool, restricted bool) bool { - savable := true - if value, ok := db.Get("gorm:save_association"); ok { - savable = utils.CheckTruth(value) - } - - if savable { - if v, ok := selectColumns[rel.Name]; (ok && v) || (!ok && !restricted) { - return true - } - } - - return false -} diff --git a/prepare_stmt.go b/prepare_stmt.go index ba9b04b6..0f112a7f 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -58,6 +58,10 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. stmt, err := db.prepare(query) if err == nil { return stmt.ExecContext(ctx, args...) + } else { + db.mux.Lock() + delete(db.Stmts, query) + db.mux.Unlock() } return nil, err } @@ -66,6 +70,10 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . stmt, err := db.prepare(query) if err == nil { return stmt.QueryContext(ctx, args...) + } else { + db.mux.Lock() + delete(db.Stmts, query) + db.mux.Unlock() } return nil, err } @@ -74,6 +82,10 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg stmt, err := db.prepare(query) if err == nil { return stmt.QueryRowContext(ctx, args...) + } else { + db.mux.Lock() + delete(db.Stmts, query) + db.mux.Unlock() } return &sql.Row{} } @@ -87,6 +99,10 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. stmt, err := tx.PreparedStmtDB.prepare(query) if err == nil { return tx.Tx.Stmt(stmt).ExecContext(ctx, args...) + } else { + tx.PreparedStmtDB.mux.Lock() + delete(tx.PreparedStmtDB.Stmts, query) + tx.PreparedStmtDB.mux.Unlock() } return nil, err } @@ -95,6 +111,10 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . stmt, err := tx.PreparedStmtDB.prepare(query) if err == nil { return tx.Tx.Stmt(stmt).QueryContext(ctx, args...) + } else { + tx.PreparedStmtDB.mux.Lock() + delete(tx.PreparedStmtDB.Stmts, query) + tx.PreparedStmtDB.mux.Unlock() } return nil, err } @@ -103,6 +123,10 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg stmt, err := tx.PreparedStmtDB.prepare(query) if err == nil { return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...) + } else { + tx.PreparedStmtDB.mux.Lock() + delete(tx.PreparedStmtDB.Stmts, query) + tx.PreparedStmtDB.mux.Unlock() } return &sql.Row{} } From 3c03b6e5271973b6db4543926773b237f5fc4540 Mon Sep 17 00:00:00 2001 From: SmallTianTian Date: Thu, 2 Jul 2020 18:14:33 +0800 Subject: [PATCH 507/881] fix no limit no offset. (#3101) * fix no limit no offset. * add test for playground. --- clause/limit.go | 10 ++++++---- clause/limit_test.go | 14 +++++++++++++- tests/query_test.go | 6 ++++++ 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/clause/limit.go b/clause/limit.go index ba5cf6c4..1946820d 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -18,11 +18,13 @@ func (limit Limit) Build(builder Builder) { if limit.Limit > 0 { builder.WriteString("LIMIT ") builder.WriteString(strconv.Itoa(limit.Limit)) - - if limit.Offset > 0 { - builder.WriteString(" OFFSET ") - builder.WriteString(strconv.Itoa(limit.Offset)) + } + if limit.Offset > 0 { + if limit.Limit > 0 { + builder.WriteString(" ") } + builder.WriteString("OFFSET ") + builder.WriteString(strconv.Itoa(limit.Offset)) } } diff --git a/clause/limit_test.go b/clause/limit_test.go index 80317dc3..c26294aa 100644 --- a/clause/limit_test.go +++ b/clause/limit_test.go @@ -20,6 +20,18 @@ func TestLimit(t *testing.T) { }}, "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}}, + "SELECT * FROM `users` OFFSET 20", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}}, + "SELECT * FROM `users` OFFSET 30", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}}, + "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, + }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}}, "SELECT * FROM `users` LIMIT 10 OFFSET 30", nil, @@ -30,7 +42,7 @@ func TestLimit(t *testing.T) { }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}}, - "SELECT * FROM `users`", nil, + "SELECT * FROM `users` OFFSET 30", nil, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}}, diff --git a/tests/query_test.go b/tests/query_test.go index 7973fd51..594fc268 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -381,6 +381,12 @@ func TestOffset(t *testing.T) { if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { t.Errorf("Offset should work") } + DB.Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) + + if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { + t.Errorf("Offset should work without limit.") + } + } func TestSearchWithMap(t *testing.T) { From 2d945a964149da5b5bc0387fe7cb811b874c6705 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 3 Jul 2020 08:53:38 +0800 Subject: [PATCH 508/881] Switch pgx as default driver --- prepare_stmt.go | 6 ++++++ tests/go.mod | 6 +++--- tests/tests_test.go | 5 ++++- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index 0f112a7f..e017bb23 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -60,6 +60,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. return stmt.ExecContext(ctx, args...) } else { db.mux.Lock() + stmt.Close() delete(db.Stmts, query) db.mux.Unlock() } @@ -72,6 +73,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . return stmt.QueryContext(ctx, args...) } else { db.mux.Lock() + stmt.Close() delete(db.Stmts, query) db.mux.Unlock() } @@ -84,6 +86,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg return stmt.QueryRowContext(ctx, args...) } else { db.mux.Lock() + stmt.Close() delete(db.Stmts, query) db.mux.Unlock() } @@ -101,6 +104,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. return tx.Tx.Stmt(stmt).ExecContext(ctx, args...) } else { tx.PreparedStmtDB.mux.Lock() + stmt.Close() delete(tx.PreparedStmtDB.Stmts, query) tx.PreparedStmtDB.mux.Unlock() } @@ -113,6 +117,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . return tx.Tx.Stmt(stmt).QueryContext(ctx, args...) } else { tx.PreparedStmtDB.mux.Lock() + stmt.Close() delete(tx.PreparedStmtDB.Stmts, query) tx.PreparedStmtDB.mux.Unlock() } @@ -125,6 +130,7 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...) } else { tx.PreparedStmtDB.mux.Lock() + stmt.Close() delete(tx.PreparedStmtDB.Stmts, query) tx.PreparedStmtDB.mux.Unlock() } diff --git a/tests/go.mod b/tests/go.mod index c467f34b..3b17feac 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,11 +6,11 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.2.7 - gorm.io/driver/postgres v0.2.4 + gorm.io/driver/mysql v0.2.8 + gorm.io/driver/postgres v0.2.5 gorm.io/driver/sqlite v1.0.8 gorm.io/driver/sqlserver v0.2.4 - gorm.io/gorm v0.2.9 + gorm.io/gorm v0.2.19 ) replace gorm.io/gorm => ../ diff --git a/tests/tests_test.go b/tests/tests_test.go index fa8bad5c..9484b897 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -54,7 +54,10 @@ func OpenTestConnection() (db *gorm.DB, err error) { if dbDSN == "" { dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" } - db, err = gorm.Open(postgres.Open(dbDSN), &gorm.Config{}) + db, err = gorm.Open(postgres.New(postgres.Config{ + DSN: dbDSN, + PreferSimpleProtocol: true, + }), &gorm.Config{}) case "sqlserver": // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; // CREATE DATABASE gorm; From 8100ac76638d70065b5d3fc32caa5184c95167df Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 3 Jul 2020 09:26:23 +0800 Subject: [PATCH 509/881] Change default postgres DSN for github action --- .github/workflows/tests.yml | 2 +- tests/tests_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 108db6a6..247b1deb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -139,7 +139,7 @@ jobs: key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh + run: GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm DB.name=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh sqlserver: strategy: diff --git a/tests/tests_test.go b/tests/tests_test.go index 9484b897..afff2d0f 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -52,7 +52,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { case "postgres": log.Println("testing postgres...") if dbDSN == "" { - dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" + dbDSN = "user=gorm password=gorm DB.name=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" } db, err = gorm.Open(postgres.New(postgres.Config{ DSN: dbDSN, From f93345afa8e17725660d370f52608c3b0014bdc0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 3 Jul 2020 10:26:18 +0800 Subject: [PATCH 510/881] Close cached prepared stmt when got error --- prepare_stmt.go | 78 +++++++++++++++++++++++-------------------------- 1 file changed, 36 insertions(+), 42 deletions(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index e017bb23..197c257c 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -54,41 +54,38 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn return nil, ErrInvalidTransaction } -func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { +func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { stmt, err := db.prepare(query) if err == nil { - return stmt.ExecContext(ctx, args...) - } else { - db.mux.Lock() - stmt.Close() - delete(db.Stmts, query) - db.mux.Unlock() + result, err = stmt.ExecContext(ctx, args...) + if err != nil { + db.mux.Lock() + stmt.Close() + delete(db.Stmts, query) + db.mux.Unlock() + } } - return nil, err + return result, err } -func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { +func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { stmt, err := db.prepare(query) if err == nil { - return stmt.QueryContext(ctx, args...) - } else { - db.mux.Lock() - stmt.Close() - delete(db.Stmts, query) - db.mux.Unlock() + rows, err = stmt.QueryContext(ctx, args...) + if err != nil { + db.mux.Lock() + stmt.Close() + delete(db.Stmts, query) + db.mux.Unlock() + } } - return nil, err + return rows, err } func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { stmt, err := db.prepare(query) if err == nil { return stmt.QueryRowContext(ctx, args...) - } else { - db.mux.Lock() - stmt.Close() - delete(db.Stmts, query) - db.mux.Unlock() } return &sql.Row{} } @@ -98,41 +95,38 @@ type PreparedStmtTX struct { PreparedStmtDB *PreparedStmtDB } -func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { +func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { stmt, err := tx.PreparedStmtDB.prepare(query) if err == nil { - return tx.Tx.Stmt(stmt).ExecContext(ctx, args...) - } else { - tx.PreparedStmtDB.mux.Lock() - stmt.Close() - delete(tx.PreparedStmtDB.Stmts, query) - tx.PreparedStmtDB.mux.Unlock() + result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...) + if err != nil { + tx.PreparedStmtDB.mux.Lock() + stmt.Close() + delete(tx.PreparedStmtDB.Stmts, query) + tx.PreparedStmtDB.mux.Unlock() + } } - return nil, err + return result, err } -func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { +func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { stmt, err := tx.PreparedStmtDB.prepare(query) if err == nil { - return tx.Tx.Stmt(stmt).QueryContext(ctx, args...) - } else { - tx.PreparedStmtDB.mux.Lock() - stmt.Close() - delete(tx.PreparedStmtDB.Stmts, query) - tx.PreparedStmtDB.mux.Unlock() + rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) + if err != nil { + tx.PreparedStmtDB.mux.Lock() + stmt.Close() + delete(tx.PreparedStmtDB.Stmts, query) + tx.PreparedStmtDB.mux.Unlock() + } } - return nil, err + return rows, err } func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { stmt, err := tx.PreparedStmtDB.prepare(query) if err == nil { return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...) - } else { - tx.PreparedStmtDB.mux.Lock() - stmt.Close() - delete(tx.PreparedStmtDB.Stmts, query) - tx.PreparedStmtDB.mux.Unlock() } return &sql.Row{} } From 2416eabd3fd78eac5e5cfb549658109b4cdd356e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 4 Jul 2020 00:36:27 +0800 Subject: [PATCH 511/881] Change unique_idnex to UniqueIndex --- schema/index.go | 6 +++--- schema/index_test.go | 2 +- tests/associations_test.go | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/schema/index.go b/schema/index.go index cf3338c3..a0a71d2c 100644 --- a/schema/index.go +++ b/schema/index.go @@ -27,7 +27,7 @@ func (schema *Schema) ParseIndexes() map[string]Index { var indexes = map[string]Index{} for _, field := range schema.Fields { - if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE_INDEX"] != "" { + if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" { for _, index := range parseFieldIndexes(field) { idx := indexes[index.Name] idx.Name = index.Name @@ -76,7 +76,7 @@ func parseFieldIndexes(field *Field) (indexes []Index) { if value != "" { v := strings.Split(value, ":") k := strings.TrimSpace(strings.ToUpper(v[0])) - if k == "INDEX" || k == "UNIQUE_INDEX" { + if k == "INDEX" || k == "UNIQUEINDEX" { var ( name string tag = strings.Join(v[1:], ":") @@ -97,7 +97,7 @@ func parseFieldIndexes(field *Field) (indexes []Index) { name = field.Schema.namer.IndexName(field.Schema.Table, field.Name) } - if (k == "UNIQUE_INDEX") || settings["UNIQUE"] != "" { + if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" { settings["CLASS"] = "UNIQUE" } diff --git a/schema/index_test.go b/schema/index_test.go index 71a70a8c..f6c3d247 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -12,7 +12,7 @@ type UserIndex struct { Name string `gorm:"index"` Name2 string `gorm:"index:idx_name,unique"` Name3 string `gorm:"index:,sort:desc,collate:utf8,type:btree,length:10,where:name3 != 'jinzhu'"` - Name4 string `gorm:"unique_index"` + Name4 string `gorm:"uniqueIndex"` Name5 int64 `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"` Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"` Age int64 `gorm:"index:profile,expression:ABS(age)"` diff --git a/tests/associations_test.go b/tests/associations_test.go index 9b4dd105..c1a4e2b2 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -41,7 +41,7 @@ func TestForeignKeyConstraints(t *testing.T) { type Member struct { ID uint - Refer uint `gorm:"unique_index"` + Refer uint `gorm:"uniqueIndex"` Name string Profile Profile `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:MemberID;References:Refer"` } @@ -91,7 +91,7 @@ func TestForeignKeyConstraintsBelongsTo(t *testing.T) { type Profile struct { ID uint Name string - Refer uint `gorm:"unique_index"` + Refer uint `gorm:"uniqueIndex"` } type Member struct { From d4f8a524423baf81aecfc6caf2780eb14e2eb187 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 4 Jul 2020 07:24:30 +0800 Subject: [PATCH 512/881] Fix join table foreign key in snake_case --- schema/relationship.go | 4 ++-- schema/relationship_test.go | 12 ++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index a13d53b9..0967f8c8 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -210,7 +210,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel for idx, ownField := range ownForeignFields { joinFieldName := schema.Name + ownField.Name if len(joinForeignKeys) > idx { - joinFieldName = joinForeignKeys[idx] + joinFieldName = strings.Title(joinForeignKeys[idx]) } ownFieldsMap[joinFieldName] = true @@ -226,7 +226,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel for idx, relField := range refForeignFields { joinFieldName := relation.FieldSchema.Name + relField.Name if len(joinReferences) > idx { - joinFieldName = joinReferences[idx] + joinFieldName = strings.Title(joinReferences[idx]) } if _, ok := ownFieldsMap[joinFieldName]; ok { diff --git a/schema/relationship_test.go b/schema/relationship_test.go index defba9ce..2c09f528 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -138,8 +138,9 @@ func TestMany2ManyOverrideForeignKeyAndReferences(t *testing.T) { type User struct { gorm.Model - Profiles []Profile `gorm:"many2many:user_profiles;ForeignKey:Refer;JoinForeignKey:UserReferID;References:UserRefer;JoinReferences:ProfileRefer"` - Refer uint + Profiles []Profile `gorm:"many2many:user_profiles;ForeignKey:Refer;JoinForeignKey:UserReferID;References:UserRefer;JoinReferences:ProfileRefer"` + Profiles2 []Profile `gorm:"many2many:user_profiles2;ForeignKey:refer;JoinForeignKey:user_refer_id;References:user_refer;JoinReferences:profile_refer"` + Refer uint } checkStructRelation(t, &User{}, Relation{ @@ -149,6 +150,13 @@ func TestMany2ManyOverrideForeignKeyAndReferences(t *testing.T) { {"Refer", "User", "UserReferID", "user_profiles", "", true}, {"UserRefer", "Profile", "ProfileRefer", "user_profiles", "", false}, }, + }, Relation{ + Name: "Profiles2", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles2", Table: "user_profiles2"}, + References: []Reference{ + {"Refer", "User", "User_refer_id", "user_profiles2", "", true}, + {"UserRefer", "Profile", "Profile_refer", "user_profiles2", "", false}, + }, }) } From 6b98ced13dc3eb1b3bad01e7f3aac473c00b131f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 4 Jul 2020 07:45:07 +0800 Subject: [PATCH 513/881] Fix set time field from null, close #3108 --- schema/field.go | 6 +++++- schema/field_test.go | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/schema/field.go b/schema/field.go index f02968fa..fbcb3cef 100644 --- a/schema/field.go +++ b/schema/field.go @@ -655,7 +655,11 @@ func (field *Field) setupValuerAndSetter() { case time.Time: field.ReflectValueOf(value).Set(reflect.ValueOf(v)) case *time.Time: - field.ReflectValueOf(value).Set(reflect.ValueOf(v).Elem()) + if data != nil { + field.ReflectValueOf(value).Set(reflect.ValueOf(data).Elem()) + } else { + field.ReflectValueOf(value).Set(reflect.ValueOf(time.Time{})) + } case string: if t, err := now.Parse(data); err == nil { field.ReflectValueOf(value).Set(reflect.ValueOf(t)) diff --git a/schema/field_test.go b/schema/field_test.go index 7970b614..7027b11d 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -19,6 +19,7 @@ func TestFieldValuerAndSetter(t *testing.T) { Model: gorm.Model{ ID: 10, CreatedAt: time.Now(), + UpdatedAt: time.Now(), DeletedAt: gorm.DeletedAt{Time: time.Now(), Valid: true}, }, Name: "valuer_and_setter", @@ -34,6 +35,7 @@ func TestFieldValuerAndSetter(t *testing.T) { "name": user.Name, "id": user.ID, "created_at": user.CreatedAt, + "updated_at": user.UpdatedAt, "deleted_at": user.DeletedAt, "age": user.Age, "birthday": user.Birthday, @@ -46,6 +48,7 @@ func TestFieldValuerAndSetter(t *testing.T) { "name": "valuer_and_setter_2", "id": 2, "created_at": time.Now(), + "updated_at": nil, "deleted_at": time.Now(), "age": 20, "birthday": time.Now(), @@ -57,14 +60,17 @@ func TestFieldValuerAndSetter(t *testing.T) { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } + newValues["updated_at"] = time.Time{} checkField(t, userSchema, reflectValue, newValues) // test valuer and other type age := myint(10) + var nilTime *time.Time newValues2 := map[string]interface{}{ "name": sql.NullString{String: "valuer_and_setter_3", Valid: true}, "id": &sql.NullInt64{Int64: 3, Valid: true}, "created_at": tests.Now(), + "updated_at": nilTime, "deleted_at": time.Now(), "age": &age, "birthday": mytime(time.Now()), @@ -76,6 +82,7 @@ func TestFieldValuerAndSetter(t *testing.T) { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } + newValues2["updated_at"] = time.Time{} checkField(t, userSchema, reflectValue, newValues2) } From f835a4deaca48027a5f2d98e0b3df45b2366da35 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 4 Jul 2020 07:57:33 +0800 Subject: [PATCH 514/881] Add health check for github action databases --- .github/workflows/tests.yml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 247b1deb..0e1cbac3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -78,6 +78,12 @@ jobs: MYSQL_RANDOM_ROOT_PASSWORD: "yes" ports: - 9910:3306 + options: >- + --health-cmd "mysqladmin ping -ugorm -pgorm" + --health-interval 10s + --health-start-period 10s + --health-timeout 5s + --health-retries 10 steps: - name: Set up Go 1.x @@ -159,6 +165,12 @@ jobs: MSSQL_PASSWORD: LoremIpsum86 ports: - 9930:1433 + options: >- + --health-cmd="/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P LoremIpsum86 -l 30 -Q \"SELECT 1\" || exit 1" + --health-start-period 10s + --health-interval 10s + --health-timeout 5s + --health-retries 10 steps: - name: Set up Go 1.x From 90a40361ed38314b8ea45e703a14f0ed58925892 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 4 Jul 2020 08:21:23 +0800 Subject: [PATCH 515/881] Fix set bool field from null --- schema/field.go | 6 +++++- schema/field_test.go | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/schema/field.go b/schema/field.go index fbcb3cef..d72a26d5 100644 --- a/schema/field.go +++ b/schema/field.go @@ -479,7 +479,11 @@ func (field *Field) setupValuerAndSetter() { case bool: field.ReflectValueOf(value).SetBool(data) case *bool: - field.ReflectValueOf(value).SetBool(*data) + if data != nil { + field.ReflectValueOf(value).SetBool(*data) + } else { + field.ReflectValueOf(value).SetBool(false) + } case int64: if data > 0 { field.ReflectValueOf(value).SetBool(true) diff --git a/schema/field_test.go b/schema/field_test.go index 7027b11d..64f4a909 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -43,6 +43,7 @@ func TestFieldValuerAndSetter(t *testing.T) { } checkField(t, userSchema, reflectValue, values) + var f *bool // test setter newValues := map[string]interface{}{ "name": "valuer_and_setter_2", @@ -52,7 +53,7 @@ func TestFieldValuerAndSetter(t *testing.T) { "deleted_at": time.Now(), "age": 20, "birthday": time.Now(), - "active": false, + "active": f, } for k, v := range newValues { @@ -61,6 +62,7 @@ func TestFieldValuerAndSetter(t *testing.T) { } } newValues["updated_at"] = time.Time{} + newValues["active"] = false checkField(t, userSchema, reflectValue, newValues) // test valuer and other type From 89ea62077d4f6a1b9de92fd26b7acd6e72eb1761 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 4 Jul 2020 08:33:10 +0800 Subject: [PATCH 516/881] DryRun for RowQuery, Exec, close #3106 --- callbacks/raw.go | 2 +- callbacks/row.go | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/callbacks/raw.go b/callbacks/raw.go index 4093a5ab..d594ab39 100644 --- a/callbacks/raw.go +++ b/callbacks/raw.go @@ -5,7 +5,7 @@ import ( ) func RawExec(db *gorm.DB) { - if db.Error == nil { + if db.Error == nil && !db.DryRun { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err != nil { db.AddError(err) diff --git a/callbacks/row.go b/callbacks/row.go index b25503ff..7e70382e 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -10,10 +10,12 @@ func RowQuery(db *gorm.DB) { BuildQuerySQL(db) } - if _, ok := db.Get("rows"); ok { - db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - } else { - db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if !db.DryRun { + if _, ok := db.Get("rows"); ok { + db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } else { + db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } } } } From 1a2fabb34d66d7581b8a37034c3575650f2a9aaa Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 5 Jul 2020 11:53:10 +0800 Subject: [PATCH 517/881] Test Not --- clause/where.go | 2 +- statement.go | 28 +++++++++++++++++++++++++++- tests/create_test.go | 2 +- tests/query_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 68 insertions(+), 3 deletions(-) diff --git a/clause/where.go b/clause/where.go index f7cd3318..a0f4598d 100644 --- a/clause/where.go +++ b/clause/where.go @@ -128,7 +128,7 @@ func (not NotConditions) Build(builder Builder) { if negationBuilder, ok := c.(NegationExpressionBuilder); ok { negationBuilder.NegationBuild(builder) } else { - builder.WriteString(" NOT ") + builder.WriteString("NOT ") c.Build(builder) } } diff --git a/statement.go b/statement.go index e65a064f..c03f6f88 100644 --- a/statement.go +++ b/statement.go @@ -265,7 +265,18 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c } case map[string]interface{}: for i, j := range v { - conds = append(conds, clause.Eq{Column: i, Value: j}) + reflectValue := reflect.Indirect(reflect.ValueOf(j)) + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + values := make([]interface{}, reflectValue.Len()) + for i := 0; i < reflectValue.Len(); i++ { + values[i] = reflectValue.Index(i).Interface() + } + + conds = append(conds, clause.IN{Column: i, Values: values}) + default: + conds = append(conds, clause.Eq{Column: i, Value: j}) + } } default: reflectValue := reflect.Indirect(reflect.ValueOf(arg)) @@ -299,6 +310,21 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c } } } else if len(conds) == 0 { + if len(args) == 1 { + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + values := make([]interface{}, reflectValue.Len()) + for i := 0; i < reflectValue.Len(); i++ { + values[i] = reflectValue.Index(i).Interface() + } + + if len(values) > 0 { + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) + } + return + } + } + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) } } diff --git a/tests/create_test.go b/tests/create_test.go index 75059f18..46cc06c6 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -307,7 +307,7 @@ func TestCreateWithNoGORMPrimaryKey(t *testing.T) { func TestSelectWithCreate(t *testing.T) { user := *GetUser("select_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) - DB.Select("Account", "Toys", "Manager", "ManagerID", "Languages", "Name", "CreatedAt", "UpdatedAt", "Age", "Active").Create(&user) + DB.Select("Account", "Toys", "Manager", "ManagerID", "Languages", "Name", "CreatedAt", "Age", "Active").Create(&user) var user2 User DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&user2, user.ID) diff --git a/tests/query_test.go b/tests/query_test.go index 594fc268..c9eb5903 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -179,6 +179,45 @@ func TestFillSmallerStruct(t *testing.T) { } } +func TestNot(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true}) + + result := dryDB.Not(map[string]interface{}{"name": "jinzhu"}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* <> .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("name = ?", "jinzhu1").Not("name = ?", "jinzhu2").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ AND NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not("name = ?", "jinzhu").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not([]int64{1, 2}).First(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*id.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not([]int64{}).First(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .users.\\..deleted_at. IS NULL ORDER BY").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not(User{Name: "jinzhu", Age: 18}).First(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } +} + func TestPluck(t *testing.T) { users := []*User{ GetUser("pluck-user1", Config{}), From 4e066c9590f28c71f98fb33ada0dff65b2efd7f2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 5 Jul 2020 12:23:45 +0800 Subject: [PATCH 518/881] Test Or --- chainable_api.go | 2 +- statement.go | 25 +++++++++++++++++++------ tests/query_test.go | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 7 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index e2ba44cc..acceb58f 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -142,7 +142,7 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(conds...)}}) + tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}}) } return } diff --git a/statement.go b/statement.go index c03f6f88..d6444fae 100644 --- a/statement.go +++ b/statement.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "fmt" "reflect" + "sort" "strconv" "strings" "sync" @@ -260,12 +261,24 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c conds = append(conds, clause.Eq{Column: i, Value: j}) } case map[string]string: - for i, j := range v { - conds = append(conds, clause.Eq{Column: i, Value: j}) + var keys = make([]string, 0, len(v)) + for i := range v { + keys = append(keys, i) + } + sort.Strings(keys) + + for _, key := range keys { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) } case map[string]interface{}: - for i, j := range v { - reflectValue := reflect.Indirect(reflect.ValueOf(j)) + var keys = make([]string, 0, len(v)) + for i := range v { + keys = append(keys, i) + } + sort.Strings(keys) + + for _, key := range keys { + reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) switch reflectValue.Kind() { case reflect.Slice, reflect.Array: values := make([]interface{}, reflectValue.Len()) @@ -273,9 +286,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c values[i] = reflectValue.Index(i).Interface() } - conds = append(conds, clause.IN{Column: i, Values: values}) + conds = append(conds, clause.IN{Column: key, Values: values}) default: - conds = append(conds, clause.Eq{Column: i, Value: j}) + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) } } default: diff --git a/tests/query_test.go b/tests/query_test.go index c9eb5903..5a8bbef2 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -218,6 +218,25 @@ func TestNot(t *testing.T) { } } +func TestOr(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true}) + + result := dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("name = ?", "jinzhu").Or(User{Name: "jinzhu 2", Age: 18}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ OR \\(.*name.* AND .*age.*\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("name = ?", "jinzhu").Or(map[string]interface{}{"name": "jinzhu 2", "age": 18}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ OR \\(.*age.* AND .*name.*\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } +} + func TestPluck(t *testing.T) { users := []*User{ GetUser("pluck-user1", Config{}), @@ -269,6 +288,23 @@ func TestSelect(t *testing.T) { if user.Name != result.Name { t.Errorf("Should have user Name when selected it") } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + r := dryDB.Select("name", "age").Find(&User{}) + if !regexp.MustCompile("SELECT .*name.*,.*age.* FROM .*users.*").MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", r.Statement.SQL.String()) + } + + r = dryDB.Select([]string{"name", "age"}).Find(&User{}) + if !regexp.MustCompile("SELECT .*name.*,.*age.* FROM .*users.*").MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("users").Select("COALESCE(age,?)", 42).Find(&User{}) + if !regexp.MustCompile("SELECT COALESCE\\(age,.*\\) FROM .*users.*").MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", r.Statement.SQL.String()) + } + // SELECT COALESCE(age,'42') FROM users; } func TestPluckWithSelect(t *testing.T) { From 9a4941ba7021bcbac0c85d0ca54c635eeeec554c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 5 Jul 2020 22:12:52 +0800 Subject: [PATCH 519/881] Test Order/GroupBy --- clause/select.go | 2 +- tests/group_by_test.go | 21 +++++++++++++++++++++ tests/joins_test.go | 1 + tests/query_test.go | 21 +++++++++++++++++---- 4 files changed, 40 insertions(+), 5 deletions(-) diff --git a/clause/select.go b/clause/select.go index a1b77de8..9c2bc625 100644 --- a/clause/select.go +++ b/clause/select.go @@ -14,7 +14,7 @@ func (s Select) Name() string { func (s Select) Build(builder Builder) { if len(s.Columns) > 0 { if s.Distinct { - builder.WriteString(" DISTINCT ") + builder.WriteString("DISTINCT ") } for idx, column := range s.Columns { diff --git a/tests/group_by_test.go b/tests/group_by_test.go index 6d0ed39c..7e41e94a 100644 --- a/tests/group_by_test.go +++ b/tests/group_by_test.go @@ -67,6 +67,27 @@ func TestGroupBy(t *testing.T) { t.Errorf("name should be groupby, but got %v, total should be 660, but got %v", name, total) } + var result = struct { + Name string + Total int64 + }{} + + if err := DB.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Find(&result).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if result.Name != "groupby1" || result.Total != 660 { + t.Errorf("name should be groupby, total should be 660, but got %+v", result) + } + + if err := DB.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Scan(&result).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if result.Name != "groupby1" || result.Total != 660 { + t.Errorf("name should be groupby, total should be 660, but got %+v", result) + } + var active bool if err := DB.Model(&User{}).Select("name, active, sum(age)").Where("name = ? and active = ?", "groupby", true).Group("name").Group("active").Row().Scan(&name, &active, &total); err != nil { t.Errorf("no error should happen, but got %v", err) diff --git a/tests/joins_test.go b/tests/joins_test.go index f01c8211..e54d3784 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -101,6 +101,7 @@ func TestJoinsWithSelect(t *testing.T) { DB.Save(&user) var results []result + DB.Table("users").Select("users.id, pets.id as pet_id, pets.name").Joins("left join pets on pets.user_id = users.id").Where("users.name = ?", "joins_with_select").Scan(&results) sort.Slice(results, func(i, j int) bool { diff --git a/tests/query_test.go b/tests/query_test.go index 5a8bbef2..1db490b7 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -223,17 +223,17 @@ func TestOr(t *testing.T) { result := dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) { - t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Where("name = ?", "jinzhu").Or(User{Name: "jinzhu 2", Age: 18}).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ OR \\(.*name.* AND .*age.*\\)").MatchString(result.Statement.SQL.String()) { - t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Where("name = ?", "jinzhu").Or(map[string]interface{}{"name": "jinzhu 2", "age": 18}).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ OR \\(.*age.* AND .*name.*\\)").MatchString(result.Statement.SQL.String()) { - t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) } } @@ -426,6 +426,20 @@ func TestSearchWithEmptyChain(t *testing.T) { } } +func TestOrder(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true}) + + result := dryDB.Order("age desc, name").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY age desc, name").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Order("age desc").Order("name").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY age desc,name").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } +} + func TestLimit(t *testing.T) { users := []User{ {Name: "LimitUser1", Age: 1}, @@ -461,7 +475,6 @@ func TestOffset(t *testing.T) { if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { t.Errorf("Offset should work without limit.") } - } func TestSearchWithMap(t *testing.T) { From b5725940e95cc886403b12e01cba4c941881a7be Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 6 Jul 2020 11:20:43 +0800 Subject: [PATCH 520/881] Test Select with Update Struct --- callbacks/update.go | 18 ++++++++++-------- tests/update_test.go | 26 ++++++++++++++++++++++++-- utils/tests/utils.go | 7 ++++++- 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index f84e933c..97a0e893 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -196,15 +196,17 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if !stmt.UpdatingColumn && stmt.Schema != nil { for _, field := range stmt.Schema.FieldsByDBName { if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { - now := stmt.DB.NowFunc() - assignValue(field, now) + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + now := stmt.DB.NowFunc() + assignValue(field, now) - if field.AutoUpdateTime == schema.UnixNanosecond { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) - } else if field.DataType == schema.Time { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) - } else { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) + if field.AutoUpdateTime == schema.UnixNanosecond { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) + } else if field.DataType == schema.Time { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) + } else { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) + } } } } diff --git a/tests/update_test.go b/tests/update_test.go index d56e3f76..2ff150dd 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -8,6 +8,7 @@ import ( "time" "gorm.io/gorm" + "gorm.io/gorm/utils" . "gorm.io/gorm/utils/tests" ) @@ -267,6 +268,22 @@ func TestSelectWithUpdate(t *testing.T) { }) AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages") + + DB.Model(&result).Select("Name", "Age").Updates(User{Name: "update_with_select"}) + if result.Age != 0 || result.Name != "update_with_select" { + t.Fatalf("Failed to update struct with select, got %+v", result) + } + AssertObjEqual(t, result, user, "UpdatedAt") + + var result3 User + DB.First(&result3, result.ID) + AssertObjEqual(t, result, result3, "Name", "Age", "UpdatedAt") + + DB.Model(&result).Select("Name", "Age", "UpdatedAt").Updates(User{Name: "update_with_select"}) + + if utils.AssertEqual(result.UpdatedAt, user.UpdatedAt) { + t.Fatalf("Update struct should update UpdatedAt, was %+v, got %+v", result.UpdatedAt, user.UpdatedAt) + } } func TestSelectWithUpdateWithMap(t *testing.T) { @@ -290,7 +307,7 @@ func TestSelectWithUpdateWithMap(t *testing.T) { "Friends": user2.Friends, } - DB.Model(&result).Select("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Updates(updateValues) + DB.Model(&result).Omit("name", "updated_at").Updates(updateValues) var result2 User DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) @@ -427,11 +444,16 @@ func TestSelectWithUpdateColumn(t *testing.T) { var result User DB.First(&result, user.ID) - DB.Model(&result).Select("Name").UpdateColumns(updateValues) + + time.Sleep(time.Second) + lastUpdatedAt := result.UpdatedAt + DB.Model(&result).Select("Name").Updates(updateValues) var result2 User DB.First(&result2, user.ID) + AssertEqual(t, lastUpdatedAt, result2.UpdatedAt) + if result2.Name == user.Name || result2.Age != user.Age { t.Errorf("Should only update users with name column") } diff --git a/utils/tests/utils.go b/utils/tests/utils.go index 5248e620..a44eb548 100644 --- a/utils/tests/utils.go +++ b/utils/tests/utils.go @@ -84,15 +84,20 @@ func AssertEqual(t *testing.T, got, expect interface{}) { if reflect.ValueOf(got).Kind() == reflect.Struct { if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { + exported := false for i := 0; i < reflect.ValueOf(got).NumField(); i++ { if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { + exported = true field := reflect.ValueOf(got).Field(i) t.Run(fieldStruct.Name, func(t *testing.T) { AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) }) } } - return + + if exported { + return + } } } From de482f57ff48f18e5ef8b98ac687c02b60db180c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 6 Jul 2020 15:47:33 +0800 Subject: [PATCH 521/881] Test raw sql with gorm.Expr --- tests/sql_builder_test.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index b78c2484..634ee1cb 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -76,10 +76,19 @@ func TestRaw(t *testing.T) { t.Errorf("Raw with Rows should find one record with name 3") } - DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name}) + DB.Exec("update users set name=? where name in (?)", "jinzhu-raw", []string{user1.Name, user2.Name, user3.Name}) if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound { t.Error("Raw sql to update records") } + + DB.Exec("update users set age=? where name = ?", gorm.Expr("age * ? + ?", 2, 10), "jinzhu-raw") + + var age int + DB.Raw("select sum(age) from users where name = ?", "jinzhu-raw").Scan(&age) + + if age != ((1+10+20)*2 + 30) { + t.Errorf("Invalid age, got %v", age) + } } func TestRowsWithGroup(t *testing.T) { From 619cd332ec3a629177fd982726da3506d725349b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 8 Jul 2020 17:59:40 +0800 Subject: [PATCH 522/881] Add index priority supports --- schema/index.go | 13 +++++++++++++ schema/index_test.go | 22 ++++++++++++++-------- schema/relationship.go | 2 +- tests/named_polymorphic_test.go | 4 ++-- utils/tests/models.go | 4 ++-- 5 files changed, 32 insertions(+), 13 deletions(-) diff --git a/schema/index.go b/schema/index.go index a0a71d2c..fb7ea501 100644 --- a/schema/index.go +++ b/schema/index.go @@ -1,6 +1,7 @@ package schema import ( + "sort" "strconv" "strings" ) @@ -20,6 +21,7 @@ type IndexOption struct { Sort string // DESC, ASC Collate string Length int + priority int } // ParseIndexes parse schema indexes @@ -43,7 +45,12 @@ func (schema *Schema) ParseIndexes() map[string]Index { if idx.Comment == "" { idx.Comment = index.Comment } + idx.Fields = append(idx.Fields, index.Fields...) + sort.Slice(idx.Fields, func(i, j int) bool { + return idx.Fields[i].priority < idx.Fields[j].priority + }) + indexes[index.Name] = idx } } @@ -101,6 +108,11 @@ func parseFieldIndexes(field *Field) (indexes []Index) { settings["CLASS"] = "UNIQUE" } + priority, err := strconv.Atoi(settings["PRIORITY"]) + if err != nil { + priority = 10 + } + indexes = append(indexes, Index{ Name: name, Class: settings["CLASS"], @@ -113,6 +125,7 @@ func parseFieldIndexes(field *Field) (indexes []Index) { Sort: settings["SORT"], Collate: settings["COLLATE"], Length: length, + priority: priority, }}, }) } diff --git a/schema/index_test.go b/schema/index_test.go index f6c3d247..dc1fb43b 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -17,7 +17,7 @@ type UserIndex struct { Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"` Age int64 `gorm:"index:profile,expression:ABS(age)"` OID int64 `gorm:"index:idx_id;index:idx_oid,unique"` - MemberNumber string `gorm:"index:idx_id"` + MemberNumber string `gorm:"index:idx_id,priority:1"` } func TestParseIndex(t *testing.T) { @@ -29,18 +29,19 @@ func TestParseIndex(t *testing.T) { results := map[string]schema.Index{ "idx_user_indices_name": { Name: "idx_user_indices_name", - Fields: []schema.IndexOption{{}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name"}}}, }, "idx_name": { Name: "idx_name", Class: "UNIQUE", - Fields: []schema.IndexOption{{}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2"}}}, }, "idx_user_indices_name3": { Name: "idx_user_indices_name3", Type: "btree", Where: "name3 != 'jinzhu'", Fields: []schema.IndexOption{{ + Field: &schema.Field{Name: "Name3"}, Sort: "desc", Collate: "utf8", Length: 10, @@ -49,31 +50,32 @@ func TestParseIndex(t *testing.T) { "idx_user_indices_name4": { Name: "idx_user_indices_name4", Class: "UNIQUE", - Fields: []schema.IndexOption{{}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4"}}}, }, "idx_user_indices_name5": { Name: "idx_user_indices_name5", Class: "FULLTEXT", Comment: "hello , world", Where: "age > 10", - Fields: []schema.IndexOption{{}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name5"}}}, }, "profile": { Name: "profile", Comment: "hello , world", Where: "age > 10", - Fields: []schema.IndexOption{{}, { + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name6"}}, { + Field: &schema.Field{Name: "Age"}, Expression: "ABS(age)", }}, }, "idx_id": { Name: "idx_id", - Fields: []schema.IndexOption{{}, {}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID"}}}, }, "idx_oid": { Name: "idx_oid", Class: "UNIQUE", - Fields: []schema.IndexOption{{}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}}, }, } @@ -96,6 +98,10 @@ func TestParseIndex(t *testing.T) { for idx, ef := range result.Fields { rf := v.Fields[idx] + if rf.Field.Name != ef.Field.Name { + t.Fatalf("index field should equal, expects %v, got %v", rf.Field.Name, ef.Field.Name) + } + for _, name := range []string{"Expression", "Sort", "Collate", "Length"} { if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() { t.Errorf( diff --git a/schema/relationship.go b/schema/relationship.go index 0967f8c8..91c2ca8d 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -130,7 +130,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"], } - if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { + if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok { relation.Polymorphic.Value = strings.TrimSpace(value) } diff --git a/tests/named_polymorphic_test.go b/tests/named_polymorphic_test.go index cbe236b5..956f3a7e 100644 --- a/tests/named_polymorphic_test.go +++ b/tests/named_polymorphic_test.go @@ -9,8 +9,8 @@ import ( type Hamster struct { Id int Name string - PreferredToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_preferred"` - OtherToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_other"` + PreferredToy Toy `gorm:"polymorphic:Owner;polymorphicValue:hamster_preferred"` + OtherToy Toy `gorm:"polymorphic:Owner;polymorphicValue:hamster_other"` } func TestNamedPolymorphic(t *testing.T) { diff --git a/utils/tests/models.go b/utils/tests/models.go index 021b0229..2c5e71c0 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -24,8 +24,8 @@ type User struct { ManagerID *uint Manager *User Team []User `gorm:"foreignkey:ManagerID"` - Languages []Language `gorm:"many2many:UserSpeak"` - Friends []*User `gorm:"many2many:user_friends"` + Languages []Language `gorm:"many2many:UserSpeak;"` + Friends []*User `gorm:"many2many:user_friends;"` Active bool } From 30188e7aa4b59759f5048fa4438c4e79b9e7122f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 8 Jul 2020 18:15:45 +0800 Subject: [PATCH 523/881] CHECK constraint without parentheses --- migrator/migrator.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 5edd800e..169701e4 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -182,7 +182,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { } for _, chk := range stmt.Schema.ParseCheckConstraints() { - createTableSQL += "CONSTRAINT ? CHECK ?," + createTableSQL += "CONSTRAINT ? CHECK (?)," values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) } @@ -371,7 +371,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { checkConstraints := stmt.Schema.ParseCheckConstraints() if chk, ok := checkConstraints[name]; ok { return m.DB.Exec( - "ALTER TABLE ? ADD CONSTRAINT ? CHECK ?", + "ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)", clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, ).Error } From e1084e78d0acea979520458ce16f2bc17141ba59 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 8 Jul 2020 18:50:49 +0800 Subject: [PATCH 524/881] Allow customize AutoIncrement for primary field --- schema/schema.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index 72bc6544..b85bbd7e 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -187,8 +187,11 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if !field.HasDefaultValue || field.DefaultValueInterface != nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } - field.HasDefaultValue = true - field.AutoIncrement = true + + if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok { + field.HasDefaultValue = true + field.AutoIncrement = true + } } } From 2ae0653af2bc19cd31f687e797b189c85f0ac3f6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 9 Jul 2020 09:03:48 +0800 Subject: [PATCH 525/881] Fix ambiguous column when using same column name in join table, close #3120 --- association.go | 20 ++++++++++---------- callbacks/delete.go | 4 ++-- callbacks/preload.go | 4 ++-- schema/relationship.go | 4 +++- schema/utils.go | 12 +++++++++--- soft_delete.go | 4 ++-- statement.go | 9 +++++++++ tests/go.mod | 2 +- tests/multi_primary_keys_test.go | 4 ++-- 9 files changed, 40 insertions(+), 23 deletions(-) diff --git a/association.go b/association.go index 928dcf3e..eeb11efe 100644 --- a/association.go +++ b/association.go @@ -122,7 +122,7 @@ func (association *Association) Replace(values ...interface{}) error { ) if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { - if column, values := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { + if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { tx.Not(clause.IN{Column: column, Values: values}) } } @@ -138,7 +138,7 @@ func (association *Association) Replace(values ...interface{}) error { } if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { - column, values := schema.ToQueryValues(foreignKeys, pvs) + column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap) } case schema.Many2Many: @@ -164,14 +164,14 @@ func (association *Association) Replace(values ...interface{}) error { } _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) - if column, values := schema.ToQueryValues(joinPrimaryKeys, pvs); len(values) > 0 { + if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { tx.Where(clause.IN{Column: column, Values: values}) } else { return ErrorPrimaryKeyRequired } _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) - if relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, rvs); len(relValues) > 0 { + if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) } @@ -208,11 +208,11 @@ func (association *Association) Delete(values ...interface{}) error { tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) - pcolumn, pvalues := schema.ToQueryValues(rel.Schema.PrimaryFieldDBNames, pvs) + pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields) - relColumn, relValues := schema.ToQueryValues(foreignKeys, rvs) + relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error @@ -220,11 +220,11 @@ func (association *Association) Delete(values ...interface{}) error { tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) - pcolumn, pvalues := schema.ToQueryValues(foreignKeys, pvs) + pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) - relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, rvs) + relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error @@ -250,11 +250,11 @@ func (association *Association) Delete(values ...interface{}) error { } _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) - pcolumn, pvalues := schema.ToQueryValues(joinPrimaryKeys, pvs) + pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) - relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, rvs) + relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error diff --git a/callbacks/delete.go b/callbacks/delete.go index dea8bb5e..ff0f601a 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -35,7 +35,7 @@ func Delete(db *gorm.DB) { if db.Statement.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) - column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) + column, values := schema.ToQueryValues(db.Statement.Schema.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) @@ -43,7 +43,7 @@ func Delete(db *gorm.DB) { if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) - column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) + column, values = schema.ToQueryValues(db.Statement.Schema.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) diff --git a/callbacks/preload.go b/callbacks/preload.go index a9907d68..cd09a6d6 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -49,7 +49,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } joinResults := rel.JoinTable.MakeSlice().Elem() - column, values := schema.ToQueryValues(joinForeignKeys, joinForeignValues) + column, values := schema.ToQueryValues(rel.JoinTable.Table, joinForeignKeys, joinForeignValues) tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()) // convert join identity map to relation identity map @@ -93,7 +93,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } reflectResults := rel.FieldSchema.MakeSlice().Elem() - column, values := schema.ToQueryValues(relForeignKeys, foreignValues) + column, values := schema.ToQueryValues(rel.FieldSchema.Table, relForeignKeys, foreignValues) for _, cond := range conds { if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { diff --git a/schema/relationship.go b/schema/relationship.go index 91c2ca8d..e3ff0307 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -462,10 +462,12 @@ func (rel *Relationship) ParseConstraint() *Constraint { } func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) { + table := rel.FieldSchema.Table foreignFields := []*Field{} relForeignKeys := []string{} if rel.JoinTable != nil { + table = rel.JoinTable.Table for _, ref := range rel.References { if ref.OwnPrimaryKey { foreignFields = append(foreignFields, ref.PrimaryKey) @@ -500,7 +502,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds [] } _, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) - column, values := ToQueryValues(relForeignKeys, foreignValues) + column, values := ToQueryValues(table, relForeignKeys, foreignValues) conds = append(conds, clause.IN{Column: column, Values: values}) return diff --git a/schema/utils.go b/schema/utils.go index da236a18..defa83af 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -5,6 +5,7 @@ import ( "regexp" "strings" + "gorm.io/gorm/clause" "gorm.io/gorm/utils" ) @@ -164,18 +165,23 @@ func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field) } // ToQueryValues to query values -func ToQueryValues(foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) { +func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) { queryValues := make([]interface{}, len(foreignValues)) if len(foreignKeys) == 1 { for idx, r := range foreignValues { queryValues[idx] = r[0] } - return foreignKeys[0], queryValues + return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues } else { + columns := make([]clause.Column, len(foreignKeys)) + for idx, key := range foreignKeys { + columns[idx] = clause.Column{Table: table, Name: key} + } + for idx, r := range foreignValues { queryValues[idx] = r } + return columns, queryValues } - return foreignKeys, queryValues } diff --git a/soft_delete.go b/soft_delete.go index 4ffceba6..e3e6e960 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -58,7 +58,7 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) { if stmt.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) - column, values := schema.ToQueryValues(stmt.Schema.PrimaryFieldDBNames, queryValues) + column, values := schema.ToQueryValues(stmt.Schema.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) @@ -66,7 +66,7 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) { if stmt.Dest != stmt.Model && stmt.Model != nil { _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) - column, values = schema.ToQueryValues(stmt.Schema.PrimaryFieldDBNames, queryValues) + column, values = schema.ToQueryValues(stmt.Schema.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) diff --git a/statement.go b/statement.go index d6444fae..036b8297 100644 --- a/statement.go +++ b/statement.go @@ -107,6 +107,15 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { writer.WriteString(" AS ") stmt.DB.Dialector.QuoteTo(writer, v.Alias) } + case []clause.Column: + writer.WriteByte('(') + for idx, d := range v { + if idx > 0 { + writer.WriteString(",") + } + stmt.QuoteTo(writer, d) + } + writer.WriteByte(')') case string: stmt.DB.Dialector.QuoteTo(writer, v) case []string: diff --git a/tests/go.mod b/tests/go.mod index 3b17feac..3a5b4224 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,7 +6,7 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.2.8 + gorm.io/driver/mysql v0.2.9 gorm.io/driver/postgres v0.2.5 gorm.io/driver/sqlite v1.0.8 gorm.io/driver/sqlserver v0.2.4 diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go index 617010c5..051e3ee2 100644 --- a/tests/multi_primary_keys_test.go +++ b/tests/multi_primary_keys_test.go @@ -13,7 +13,7 @@ type Blog struct { Locale string `gorm:"primary_key"` Subject string Body string - Tags []Tag `gorm:"many2many:blogs_tags;"` + Tags []Tag `gorm:"many2many:blog_tags;"` SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;References:id"` LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;References:id"` } @@ -22,7 +22,7 @@ type Tag struct { ID uint `gorm:"primary_key"` Locale string `gorm:"primary_key"` Value string - Blogs []*Blog `gorm:"many2many:blogs_tags"` + Blogs []*Blog `gorm:"many2many:blog_tags"` } func compareTags(tags []Tag, contents []string) bool { From 0790ff69373366a536bb183b2d1646d14af63594 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 9 Jul 2020 09:42:27 +0800 Subject: [PATCH 526/881] Update tests helper to check time --- utils/tests/utils.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/tests/utils.go b/utils/tests/utils.go index a44eb548..0067d5c6 100644 --- a/utils/tests/utils.go +++ b/utils/tests/utils.go @@ -27,7 +27,7 @@ func AssertEqual(t *testing.T, got, expect interface{}) { if curTime, ok := got.(time.Time); ok { format := "2006-01-02T15:04:05Z07:00" - if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) && curTime.Truncate(time.Second).Format(format) != expect.(time.Time).Truncate(time.Second).Format(format) { + if curTime.Round(time.Second).UTC().Format(format) != expect.(time.Time).Round(time.Second).UTC().Format(format) && curTime.Truncate(time.Second).UTC().Format(format) != expect.(time.Time).Truncate(time.Second).UTC().Format(format) { t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime) } } else if fmt.Sprint(got) != fmt.Sprint(expect) { From a8655f79477cc5d287a136d369141b5b9a468ba7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 9 Jul 2020 12:15:35 +0800 Subject: [PATCH 527/881] Fix auto select with smaller struct for slices --- callbacks/query.go | 24 +++++++++++++++++------- tests/query_test.go | 18 ++++++++++++++++++ 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 4b7f5bd5..9601f9bd 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -64,14 +64,24 @@ func BuildQuerySQL(db *gorm.DB) { clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} } } - } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() && db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType { - stmt := gorm.Statement{DB: db} - // smaller struct - if err := stmt.Parse(db.Statement.Dest); err == nil { - clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) + } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() { + smallerStruct := false + switch db.Statement.ReflectValue.Kind() { + case reflect.Struct: + smallerStruct = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType + case reflect.Slice: + smallerStruct = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType + } - for idx, dbName := range stmt.Schema.DBNames { - clauseSelect.Columns[idx] = clause.Column{Name: dbName} + if smallerStruct { + stmt := gorm.Statement{DB: db} + // smaller struct + if err := stmt.Parse(db.Statement.Dest); err == nil && stmt.Schema.ModelType != db.Statement.Schema.ModelType { + clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) + + for idx, dbName := range stmt.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Name: dbName} + } } } } diff --git a/tests/query_test.go b/tests/query_test.go index 1db490b7..62005e3a 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -177,6 +177,24 @@ func TestFillSmallerStruct(t *testing.T) { if !regexp.MustCompile("SELECT .*id.*name.*updated_at.*created_at.* FROM .*users").MatchString(result.Statement.SQL.String()) { t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String()) } + + result = DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&User{}, user.ID) + + if regexp.MustCompile("SELECT .*name.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include selected names, but got %v", result.Statement.SQL.String()) + } + + result = DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&[]User{}, user.ID) + + if regexp.MustCompile("SELECT .*name.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include selected names, but got %v", result.Statement.SQL.String()) + } + + result = DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&[]*User{}, user.ID) + + if regexp.MustCompile("SELECT .*name.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include selected names, but got %v", result.Statement.SQL.String()) + } } func TestNot(t *testing.T) { From d04984323f4545b39f39767629d37d4c4492b690 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 9 Jul 2020 22:02:29 +0800 Subject: [PATCH 528/881] Add stale for v1 action --- .github/workflows/missing_playground.yml | 21 +++++++++++++++++++++ .github/workflows/stale.yml | 12 +++++------- 2 files changed, 26 insertions(+), 7 deletions(-) create mode 100644 .github/workflows/missing_playground.yml diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml new file mode 100644 index 00000000..6fb714ca --- /dev/null +++ b/.github/workflows/missing_playground.yml @@ -0,0 +1,21 @@ +name: "Close Missing Playground issues" +on: + schedule: + - cron: "*/10 * * * *" + +jobs: + stale: + runs-on: ubuntu-latest + env: + ACTIONS_STEP_DEBUG: true + steps: + - name: Close Stale Issues + uses: actions/stale@v3.0.7 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + stale-issue-message: "This issue has been automatically marked as stale as it missing playground pull request link, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details, it will be closed in 2 days if no further activity occurs." + stale-issue-label: "status:stale" + days-before-stale: 0 + days-before-close: 2 + remove-stale-when-updated: true + only-labels: "type:missing reproduction steps" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 6fb714ca..7a304eb7 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -1,4 +1,4 @@ -name: "Close Missing Playground issues" +name: "Stale" on: schedule: - cron: "*/10 * * * *" @@ -13,9 +13,7 @@ jobs: uses: actions/stale@v3.0.7 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "This issue has been automatically marked as stale as it missing playground pull request link, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details, it will be closed in 2 days if no further activity occurs." - stale-issue-label: "status:stale" - days-before-stale: 0 - days-before-close: 2 - remove-stale-when-updated: true - only-labels: "type:missing reproduction steps" + stale-issue-message: "This issue will be automatically closed because it is marked as GORM V1 issue, we have released the public testing GORM V2 release and its documents https://v2.gorm.io/docs/ already, the testing release has been used in some production services for a while, and going to release the final version in following weeks, we are still actively collecting feedback before it, please open a new issue for any suggestion or problem, thank you\n\n Also check out https://github.com/go-gorm/gorm/wiki/GORM-V2-Release-Note-Draft for how to use the public testing version and its changelog" + stale-issue-label: "status:gorm_v1" + days-before-stale: 30 + days-before-close: 0 From c091cd6aa42aa8d7278f02654ac55adf2b6a3202 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 9 Jul 2020 22:14:11 +0800 Subject: [PATCH 529/881] Update stale action --- .github/workflows/stale.yml | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 7a304eb7..f9c1bece 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -1,7 +1,7 @@ name: "Stale" on: schedule: - - cron: "*/10 * * * *" + - cron: "0 2 * * *" jobs: stale: @@ -13,7 +13,10 @@ jobs: uses: actions/stale@v3.0.7 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "This issue will be automatically closed because it is marked as GORM V1 issue, we have released the public testing GORM V2 release and its documents https://v2.gorm.io/docs/ already, the testing release has been used in some production services for a while, and going to release the final version in following weeks, we are still actively collecting feedback before it, please open a new issue for any suggestion or problem, thank you\n\n Also check out https://github.com/go-gorm/gorm/wiki/GORM-V2-Release-Note-Draft for how to use the public testing version and its changelog" - stale-issue-label: "status:gorm_v1" - days-before-stale: 30 - days-before-close: 0 + stale-issue-message: "This issue has been automatically marked as stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 30 days" + days-before-stale: 60 + days-before-close: 30 + stale-issue-label: "status:stale" + exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request' + stale-pr-label: 'status:stale' + exempt-pr-labels: 'type:feature,type:with reproduction steps,type:has pull request' From bc3728a18f380f28a007ba1100993e2c9f7e0288 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 10 Jul 2020 07:14:37 +0800 Subject: [PATCH 530/881] Fix concurrent map writes, close #3126 --- schema/schema.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index b85bbd7e..66e02443 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -207,13 +207,13 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } - cacheStore.Store(modelType, schema) - - // parse relations for unidentified fields - for _, field := range schema.Fields { - if field.DataType == "" && field.Creatable { - if schema.parseRelation(field); schema.err != nil { - return schema, schema.err + if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded { + // parse relations for unidentified fields + for _, field := range schema.Fields { + if field.DataType == "" && field.Creatable { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err + } } } } From bba569af2b6e13484c78773f85dee0bd585c50a7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 10 Jul 2020 12:28:24 +0800 Subject: [PATCH 531/881] Add NamedArg support --- README.md | 2 +- callbacks.go | 1 - chainable_api.go | 7 ++++- clause/expression.go | 59 ++++++++++++++++++++++++++++++++++++ clause/expression_test.go | 50 ++++++++++++++++++++++++++++++ finisher_api.go | 8 ++++- statement.go | 25 ++++++--------- tests/named_argument_test.go | 57 ++++++++++++++++++++++++++++++++++ 8 files changed, 190 insertions(+), 19 deletions(-) create mode 100644 tests/named_argument_test.go diff --git a/README.md b/README.md index 140c0d28..b51297c4 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point * Context, Prepared Statment Mode, DryRun Mode * Batch Insert, FindInBatches, Find To Map -* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints +* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg * Composite Primary Key * Auto Migrations * Logger diff --git a/callbacks.go b/callbacks.go index 5e7933af..c917a678 100644 --- a/callbacks.go +++ b/callbacks.go @@ -107,7 +107,6 @@ func (p *processor) Execute(db *DB) { if !stmt.DB.DryRun { stmt.SQL.Reset() stmt.Vars = nil - stmt.NamedVars = nil } } diff --git a/chainable_api.go b/chainable_api.go index acceb58f..3e509f12 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -265,6 +265,11 @@ func (db *DB) Unscoped() (tx *DB) { func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.SQL = strings.Builder{} - clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) + + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) + } else { + clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) + } return } diff --git a/clause/expression.go b/clause/expression.go index ecf8ba85..4d5e328b 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -1,6 +1,7 @@ package clause import ( + "database/sql" "database/sql/driver" "reflect" ) @@ -62,6 +63,64 @@ func (expr Expr) Build(builder Builder) { } } +// NamedExpr raw expression for named expr +type NamedExpr struct { + SQL string + Vars []interface{} +} + +// Build build raw expression +func (expr NamedExpr) Build(builder Builder) { + var ( + idx int + inName bool + namedMap = make(map[string]interface{}, len(expr.Vars)) + ) + + for _, v := range expr.Vars { + switch value := v.(type) { + case sql.NamedArg: + namedMap[value.Name] = value.Value + case map[string]interface{}: + for k, v := range value { + namedMap[k] = v + } + } + } + + name := make([]byte, 0, 10) + + for _, v := range []byte(expr.SQL) { + if v == '@' && !inName { + inName = true + name = []byte{} + } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' { + if inName { + if nv, ok := namedMap[string(name)]; ok { + builder.AddVar(builder, nv) + } else { + builder.WriteByte('@') + builder.WriteString(string(name)) + } + inName = false + } + + builder.WriteByte(v) + } else if v == '?' { + builder.AddVar(builder, expr.Vars[idx]) + idx++ + } else if inName { + name = append(name, v) + } else { + builder.WriteByte(v) + } + } + + if inName { + builder.AddVar(builder, namedMap[string(name)]) + } +} + // IN Whether a value is within a set of values type IN struct { Column interface{} diff --git a/clause/expression_test.go b/clause/expression_test.go index 3059aea6..17af737d 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -1,7 +1,9 @@ package clause_test import ( + "database/sql" "fmt" + "reflect" "sync" "testing" @@ -33,3 +35,51 @@ func TestExpr(t *testing.T) { }) } } + +func TestNamedExpr(t *testing.T) { + results := []struct { + SQL string + Result string + Vars []interface{} + ExpectedVars []interface{} + }{{ + SQL: "create table ? (? ?, ? ?)", + Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}}, + Result: "create table `users` (`id` int, `name` text)", + }, { + SQL: "name1 = @name AND name2 = @name", + Vars: []interface{}{sql.Named("name", "jinzhu")}, + Result: "name1 = ? AND name2 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, + }, { + SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1", + Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, + Result: "name1 = ? AND name2 = ? AND name3 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, + }, { + SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1", + Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu2"}}, + Result: "name1 = ? AND name2 = ? AND name3 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, + }, { + SQL: "@@test AND name1 = @name1 AND name2 = @name2 AND name3 = @name1 @notexist", + Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, + Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, + }} + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + clause.NamedExpr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) + if stmt.SQL.String() != result.Result { + t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) + } + + if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) { + t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars) + } + }) + } +} diff --git a/finisher_api.go b/finisher_api.go index 25c56e49..d70b3cd0 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -453,7 +453,13 @@ func (db *DB) RollbackTo(name string) *DB { func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.SQL = strings.Builder{} - clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) + + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) + } else { + clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) + } + tx.callbacks.Raw().Execute(tx) return } diff --git a/statement.go b/statement.go index 036b8297..00feeac5 100644 --- a/statement.go +++ b/statement.go @@ -38,7 +38,6 @@ type Statement struct { UpdatingColumn bool SQL strings.Builder Vars []interface{} - NamedVars []sql.NamedArg CurDestIndex int attrs []interface{} assigns []interface{} @@ -148,14 +147,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { switch v := v.(type) { case sql.NamedArg: - if len(v.Name) > 0 { - stmt.NamedVars = append(stmt.NamedVars, v) - writer.WriteByte('@') - writer.WriteString(v.Name) - } else { - stmt.Vars = append(stmt.Vars, v.Value) - stmt.DB.Dialector.BindVarTo(writer, stmt, v.Value) - } + stmt.Vars = append(stmt.Vars, v.Value) case clause.Column, clause.Table: stmt.QuoteTo(writer, v) case clause.Expr: @@ -234,16 +226,19 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { // BuildCondition build condition func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) { - if sql, ok := query.(string); ok { + if s, ok := query.(string); ok { // if it is a number, then treats it as primary key - if _, err := strconv.Atoi(sql); err != nil { - if sql == "" && len(args) == 0 { + if _, err := strconv.Atoi(s); err != nil { + if s == "" && len(args) == 0 { return - } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { + } else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { // looks like a where condition - return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} + return []clause.Expression{clause.Expr{SQL: s, Vars: args}} + } else if len(args) > 0 && strings.Contains(s, "@") { + // looks like a named query + return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}} } else if len(args) == 1 { - return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}} + return []clause.Expression{clause.Eq{Column: s, Value: args[0]}} } } } diff --git a/tests/named_argument_test.go b/tests/named_argument_test.go new file mode 100644 index 00000000..60f5a535 --- /dev/null +++ b/tests/named_argument_test.go @@ -0,0 +1,57 @@ +package tests_test + +import ( + "database/sql" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestNamedArg(t *testing.T) { + type NamedUser struct { + gorm.Model + Name1 string + Name2 string + Name3 string + } + + DB.Migrator().DropTable(&NamedUser{}) + DB.AutoMigrate(&NamedUser{}) + + namedUser := NamedUser{Name1: "jinzhu1", Name2: "jinzhu2", Name3: "jinzhu3"} + DB.Create(&namedUser) + + var result NamedUser + DB.First(&result, "name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")) + + AssertEqual(t, result, namedUser) + + var result2 NamedUser + DB.Where("name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")).First(&result2) + + AssertEqual(t, result2, namedUser) + + var result3 NamedUser + DB.Where("name1 = @name OR name2 = @name OR name3 = @name", map[string]interface{}{"name": "jinzhu2"}).First(&result3) + + AssertEqual(t, result3, namedUser) + + var result4 NamedUser + if err := DB.Raw("SELECT * FROM named_users WHERE name1 = @name OR name2 = @name2 OR name3 = @name", sql.Named("name", "jinzhu-none"), sql.Named("name2", "jinzhu2")).Find(&result4).Error; err != nil { + t.Errorf("failed to update with named arg") + } + + AssertEqual(t, result4, namedUser) + + if err := DB.Exec("UPDATE named_users SET name1 = @name, name2 = @name2, name3 = @name", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Error; err != nil { + t.Errorf("failed to update with named arg") + } + + var result5 NamedUser + if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Find(&result5).Error; err != nil { + t.Errorf("failed to update with named arg") + } + + AssertEqual(t, result4, namedUser) +} From c0319f6eed8c56ba09d0b6674d5bcd5e062b9981 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 10 Jul 2020 12:52:01 +0800 Subject: [PATCH 532/881] Test map with named argument for raw sql --- tests/named_argument_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/named_argument_test.go b/tests/named_argument_test.go index 60f5a535..56fad5f4 100644 --- a/tests/named_argument_test.go +++ b/tests/named_argument_test.go @@ -49,7 +49,7 @@ func TestNamedArg(t *testing.T) { } var result5 NamedUser - if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Find(&result5).Error; err != nil { + if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result5).Error; err != nil { t.Errorf("failed to update with named arg") } From 33c48611b6614667c231307833c84899436e076a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 10 Jul 2020 13:08:15 +0800 Subject: [PATCH 533/881] Fix customize table with Delete, close #3129 --- callbacks/delete.go | 4 ++-- soft_delete.go | 4 ++-- tests/delete_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index ff0f601a..51a33bf0 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -35,7 +35,7 @@ func Delete(db *gorm.DB) { if db.Statement.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) - column, values := schema.ToQueryValues(db.Statement.Schema.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) + column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) @@ -43,7 +43,7 @@ func Delete(db *gorm.DB) { if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) - column, values = schema.ToQueryValues(db.Statement.Schema.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) + column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) diff --git a/soft_delete.go b/soft_delete.go index e3e6e960..6b88b1a5 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -58,7 +58,7 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) { if stmt.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) - column, values := schema.ToQueryValues(stmt.Schema.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) + column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) @@ -66,7 +66,7 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) { if stmt.Dest != stmt.Model && stmt.Model != nil { _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) - column, values = schema.ToQueryValues(stmt.Schema.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) + column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) diff --git a/tests/delete_test.go b/tests/delete_test.go index b853a9d3..3d461f65 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -45,6 +45,49 @@ func TestDelete(t *testing.T) { } } +func TestDeleteWithTable(t *testing.T) { + type UserWithDelete struct { + gorm.Model + Name string + } + + DB.Table("deleted_users").Migrator().DropTable(UserWithDelete{}) + DB.Table("deleted_users").AutoMigrate(UserWithDelete{}) + + user := UserWithDelete{Name: "delete1"} + DB.Table("deleted_users").Create(&user) + + var result UserWithDelete + if err := DB.Table("deleted_users").First(&result).Error; err != nil { + t.Errorf("failed to find deleted user, got error %v", err) + } + + AssertEqual(t, result, user) + + if err := DB.Table("deleted_users").Delete(&result).Error; err != nil { + t.Errorf("failed to delete user, got error %v", err) + } + + var result2 UserWithDelete + if err := DB.Table("deleted_users").First(&result2, user.ID).Error; !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should raise record not found error, but got error %v", err) + } + + var result3 UserWithDelete + if err := DB.Table("deleted_users").Unscoped().First(&result3, user.ID).Error; err != nil { + t.Fatalf("failed to find record, got error %v", err) + } + + if err := DB.Table("deleted_users").Unscoped().Delete(&result).Error; err != nil { + t.Errorf("failed to delete user with unscoped, got error %v", err) + } + + var result4 UserWithDelete + if err := DB.Table("deleted_users").Unscoped().First(&result4, user.ID).Error; !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should raise record not found error, but got error %v", err) + } +} + func TestInlineCondDelete(t *testing.T) { user1 := *GetUser("inline_delete_1", Config{}) user2 := *GetUser("inline_delete_2", Config{}) From d4b462a351949f7a7002147c13f69bb3e5ab63e5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 10 Jul 2020 21:11:28 +0800 Subject: [PATCH 534/881] Fix alias keyword with Table, close #3104 --- chainable_api.go | 11 +++++++++++ statement.go | 8 +++++++- tests/sql_builder_test.go | 16 ++++++++++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/chainable_api.go b/chainable_api.go index 3e509f12..7ee20324 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -2,6 +2,7 @@ package gorm import ( "fmt" + "regexp" "strings" "gorm.io/gorm/clause" @@ -40,9 +41,19 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { return } +var tableRegexp = regexp.MustCompile("(?i).+ AS (\\w+)\\s*$") + // Table specify the table you would like to run db operations func (db *DB) Table(name string) (tx *DB) { tx = db.getInstance() + if strings.Contains(name, " ") { + tx.Statement.FullTable = name + if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 { + tx.Statement.Table = results[1] + return + } + } + tx.Statement.Table = name return } diff --git a/statement.go b/statement.go index 00feeac5..142c7c31 100644 --- a/statement.go +++ b/statement.go @@ -19,6 +19,7 @@ import ( // Statement statement type Statement struct { *DB + FullTable string Table string Model interface{} Unscoped bool @@ -69,7 +70,11 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { switch v := field.(type) { case clause.Table: if v.Name == clause.CurrentTable { - stmt.DB.Dialector.QuoteTo(writer, stmt.Table) + if stmt.FullTable != "" { + writer.WriteString(stmt.FullTable) + } else { + stmt.DB.Dialector.QuoteTo(writer, stmt.Table) + } } else if v.Raw { writer.WriteString(v.Name) } else { @@ -374,6 +379,7 @@ func (stmt *Statement) Build(clauses ...string) { func (stmt *Statement) Parse(value interface{}) (err error) { if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { stmt.Table = stmt.Schema.Table + stmt.FullTable = stmt.Schema.Table } return err } diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 634ee1cb..e6038947 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -24,6 +24,22 @@ func TestRow(t *testing.T) { if age != 10 { t.Errorf("Scan with Row, age expects: %v, got %v", user2.Age, age) } + + table := "gorm.users" + if DB.Dialector.Name() != "mysql" { + table = "users" // other databases doesn't support select with `database.table` + } + + DB.Table(table).Where(map[string]interface{}{"name": user2.Name}).Update("age", 20) + + row = DB.Table(table+" as u").Where("u.name = ?", user2.Name).Select("age").Row() + if err := row.Scan(&age); err != nil { + t.Fatalf("Failed to scan age, got %v", err) + } + + if age != 20 { + t.Errorf("Scan with Row, age expects: %v, got %v", user2.Age, age) + } } func TestRows(t *testing.T) { From 1f05cb7e55ece75a08ae79fc1c867ae023ade8c6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 10 Jul 2020 22:53:03 +0800 Subject: [PATCH 535/881] Handle Associations with pointer of pointer, close #3130 --- association.go | 5 ++++- tests/associations_belongs_to_test.go | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/association.go b/association.go index eeb11efe..516a8c57 100644 --- a/association.go +++ b/association.go @@ -30,7 +30,10 @@ func (db *DB) Association(column string) *Association { association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column) } - db.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(db.Statement.Model)) + db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) + for db.Statement.ReflectValue.Kind() == reflect.Ptr { + db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() + } } else { association.Error = err } diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index 1800be91..3e4de726 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -18,7 +18,10 @@ func TestBelongsToAssociation(t *testing.T) { // Find var user2 User DB.Find(&user2, "id = ?", user.ID) - DB.Model(&user2).Association("Company").Find(&user2.Company) + pointerOfUser := &user2 + if err := DB.Model(&pointerOfUser).Association("Company").Find(&user2.Company); err != nil { + t.Errorf("failed to query users, got error %#v", err) + } user2.Manager = &User{} DB.Model(&user2).Association("Manager").Find(user2.Manager) CheckUser(t, user2, user) From 72a64bef1185cfd036aa08abdb300433c28d6889 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 15 Jul 2020 10:25:10 +0800 Subject: [PATCH 536/881] Don't merge clause From --- clause/from.go | 4 ---- clause/from_test.go | 22 +++++++++++----------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/clause/from.go b/clause/from.go index 59b0bfaf..1ea2d595 100644 --- a/clause/from.go +++ b/clause/from.go @@ -33,9 +33,5 @@ func (from From) Build(builder Builder) { // MergeClause merge from clause func (from From) MergeClause(clause *Clause) { - if v, ok := clause.Expression.(From); ok { - from.Tables = append(v.Tables, from.Tables...) - from.Joins = append(v.Joins, from.Joins...) - } clause.Expression = from } diff --git a/clause/from_test.go b/clause/from_test.go index 3ebb754c..75422f8e 100644 --- a/clause/from_test.go +++ b/clause/from_test.go @@ -38,6 +38,16 @@ func TestFrom(t *testing.T) { []clause.Interface{ clause.Select{}, clause.From{ Tables: []clause.Table{{Name: "users"}}, + Joins: []clause.Join{ + { + Type: clause.RightJoin, + Table: clause.Table{Name: "profiles"}, + ON: clause.Where{ + []clause.Expression{clause.Eq{clause.Column{Table: "profiles", Name: "email"}, clause.Column{Table: clause.CurrentTable, Name: "email"}}}, + }, + }, + }, + }, clause.From{ Joins: []clause.Join{ { Type: clause.InnerJoin, @@ -51,19 +61,9 @@ func TestFrom(t *testing.T) { Using: []string{"company_name"}, }, }, - }, clause.From{ - Joins: []clause.Join{ - { - Type: clause.RightJoin, - Table: clause.Table{Name: "profiles"}, - ON: clause.Where{ - []clause.Expression{clause.Eq{clause.Column{Table: "profiles", Name: "email"}, clause.Column{Table: clause.CurrentTable, Name: "email"}}}, - }, - }, - }, }, }, - "SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id` LEFT JOIN `companies` USING (`company_name`) RIGHT JOIN `profiles` ON `profiles`.`email` = `users`.`email`", nil, + "SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id` LEFT JOIN `companies` USING (`company_name`)", nil, }, } From 0028246ea519b2bbb4adc6d6bbd66636a58c1c81 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 16 Jul 2020 10:18:16 +0800 Subject: [PATCH 537/881] Don't set DefaultValueInterface when DefaultValue not set, close #3152 --- schema/field.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/schema/field.go b/schema/field.go index d72a26d5..3e08802a 100644 --- a/schema/field.go +++ b/schema/field.go @@ -179,22 +179,22 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { switch reflect.Indirect(fieldValue).Kind() { case reflect.Bool: field.DataType = Bool - if field.HasDefaultValue { + if field.HasDefaultValue && field.DefaultValue != "" { field.DefaultValueInterface, _ = strconv.ParseBool(field.DefaultValue) } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: field.DataType = Int - if field.HasDefaultValue { + if field.HasDefaultValue && field.DefaultValue != "" { field.DefaultValueInterface, _ = strconv.ParseInt(field.DefaultValue, 0, 64) } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: field.DataType = Uint - if field.HasDefaultValue { + if field.HasDefaultValue && field.DefaultValue != "" { field.DefaultValueInterface, _ = strconv.ParseUint(field.DefaultValue, 0, 64) } case reflect.Float32, reflect.Float64: field.DataType = Float - if field.HasDefaultValue { + if field.HasDefaultValue && field.DefaultValue != "" { field.DefaultValueInterface, _ = strconv.ParseFloat(field.DefaultValue, 64) } case reflect.String: From 4456df7a5d4de3e5e2121d346b79d21c7df29b49 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 16 Jul 2020 11:27:04 +0800 Subject: [PATCH 538/881] Lint with golangci-lint --- association.go | 43 +++++++++++++++++++++++---------------- callbacks/associations.go | 6 +++--- callbacks/helper.go | 4 ++-- callbacks/interface.go | 11 ---------- chainable_api.go | 2 +- clause/clause.go | 2 +- clause/joins.go | 6 +++--- clause/where.go | 2 -- finisher_api.go | 8 ++++---- logger/logger.go | 2 +- logger/sql_test.go | 6 +++--- migrator/migrator.go | 7 +++++-- schema/field.go | 18 ++++++++-------- schema/relationship.go | 5 ++--- statement.go | 5 ++--- utils/utils.go | 2 +- 16 files changed, 62 insertions(+), 67 deletions(-) delete mode 100644 callbacks/interface.go diff --git a/association.go b/association.go index 516a8c57..aa740fc5 100644 --- a/association.go +++ b/association.go @@ -102,10 +102,10 @@ func (association *Association) Replace(values ...interface{}) error { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { - rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) + association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) } case reflect.Struct: - rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) + association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) } for _, ref := range rel.References { @@ -187,18 +187,17 @@ func (association *Association) Replace(values ...interface{}) error { func (association *Association) Delete(values ...interface{}) error { if association.Error == nil { var ( - reflectValue = association.DB.Statement.ReflectValue - rel = association.Relationship - primaryFields, foreignFields []*schema.Field - foreignKeys []string - updateAttrs = map[string]interface{}{} - conds []clause.Expression + reflectValue = association.DB.Statement.ReflectValue + rel = association.Relationship + primaryFields []*schema.Field + foreignKeys []string + updateAttrs = map[string]interface{}{} + conds []clause.Expression ) for _, ref := range rel.References { if ref.PrimaryValue == "" { primaryFields = append(primaryFields, ref.PrimaryKey) - foreignFields = append(foreignFields, ref.ForeignKey) foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) updateAttrs[ref.ForeignKey.DBName] = nil } else { @@ -284,21 +283,23 @@ func (association *Association) Delete(values ...interface{}) error { } } - rel.Field.Set(data, validFieldValues.Interface()) + association.Error = rel.Field.Set(data, validFieldValues.Interface()) case reflect.Struct: for idx, field := range rel.FieldSchema.PrimaryFields { primaryValues[idx], _ = field.ValueOf(fieldValue) } if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { - rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()) + if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { + break + } if rel.JoinTable == nil { for _, ref := range rel.References { if ref.OwnPrimaryKey || ref.PrimaryValue != "" { - ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } else { - ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } } } @@ -436,12 +437,18 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ if len(values) != reflectValue.Len() { if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { - association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { + association.Error = err + break + } if association.Relationship.JoinTable == nil { for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { - ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()) + if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { + association.Error = err + break + } } } } @@ -461,12 +468,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } case reflect.Struct: if clear && len(values) == 0 { - association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) - if association.Relationship.JoinTable == nil { + if association.Relationship.JoinTable == nil && association.Error == nil { for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { - ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } } } diff --git a/callbacks/associations.go b/callbacks/associations.go index 408f3fc9..3508335a 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -21,7 +21,7 @@ func SaveBeforeAssociations(db *gorm.DB) { for _, ref := range rel.References { if !ref.OwnPrimaryKey { pv, _ := ref.PrimaryKey.ValueOf(elem) - ref.ForeignKey.Set(obj, pv) + db.AddError(ref.ForeignKey.Set(obj, pv)) if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { dest[ref.ForeignKey.DBName] = pv @@ -121,9 +121,9 @@ func SaveAfterAssociations(db *gorm.DB) { for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(obj) - ref.ForeignKey.Set(rv, fv) + db.AddError(ref.ForeignKey.Set(rv, fv)) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(rv, ref.PrimaryValue) + db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) } } diff --git a/callbacks/helper.go b/callbacks/helper.go index 1b06e0b7..7bd910f6 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -9,7 +9,7 @@ import ( // ConvertMapToValuesForCreate convert map to values func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { - columns := make([]string, 0, len(mapValue)) + values.Columns = make([]clause.Column, 0, len(mapValue)) selectColumns, restricted := stmt.SelectAndOmitColumns(true, false) var keys []string @@ -25,7 +25,7 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter } if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { - columns = append(columns, k) + values.Columns = append(values.Columns, clause.Column{Name: k}) values.Values[0] = append(values.Values[0], value) } } diff --git a/callbacks/interface.go b/callbacks/interface.go deleted file mode 100644 index ee0044e8..00000000 --- a/callbacks/interface.go +++ /dev/null @@ -1,11 +0,0 @@ -package callbacks - -import "gorm.io/gorm" - -type beforeSaveInterface interface { - BeforeSave(*gorm.DB) error -} - -type beforeCreateInterface interface { - BeforeCreate(*gorm.DB) error -} diff --git a/chainable_api.go b/chainable_api.go index 7ee20324..730f6308 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -41,7 +41,7 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { return } -var tableRegexp = regexp.MustCompile("(?i).+ AS (\\w+)\\s*$") +var tableRegexp = regexp.MustCompile(`(?i).+ AS (\w+)\s*$`) // Table specify the table you would like to run db operations func (db *DB) Table(name string) (tx *DB) { diff --git a/clause/clause.go b/clause/clause.go index c7d1efeb..d413d0ee 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -18,7 +18,7 @@ type Writer interface { // Builder builder interface type Builder interface { Writer - WriteQuoted(field interface{}) error + WriteQuoted(field interface{}) AddVar(Writer, ...interface{}) } diff --git a/clause/joins.go b/clause/joins.go index 8d9055cd..f3e373f2 100644 --- a/clause/joins.go +++ b/clause/joins.go @@ -4,9 +4,9 @@ type JoinType string const ( CrossJoin JoinType = "CROSS" - InnerJoin = "INNER" - LeftJoin = "LEFT" - RightJoin = "RIGHT" + InnerJoin JoinType = "INNER" + LeftJoin JoinType = "LEFT" + RightJoin JoinType = "RIGHT" ) // Join join clause for from diff --git a/clause/where.go b/clause/where.go index a0f4598d..9af9701c 100644 --- a/clause/where.go +++ b/clause/where.go @@ -33,8 +33,6 @@ func (where Where) Build(builder Builder) { expr.Build(builder) } - - return } // MergeClause merge where clauses diff --git a/finisher_api.go b/finisher_api.go index d70b3cd0..6bfe5d20 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -138,11 +138,11 @@ func (tx *DB) assignExprsToValue(exprs []clause.Expression) { switch column := eq.Column.(type) { case string: if field := tx.Statement.Schema.LookUpField(column); field != nil { - field.Set(tx.Statement.ReflectValue, eq.Value) + tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) } case clause.Column: if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { - field.Set(tx.Statement.ReflectValue, eq.Value) + tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) } default: } @@ -433,7 +433,7 @@ func (db *DB) Rollback() *DB { func (db *DB) SavePoint(name string) *DB { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { - savePointer.SavePoint(db, name) + db.AddError(savePointer.SavePoint(db, name)) } else { db.AddError(ErrUnsupportedDriver) } @@ -442,7 +442,7 @@ func (db *DB) SavePoint(name string) *DB { func (db *DB) RollbackTo(name string) *DB { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { - savePointer.RollbackTo(db, name) + db.AddError(savePointer.RollbackTo(db, name)) } else { db.AddError(ErrUnsupportedDriver) } diff --git a/logger/logger.go b/logger/logger.go index 2a5e445c..49ae988c 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -129,7 +129,7 @@ func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { // Trace print sql message func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { if l.LogLevel > 0 { - elapsed := time.Now().Sub(begin) + elapsed := time.Since(begin) switch { case err != nil && l.LogLevel >= Error: sql, rows := fc() diff --git a/logger/sql_test.go b/logger/sql_test.go index 8bc48116..180570b8 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -31,19 +31,19 @@ func TestExplainSQL(t *testing.T) { }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", - NumericRegexp: regexp.MustCompile("@p(\\d+)"), + NumericRegexp: regexp.MustCompile(`@p(\d+)`), Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)", - NumericRegexp: regexp.MustCompile("\\$(\\d+)"), + NumericRegexp: regexp.MustCompile(`\$(\d+)`), Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)", - NumericRegexp: regexp.MustCompile("@p(\\d+)"), + NumericRegexp: regexp.MustCompile(`@p(\d+)`), Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, diff --git a/migrator/migrator.go b/migrator/migrator.go index 169701e4..3e5d86d3 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -1,6 +1,7 @@ package migrator import ( + "context" "database/sql" "fmt" "reflect" @@ -139,7 +140,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.FieldsByDBName[dbName] - createTableSQL += fmt.Sprintf("? ?") + createTableSQL += "? ?" hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) createTableSQL += "," @@ -534,7 +535,9 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i dep := Dependency{ Statement: &gorm.Statement{DB: m.DB, Dest: value}, } - dep.Parse(value) + if err := dep.Parse(value); err != nil { + m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) + } for _, rel := range dep.Schema.Relationships.Relations { if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { diff --git a/schema/field.go b/schema/field.go index 3e08802a..2c43229b 100644 --- a/schema/field.go +++ b/schema/field.go @@ -25,12 +25,12 @@ const ( const ( Bool DataType = "bool" - Int = "int" - Uint = "uint" - Float = "float" - String = "string" - Time = "time" - Bytes = "bytes" + Int DataType = "int" + Uint DataType = "uint" + Float DataType = "float" + String DataType = "string" + Time DataType = "time" + Bytes DataType = "bytes" ) type Field struct { @@ -455,13 +455,13 @@ func (field *Field) setupValuerAndSetter() { if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { - setter(value, v) + err = setter(value, v) } } else if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { - setter(value, reflectV.Elem().Interface()) + err = setter(value, reflectV.Elem().Interface()) } } else { return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) @@ -744,7 +744,7 @@ func (field *Field) setupValuerAndSetter() { if reflectV.IsNil() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { - field.Set(value, reflectV.Elem().Interface()) + err = field.Set(value, reflectV.Elem().Interface()) } } else { fieldValue := field.ReflectValueOf(value) diff --git a/schema/relationship.go b/schema/relationship.go index e3ff0307..c290c5ba 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -71,9 +71,9 @@ func (schema *Schema) parseRelation(field *Field) { return } - if polymorphic, _ := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { schema.buildPolymorphicRelation(relation, field, polymorphic) - } else if many2many, _ := field.TagSettings["MANY2MANY"]; many2many != "" { + } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { schema.buildMany2ManyRelation(relation, field, many2many) } else { switch field.IndirectFieldType.Kind() { @@ -312,7 +312,6 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel OwnPrimaryKey: ownPriamryField, }) } - return } func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) { diff --git a/statement.go b/statement.go index 142c7c31..38154939 100644 --- a/statement.go +++ b/statement.go @@ -60,9 +60,8 @@ func (stmt *Statement) WriteByte(c byte) error { } // WriteQuoted write quoted value -func (stmt *Statement) WriteQuoted(value interface{}) error { +func (stmt *Statement) WriteQuoted(value interface{}) { stmt.QuoteTo(&stmt.SQL, value) - return nil } // QuoteTo write quoted value to writer @@ -215,7 +214,7 @@ func (stmt *Statement) AddClause(v clause.Interface) { optimizer.ModifyStatement(stmt) } else { name := v.Name() - c, _ := stmt.Clauses[name] + c := stmt.Clauses[name] c.Name = name v.MergeClause(&c) stmt.Clauses[name] = c diff --git a/utils/utils.go b/utils/utils.go index 9bf00683..3d7e395b 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -15,7 +15,7 @@ var gormSourceDir string func init() { _, file, _, _ := runtime.Caller(0) - gormSourceDir = regexp.MustCompile("utils.utils\\.go").ReplaceAllString(file, "") + gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "") } func FileWithLineNum() string { From 25954025078a5ce997c55eb471783d7527138167 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 16 Jul 2020 13:37:02 +0800 Subject: [PATCH 539/881] Add reviewdog --- .github/workflows/reviewdog.yml | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 .github/workflows/reviewdog.yml diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml new file mode 100644 index 00000000..4511c378 --- /dev/null +++ b/.github/workflows/reviewdog.yml @@ -0,0 +1,11 @@ +name: reviewdog +on: [pull_request] +jobs: + golangci-lint: + name: runner / golangci-lint + runs-on: ubuntu-latest + steps: + - name: Check out code into the Go module directory + uses: actions/checkout@v1 + - name: golangci-lint + uses: reviewdog/action-golangci-lint@v1 From e83e21097138e4a3603c5e23e6690fb787ce54df Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 16 Jul 2020 17:15:57 +0800 Subject: [PATCH 540/881] Update postgres DSN --- .github/workflows/tests.yml | 2 +- tests/tests_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0e1cbac3..b626ce94 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -145,7 +145,7 @@ jobs: key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm DB.name=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh + run: GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh sqlserver: strategy: diff --git a/tests/tests_test.go b/tests/tests_test.go index afff2d0f..5aedc061 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -52,7 +52,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { case "postgres": log.Println("testing postgres...") if dbDSN == "" { - dbDSN = "user=gorm password=gorm DB.name=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" + dbDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" } db, err = gorm.Open(postgres.New(postgres.Config{ DSN: dbDSN, From b8692c76711f473bb1f5fcd54a38f0611b7410bd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 16 Jul 2020 18:05:55 +0800 Subject: [PATCH 541/881] Allow temporarily disable default transaction --- callbacks/transaction.go | 26 +++++++++++++++----------- gorm.go | 17 +++++++++++------ 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/callbacks/transaction.go b/callbacks/transaction.go index 14d31a62..3171b5bb 100644 --- a/callbacks/transaction.go +++ b/callbacks/transaction.go @@ -5,21 +5,25 @@ import ( ) func BeginTransaction(db *gorm.DB) { - if tx := db.Begin(); tx.Error == nil { - db.Statement.ConnPool = tx.Statement.ConnPool - db.InstanceSet("gorm:started_transaction", true) - } else { - tx.Error = nil + if !db.Config.SkipDefaultTransaction { + if tx := db.Begin(); tx.Error == nil { + db.Statement.ConnPool = tx.Statement.ConnPool + db.InstanceSet("gorm:started_transaction", true) + } else { + tx.Error = nil + } } } func CommitOrRollbackTransaction(db *gorm.DB) { - if _, ok := db.InstanceGet("gorm:started_transaction"); ok { - if db.Error == nil { - db.Commit() - } else { - db.Rollback() + if !db.Config.SkipDefaultTransaction { + if _, ok := db.InstanceGet("gorm:started_transaction"); ok { + if db.Error == nil { + db.Commit() + } else { + db.Rollback() + } + db.Statement.ConnPool = db.ConnPool } - db.Statement.ConnPool = db.ConnPool } } diff --git a/gorm.go b/gorm.go index 1c6d3383..e3b1dd35 100644 --- a/gorm.go +++ b/gorm.go @@ -57,12 +57,13 @@ type DB struct { // Session session config when create session with Session() method type Session struct { - DryRun bool - PrepareStmt bool - WithConditions bool - Context context.Context - Logger logger.Interface - NowFunc func() time.Time + DryRun bool + PrepareStmt bool + WithConditions bool + SkipDefaultTransaction bool + Context context.Context + Logger logger.Interface + NowFunc func() time.Time } // Open initialize db session based on dialector @@ -145,6 +146,10 @@ func (db *DB) Session(config *Session) *DB { } ) + if config.SkipDefaultTransaction { + tx.Config.SkipDefaultTransaction = true + } + if config.Context != nil { tx.Statement = tx.Statement.clone() tx.Statement.DB = tx From 58e32415449ac9e5184de006d00c83072b500a5c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 17 Jul 2020 11:06:20 +0800 Subject: [PATCH 542/881] Fix Select with specific symbol, close #3158 --- tests/query_test.go | 13 +++++++++---- utils/utils.go | 2 +- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/query_test.go b/tests/query_test.go index 62005e3a..22807377 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -310,19 +310,24 @@ func TestSelect(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) r := dryDB.Select("name", "age").Find(&User{}) if !regexp.MustCompile("SELECT .*name.*,.*age.* FROM .*users.*").MatchString(r.Statement.SQL.String()) { - t.Fatalf("Build NOT condition, but got %v", r.Statement.SQL.String()) + t.Fatalf("Build Select with strings, but got %v", r.Statement.SQL.String()) } r = dryDB.Select([]string{"name", "age"}).Find(&User{}) if !regexp.MustCompile("SELECT .*name.*,.*age.* FROM .*users.*").MatchString(r.Statement.SQL.String()) { - t.Fatalf("Build NOT condition, but got %v", r.Statement.SQL.String()) + t.Fatalf("Build Select with slice, but got %v", r.Statement.SQL.String()) } r = dryDB.Table("users").Select("COALESCE(age,?)", 42).Find(&User{}) - if !regexp.MustCompile("SELECT COALESCE\\(age,.*\\) FROM .*users.*").MatchString(r.Statement.SQL.String()) { - t.Fatalf("Build NOT condition, but got %v", r.Statement.SQL.String()) + if !regexp.MustCompile(`SELECT COALESCE\(age,.*\) FROM .*users.*`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with func, but got %v", r.Statement.SQL.String()) } // SELECT COALESCE(age,'42') FROM users; + + r = dryDB.Select("u.*").Table("users as u").First(&User{}, user.ID) + if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String()) + } } func TestPluckWithSelect(t *testing.T) { diff --git a/utils/utils.go b/utils/utils.go index 3d7e395b..e93f3055 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -30,7 +30,7 @@ func FileWithLineNum() string { } func IsChar(c rune) bool { - return !unicode.IsLetter(c) && !unicode.IsNumber(c) + return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' } func CheckTruth(val interface{}) bool { From 362779575c2a91d29074b0a03b27187d615070ce Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 17 Jul 2020 11:24:24 +0800 Subject: [PATCH 543/881] Fix Select with specific symbol, close #3157 --- chainable_api.go | 6 ++++-- clause/select.go | 8 ++++++++ tests/distinct_test.go | 8 ++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 730f6308..7c352268 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -60,11 +60,11 @@ func (db *DB) Table(name string) (tx *DB) { // Distinct specify distinct fields that you want querying func (db *DB) Distinct(args ...interface{}) (tx *DB) { - tx = db + tx = db.getInstance() + tx.Statement.Distinct = true if len(args) > 0 { tx = tx.Select(args[0], args[1:]...) } - tx.Statement.Distinct = true return tx } @@ -102,6 +102,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { tx.Statement.Selects = append(tx.Statement.Selects, arg...) default: tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, Expression: clause.Expr{SQL: v, Vars: args}, }) return @@ -109,6 +110,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } } else { tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, Expression: clause.Expr{SQL: v, Vars: args}, }) } diff --git a/clause/select.go b/clause/select.go index 9c2bc625..b93b8769 100644 --- a/clause/select.go +++ b/clause/select.go @@ -30,6 +30,14 @@ func (s Select) Build(builder Builder) { func (s Select) MergeClause(clause *Clause) { if s.Expression != nil { + if s.Distinct { + if expr, ok := s.Expression.(Expr); ok { + expr.SQL = "DISTINCT " + expr.SQL + clause.Expression = expr + return + } + } + clause.Expression = s.Expression } else { clause.Expression = s diff --git a/tests/distinct_test.go b/tests/distinct_test.go index 248602d3..29a320ff 100644 --- a/tests/distinct_test.go +++ b/tests/distinct_test.go @@ -1,8 +1,10 @@ package tests_test import ( + "regexp" "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -57,4 +59,10 @@ func TestDistinct(t *testing.T) { if err := DB.Model(&User{}).Distinct("name").Where("name like ?", "distinct%").Count(&count).Error; err != nil || count != 3 { t.Errorf("failed to query users count, got error: %v, count %v", err, count) } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + r := dryDB.Distinct("u.id, u.*").Table("user_speaks as s").Joins("inner join users as u on u.id = s.user_id").Where("s.language_code ='US' or s.language_code ='ES'").Find(&User{}) + if !regexp.MustCompile(`SELECT DISTINCT u\.id, u\.\* FROM user_speaks as s inner join users as u`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Distinct with u.*, but got %v", r.Statement.SQL.String()) + } } From 6dc583869b5aef690650f3e3e62d6a80c5de99ea Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 17 Jul 2020 12:02:00 +0800 Subject: [PATCH 544/881] Don't use value's first field to guess data type for struct implements GormDataTypeInterface --- schema/field.go | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/schema/field.go b/schema/field.go index 2c43229b..bc3dbc62 100644 --- a/schema/field.go +++ b/schema/field.go @@ -105,28 +105,30 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { // if field is valuer, used its value or first fields as data type valuer, isValuer := fieldValue.Interface().(driver.Valuer) if isValuer { - var overrideFieldValue bool - if v, err := valuer.Value(); v != nil && err == nil { - overrideFieldValue = true - fieldValue = reflect.ValueOf(v) - } + if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { + var overrideFieldValue bool + if v, err := valuer.Value(); v != nil && err == nil { + overrideFieldValue = true + fieldValue = reflect.ValueOf(v) + } - if field.IndirectFieldType.Kind() == reflect.Struct { - for i := 0; i < field.IndirectFieldType.NumField(); i++ { - if !overrideFieldValue { - newFieldType := field.IndirectFieldType.Field(i).Type - for newFieldType.Kind() == reflect.Ptr { - newFieldType = newFieldType.Elem() + if field.IndirectFieldType.Kind() == reflect.Struct { + for i := 0; i < field.IndirectFieldType.NumField(); i++ { + if !overrideFieldValue { + newFieldType := field.IndirectFieldType.Field(i).Type + for newFieldType.Kind() == reflect.Ptr { + newFieldType = newFieldType.Elem() + } + + fieldValue = reflect.New(newFieldType) + overrideFieldValue = true } - fieldValue = reflect.New(newFieldType) - overrideFieldValue = true - } - - // copy tag settings from valuer - for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { - if _, ok := field.TagSettings[key]; !ok { - field.TagSettings[key] = value + // copy tag settings from valuer + for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { + if _, ok := field.TagSettings[key]; !ok { + field.TagSettings[key] = value + } } } } From e77156980cd74639fefdaf0576785018464a3ca1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 17 Jul 2020 15:49:41 +0800 Subject: [PATCH 545/881] Fix panic when using Select/Omit Associations with no schema, close #3160 --- statement.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/statement.go b/statement.go index 38154939..3a2344ae 100644 --- a/statement.go +++ b/statement.go @@ -503,7 +503,7 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( for _, dbName := range stmt.Schema.DBNames { results[dbName] = true } - } else if column == clause.Associations { + } else if column == clause.Associations && stmt.Schema != nil { for _, rel := range stmt.Schema.Relationships.Relations { results[rel.Name] = true } @@ -517,8 +517,10 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( // omit columns for _, omit := range stmt.Omits { if omit == clause.Associations { - for _, rel := range stmt.Schema.Relationships.Relations { - results[rel.Name] = false + if stmt.Schema != nil { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = false + } } } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { results[field.DBName] = false From de764d9e3deb99d489e6538219fe5fbb12062e72 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 17 Jul 2020 21:19:11 +0800 Subject: [PATCH 546/881] Replace FullTable with TableExpr --- chainable_api.go | 2 +- statement.go | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 7c352268..fe11e474 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -47,7 +47,7 @@ var tableRegexp = regexp.MustCompile(`(?i).+ AS (\w+)\s*$`) func (db *DB) Table(name string) (tx *DB) { tx = db.getInstance() if strings.Contains(name, " ") { - tx.Statement.FullTable = name + tx.Statement.TableExpr = &clause.Expr{SQL: name} if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 { tx.Statement.Table = results[1] return diff --git a/statement.go b/statement.go index 3a2344ae..6641aed8 100644 --- a/statement.go +++ b/statement.go @@ -19,7 +19,7 @@ import ( // Statement statement type Statement struct { *DB - FullTable string + TableExpr *clause.Expr Table string Model interface{} Unscoped bool @@ -69,8 +69,8 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { switch v := field.(type) { case clause.Table: if v.Name == clause.CurrentTable { - if stmt.FullTable != "" { - writer.WriteString(stmt.FullTable) + if stmt.TableExpr != nil { + stmt.TableExpr.Build(stmt) } else { stmt.DB.Dialector.QuoteTo(writer, stmt.Table) } @@ -378,7 +378,6 @@ func (stmt *Statement) Build(clauses ...string) { func (stmt *Statement) Parse(value interface{}) (err error) { if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { stmt.Table = stmt.Schema.Table - stmt.FullTable = stmt.Schema.Table } return err } From 90183fadde3ee228383daadff845ae3a75bc75d0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 19 Jul 2020 21:30:24 +0800 Subject: [PATCH 547/881] Allow advanced table with args --- chainable_api.go | 12 ++++++---- statement.go | 6 +++++ tests/migrate_test.go | 6 ++--- tests/table_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 69 insertions(+), 7 deletions(-) create mode 100644 tests/table_test.go diff --git a/chainable_api.go b/chainable_api.go index fe11e474..4df8780e 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -41,17 +41,21 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { return } -var tableRegexp = regexp.MustCompile(`(?i).+ AS (\w+)\s*$`) +var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`) // Table specify the table you would like to run db operations -func (db *DB) Table(name string) (tx *DB) { +func (db *DB) Table(name string, args ...interface{}) (tx *DB) { tx = db.getInstance() - if strings.Contains(name, " ") { - tx.Statement.TableExpr = &clause.Expr{SQL: name} + if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 { + tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args} if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 { tx.Statement.Table = results[1] return } + } else if tables := strings.Split(name, "."); len(tables) == 2 { + tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} + tx.Statement.Table = tables[1] + return } tx.Statement.Table = name diff --git a/statement.go b/statement.go index 6641aed8..5f4238ef 100644 --- a/statement.go +++ b/statement.go @@ -377,6 +377,12 @@ func (stmt *Statement) Build(clauses ...string) { func (stmt *Statement) Parse(value interface{}) (err error) { if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { + if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { + stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} + stmt.Table = tables[1] + return + } + stmt.Table = stmt.Schema.Table } return err diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 2c593a70..1b002049 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -79,7 +79,7 @@ func TestMigrateWithUniqueIndex(t *testing.T) { } } -func TestTable(t *testing.T) { +func TestMigrateTable(t *testing.T) { type TableStruct struct { gorm.Model Name string @@ -112,7 +112,7 @@ func TestTable(t *testing.T) { } } -func TestIndexes(t *testing.T) { +func TestMigrateIndexes(t *testing.T) { type IndexStruct struct { gorm.Model Name string `gorm:"size:255;index"` @@ -162,7 +162,7 @@ func TestIndexes(t *testing.T) { } } -func TestColumns(t *testing.T) { +func TestMigrateColumns(t *testing.T) { type ColumnStruct struct { gorm.Model Name string diff --git a/tests/table_test.go b/tests/table_test.go new file mode 100644 index 00000000..b96af170 --- /dev/null +++ b/tests/table_test.go @@ -0,0 +1,52 @@ +package tests_test + +import ( + "regexp" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +type UserWithTable struct { + gorm.Model + Name string +} + +func (UserWithTable) TableName() string { + return "gorm.user" +} + +func TestTable(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true}) + + r := dryDB.Table("`user`").Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM `user`").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("user as u").Select("name").Find(&User{}).Statement + if !regexp.MustCompile("SELECT .name. FROM user as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("gorm.user").Select("name").Find(&User{}).Statement + if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Select("name").Find(&UserWithTable{}).Statement + if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("(?) as u", DB.Model(&User{}).Select("name")).Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name"), DB.Model(&Pet{}).Select("name")).Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE .pets.\\..deleted_at. IS NULL\\) as p WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } +} From a0477f94dd97ef33a442aadf7c710ac03d4a0590 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 19 Jul 2020 21:48:58 +0800 Subject: [PATCH 548/881] Allow Omit with Query, close #3165 --- callbacks/query.go | 8 ++++++++ tests/query_test.go | 15 +++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/callbacks/query.go b/callbacks/query.go index 9601f9bd..5c322a05 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -64,6 +64,14 @@ func BuildQuerySQL(db *gorm.DB) { clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} } } + } else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 { + selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false) + clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames)) + for _, dbName := range db.Statement.Schema.DBNames { + if v, ok := selectColumns[dbName]; (ok && v) || !ok { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Name: dbName}) + } + } } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() { smallerStruct := false switch db.Statement.ReflectValue.Kind() { diff --git a/tests/query_test.go b/tests/query_test.go index 22807377..59f1130b 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -330,6 +330,21 @@ func TestSelect(t *testing.T) { } } +func TestOmit(t *testing.T) { + user := User{Name: "OmitUser1", Age: 20} + DB.Save(&user) + + var result User + DB.Where("name = ?", user.Name).Omit("name").Find(&result) + if result.ID == 0 { + t.Errorf("Should not have ID because only selected name, %+v", result.ID) + } + + if result.Name != "" || result.Age != 20 { + t.Errorf("User Name should be omitted, got %v, Age should be ok, got %v", result.Name, result.Age) + } +} + func TestPluckWithSelect(t *testing.T) { users := []User{ {Name: "pluck_with_select_1", Age: 25}, From 5d0544106744430c24d5772da1fb64395ddfe48d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 20 Jul 2020 08:12:18 +0800 Subject: [PATCH 549/881] Test From SubQuery with vars --- tests/table_test.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/table_test.go b/tests/table_test.go index b96af170..faee6499 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -49,4 +49,11 @@ func TestTable(t *testing.T) { if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE .pets.\\..deleted_at. IS NULL\\) as p WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } + + r = dryDB.Where("name = ?", 1).Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name").Where("name = ?", 2), DB.Model(&Pet{}).Where("name = ?", 4).Select("name")).Where("name = ?", 3).Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE name = .+ AND .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE name = .+ AND .pets.\\..deleted_at. IS NULL\\) as p WHERE name = .+ AND name = .+ AND .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3}) } From ef002fd7accb973c9f36931e2b1c3112d2b062ea Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 20 Jul 2020 18:59:28 +0800 Subject: [PATCH 550/881] Add GORMDataType to Field, close #3171 --- callbacks/update.go | 4 ++-- gorm.go | 1 + schema/field.go | 7 +++++++ schema/relationship.go | 3 +++ schema/schema.go | 2 +- 5 files changed, 14 insertions(+), 3 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 97a0e893..d549f97b 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -202,7 +202,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime == schema.UnixNanosecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) - } else if field.DataType == schema.Time { + } else if field.GORMDataType == schema.Time { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) } else { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) @@ -223,7 +223,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime > 0 { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() - } else if field.DataType == schema.Time { + } else if field.GORMDataType == schema.Time { value = stmt.DB.NowFunc() } else { value = stmt.DB.NowFunc().Unix() diff --git a/gorm.go b/gorm.go index e3b1dd35..338a1473 100644 --- a/gorm.go +++ b/gorm.go @@ -300,6 +300,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac for _, ref := range relation.References { if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { f.DataType = ref.ForeignKey.DataType + f.GORMDataType = ref.ForeignKey.GORMDataType ref.ForeignKey = f } else { return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) diff --git a/schema/field.go b/schema/field.go index bc3dbc62..a170e60e 100644 --- a/schema/field.go +++ b/schema/field.go @@ -38,6 +38,7 @@ type Field struct { DBName string BindNames []string DataType DataType + GORMDataType DataType PrimaryKey bool AutoIncrement bool Creatable bool @@ -221,6 +222,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + field.GORMDataType = field.DataType + if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { field.DataType = DataType(dataTyper.GormDataType()) } @@ -250,6 +253,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + if field.GORMDataType == "" { + field.GORMDataType = field.DataType + } + if field.Size == 0 { switch reflect.Indirect(fieldValue).Kind() { case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: diff --git a/schema/relationship.go b/schema/relationship.go index c290c5ba..e67092b4 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -157,6 +157,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi // use same data type for foreign keys relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType + relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType relation.References = append(relation.References, &Reference{ PrimaryKey: primaryKeyField, @@ -285,6 +286,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel for idx, f := range relation.JoinTable.Fields { // use same data type for foreign keys f.DataType = fieldsMap[f.Name].DataType + f.GORMDataType = fieldsMap[f.Name].GORMDataType relation.JoinTable.PrimaryFields[idx] = f ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] @@ -387,6 +389,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH for idx, foreignField := range foreignFields { // use same data type for foreign keys foreignField.DataType = primaryFields[idx].DataType + foreignField.GORMDataType = primaryFields[idx].GORMDataType relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], diff --git a/schema/schema.go b/schema/schema.go index 66e02443..bcf65939 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -182,7 +182,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } if field := schema.PrioritizedPrimaryField; field != nil { - switch field.DataType { + switch field.GORMDataType { case Int, Uint: if !field.HasDefaultValue || field.DefaultValueInterface != nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) From 0546b59743ec2759051cb921a4dc5f7c31f36e3d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 22 Jul 2020 11:28:00 +0800 Subject: [PATCH 551/881] Fix save many2many associations with UUID primary key, close #3182 --- callbacks/create.go | 9 ++++++++- tests/postgres_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/callbacks/create.go b/callbacks/create.go index eecb80a1..de5bf1f8 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -149,10 +149,17 @@ func CreateWithReturning(db *gorm.DB) { for rows.Next() { BEGIN: + reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected)) for idx, field := range fields { - fieldValue := field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))) + fieldValue := field.ReflectValueOf(reflectValue) + if onConflict.DoNothing && !fieldValue.IsZero() { db.RowsAffected++ + + if int(db.RowsAffected) >= db.Statement.ReflectValue.Len() { + return + } + goto BEGIN } diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 98302d87..ab47a548 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -37,3 +37,36 @@ func TestPostgres(t *testing.T) { t.Errorf("No error should happen, but got %v", err) } } + +type Post struct { + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4()"` + Title string + Categories []*Category `gorm:"Many2Many:post_categories"` +} + +type Category struct { + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4()"` + Title string + Posts []*Post `gorm:"Many2Many:post_categories"` +} + +func TestMany2ManyWithDefaultValueUUID(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + DB.Migrator().DropTable(&Post{}, &Category{}, "post_categories") + DB.AutoMigrate(&Post{}, &Category{}) + + post := Post{ + Title: "Hello World", + Categories: []*Category{ + {Title: "Coding"}, + {Title: "Golang"}, + }, + } + + if err := DB.Create(&post).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } +} From da16f7b4756ead84856448fab67ff6aeddf91f60 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 22 Jul 2020 12:13:40 +0800 Subject: [PATCH 552/881] Create extension uuid-ossp for postgres test database --- tests/postgres_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/postgres_test.go b/tests/postgres_test.go index ab47a548..a0b1fddb 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -55,6 +55,10 @@ func TestMany2ManyWithDefaultValueUUID(t *testing.T) { t.Skip() } + if err := DB.Exec(`create extension if not exists "uuid-ossp"`).Error; err != nil { + t.Fatalf("Failed to create 'uuid-ossp' extension, but got error %v", err) + } + DB.Migrator().DropTable(&Post{}, &Category{}, "post_categories") DB.AutoMigrate(&Post{}, &Category{}) From 87112ab1c711db2d8dd26ee32a4ccd0bb9307261 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 22 Jul 2020 15:05:38 +0800 Subject: [PATCH 553/881] Fix row callback name --- callbacks/callbacks.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index f61252d4..0a12468c 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -45,6 +45,6 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) - db.Callback().Row().Register("gorm:raw", RowQuery) + db.Callback().Row().Register("gorm:row", RowQuery) db.Callback().Raw().Register("gorm:raw", RawExec) } From 7021db3655381405b8c3f848319a66128b96041b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 22 Jul 2020 19:03:19 +0800 Subject: [PATCH 554/881] Fix FieldsWithDefaultDBValue for primary field, close #3187 --- schema/schema.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index bcf65939..1106f0c5 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -184,11 +184,11 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if field := schema.PrioritizedPrimaryField; field != nil { switch field.GORMDataType { case Int, Uint: - if !field.HasDefaultValue || field.DefaultValueInterface != nil { - schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) - } - if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok { + if !field.HasDefaultValue || field.DefaultValueInterface != nil { + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) + } + field.HasDefaultValue = true field.AutoIncrement = true } From 6ed697dd0225631c19bcfc43bf8762ced235742c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 23 Jul 2020 23:41:56 +0800 Subject: [PATCH 555/881] TestFirstOrCreateWithPrimaryKey, close #3192 --- callbacks/create.go | 10 +--------- tests/create_test.go | 19 +++++++++++++++++++ tests/go.mod | 6 +++--- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index de5bf1f8..707b94c1 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -70,16 +70,8 @@ func Create(config *Config) func(db *gorm.DB) { } } } else { - allUpdated := int(db.RowsAffected) == db.Statement.ReflectValue.Len() - isZero := true - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - - if !allUpdated { - _, isZero = db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)) - } - - if isZero { + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)); isZero { db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) insertID++ } diff --git a/tests/create_test.go b/tests/create_test.go index 46cc06c6..ae6e1232 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -352,3 +352,22 @@ func TestOmitWithCreate(t *testing.T) { CheckUser(t, result2, user2) } + +func TestFirstOrCreateWithPrimaryKey(t *testing.T) { + company := Company{ID: 100, Name: "company100_with_primarykey"} + DB.FirstOrCreate(&company) + + if company.ID != 100 { + t.Errorf("invalid primary key after creating, got %v", company.ID) + } + + companies := []Company{ + {ID: 101, Name: "company101_with_primarykey"}, + {ID: 102, Name: "company102_with_primarykey"}, + } + DB.Create(&companies) + + if companies[0].ID != 101 || companies[1].ID != 102 { + t.Errorf("invalid primary key after creating, got %v, %v", companies[0].ID, companies[1].ID) + } +} diff --git a/tests/go.mod b/tests/go.mod index 3a5b4224..6eb6eb07 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,10 +6,10 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.2.9 - gorm.io/driver/postgres v0.2.5 + gorm.io/driver/mysql v0.3.1 + gorm.io/driver/postgres v0.2.6 gorm.io/driver/sqlite v1.0.8 - gorm.io/driver/sqlserver v0.2.4 + gorm.io/driver/sqlserver v0.2.5 gorm.io/gorm v0.2.19 ) From c3f52cee8b1e3d26fd0618399cc2a0cc012ff216 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 23 Jul 2020 23:56:13 +0800 Subject: [PATCH 556/881] Don't scan last insert id 0 --- callbacks/create.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/callbacks/create.go b/callbacks/create.go index 707b94c1..c86cefe4 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -78,7 +78,9 @@ func Create(config *Config) func(db *gorm.DB) { } } case reflect.Struct: - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + if insertID > 0 { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } } } else { db.AddError(err) From 69d81118936a761a140d35eb07f1cd249067a1a9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 24 Jul 2020 08:32:50 +0800 Subject: [PATCH 557/881] Fix panic when using invalid data, close #3193 --- callbacks/create.go | 6 +++--- callbacks/delete.go | 2 +- callbacks/query.go | 2 +- callbacks/update.go | 2 +- errors.go | 6 ------ statement.go | 4 +++- 6 files changed, 9 insertions(+), 13 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index c86cefe4..b41a3ef2 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -51,7 +51,7 @@ func Create(config *Config) func(db *gorm.DB) { db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") } - if !db.DryRun { + if !db.DryRun && db.Error == nil { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { @@ -130,7 +130,7 @@ func CreateWithReturning(db *gorm.DB) { db.Statement.WriteQuoted(field.DBName) } - if !db.DryRun { + if !db.DryRun && db.Error == nil { rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { @@ -179,7 +179,7 @@ func CreateWithReturning(db *gorm.DB) { db.AddError(err) } } - } else if !db.DryRun { + } else if !db.DryRun && db.Error == nil { if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { db.RowsAffected, _ = result.RowsAffected() } else { diff --git a/callbacks/delete.go b/callbacks/delete.go index 51a33bf0..288f2d69 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -60,7 +60,7 @@ func Delete(db *gorm.DB) { db.Statement.Build("DELETE", "FROM", "WHERE") } - if !db.DryRun { + if !db.DryRun && db.Error == nil { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { diff --git a/callbacks/query.go b/callbacks/query.go index 5c322a05..66bbf805 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -23,7 +23,7 @@ func Query(db *gorm.DB) { BuildQuerySQL(db) } - if !db.DryRun { + if !db.DryRun && db.Error == nil { rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err != nil { db.AddError(err) diff --git a/callbacks/update.go b/callbacks/update.go index d549f97b..e492cfc9 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -74,7 +74,7 @@ func Update(db *gorm.DB) { return } - if !db.DryRun { + if !db.DryRun && db.Error == nil { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { diff --git a/errors.go b/errors.go index e1b58835..12e64611 100644 --- a/errors.go +++ b/errors.go @@ -7,20 +7,14 @@ import ( var ( // ErrRecordNotFound record not found error ErrRecordNotFound = errors.New("record not found") - // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL - ErrInvalidSQL = errors.New("invalid SQL") // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` ErrInvalidTransaction = errors.New("no valid transaction") - // ErrUnaddressable unaddressable value - ErrUnaddressable = errors.New("using unaddressable value") // ErrNotImplemented not implemented ErrNotImplemented = errors.New("not implemented") // ErrMissingWhereClause missing where clause ErrMissingWhereClause = errors.New("WHERE conditions required") // ErrUnsupportedRelation unsupported relations ErrUnsupportedRelation = errors.New("unsupported relations") - // ErrPtrStructSupported only ptr of struct supported - ErrPtrStructSupported = errors.New("only ptr of struct supported") // ErrorPrimaryKeyRequired primary keys required ErrorPrimaryKeyRequired = errors.New("primary key required") // ErrorModelValueRequired model value required diff --git a/statement.go b/statement.go index 5f4238ef..310484d8 100644 --- a/statement.go +++ b/statement.go @@ -95,7 +95,9 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { } if v.Name == clause.PrimaryKey { - if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil { + if stmt.Schema == nil { + stmt.DB.AddError(ErrorModelValueRequired) + } else if stmt.Schema.PrioritizedPrimaryField != nil { stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName) } else if len(stmt.Schema.DBNames) > 0 { stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.DBNames[0]) From f4cfa9411bc3eae4488d52c30272cd3cdb6e2127 Mon Sep 17 00:00:00 2001 From: Qt Date: Sun, 26 Jul 2020 10:03:58 +0800 Subject: [PATCH 558/881] define err with the same code style (#3199) --- association.go | 2 +- errors.go | 8 ++++---- finisher_api.go | 2 +- statement.go | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/association.go b/association.go index aa740fc5..e59b8938 100644 --- a/association.go +++ b/association.go @@ -170,7 +170,7 @@ func (association *Association) Replace(values ...interface{}) error { if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { tx.Where(clause.IN{Column: column, Values: values}) } else { - return ErrorPrimaryKeyRequired + return ErrPrimaryKeyRequired } _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) diff --git a/errors.go b/errors.go index 12e64611..115b8e25 100644 --- a/errors.go +++ b/errors.go @@ -15,10 +15,10 @@ var ( ErrMissingWhereClause = errors.New("WHERE conditions required") // ErrUnsupportedRelation unsupported relations ErrUnsupportedRelation = errors.New("unsupported relations") - // ErrorPrimaryKeyRequired primary keys required - ErrorPrimaryKeyRequired = errors.New("primary key required") - // ErrorModelValueRequired model value required - ErrorModelValueRequired = errors.New("model value required") + // ErrPrimaryKeyRequired primary keys required + ErrPrimaryKeyRequired = errors.New("primary key required") + // ErrModelValueRequired model value required + ErrModelValueRequired = errors.New("model value required") // ErrUnsupportedDriver unsupported driver ErrUnsupportedDriver = errors.New("unsupported driver") // ErrRegistered registered diff --git a/finisher_api.go b/finisher_api.go index 6bfe5d20..77bea578 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -325,7 +325,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { } } } else if tx.Statement.Table == "" { - tx.AddError(ErrorModelValueRequired) + tx.AddError(ErrModelValueRequired) } fields := strings.FieldsFunc(column, utils.IsChar) diff --git a/statement.go b/statement.go index 310484d8..e9d826c4 100644 --- a/statement.go +++ b/statement.go @@ -96,7 +96,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { if v.Name == clause.PrimaryKey { if stmt.Schema == nil { - stmt.DB.AddError(ErrorModelValueRequired) + stmt.DB.AddError(ErrModelValueRequired) } else if stmt.Schema.PrioritizedPrimaryField != nil { stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName) } else if len(stmt.Schema.DBNames) > 0 { From c7667e9299134799da6f16e19eaf50cb8419736f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 28 Jul 2020 14:26:09 +0800 Subject: [PATCH 559/881] Refactor Prepared Statement --- gorm.go | 22 +++++++++++++++------- prepare_stmt.go | 14 +++++++++----- tests/.gitignore | 1 + 3 files changed, 25 insertions(+), 12 deletions(-) create mode 100644 tests/.gitignore diff --git a/gorm.go b/gorm.go index 338a1473..c786b5a5 100644 --- a/gorm.go +++ b/gorm.go @@ -108,11 +108,15 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { err = config.Dialector.Initialize(db) } + preparedStmt := &PreparedStmtDB{ + ConnPool: db.ConnPool, + Stmts: map[string]*sql.Stmt{}, + PreparedSQL: make([]string, 0, 100), + } + db.cacheStore.Store("preparedStmt", preparedStmt) + if config.PrepareStmt { - db.ConnPool = &PreparedStmtDB{ - ConnPool: db.ConnPool, - Stmts: map[string]*sql.Stmt{}, - } + db.ConnPool = preparedStmt } db.Statement = &Statement{ @@ -157,9 +161,13 @@ func (db *DB) Session(config *Session) *DB { } if config.PrepareStmt { - tx.Statement.ConnPool = &PreparedStmtDB{ - ConnPool: db.Config.ConnPool, - Stmts: map[string]*sql.Stmt{}, + if v, ok := db.cacheStore.Load("preparedStmt"); ok { + preparedStmt := v.(*PreparedStmtDB) + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + mux: preparedStmt.mux, + Stmts: preparedStmt.Stmts, + } } } diff --git a/prepare_stmt.go b/prepare_stmt.go index 197c257c..2f4e1d57 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -7,16 +7,19 @@ import ( ) type PreparedStmtDB struct { - Stmts map[string]*sql.Stmt - mux sync.RWMutex + Stmts map[string]*sql.Stmt + PreparedSQL []string + mux sync.RWMutex ConnPool } func (db *PreparedStmtDB) Close() { db.mux.Lock() - for k, stmt := range db.Stmts { - delete(db.Stmts, k) - stmt.Close() + for _, query := range db.PreparedSQL { + if stmt, ok := db.Stmts[query]; ok { + delete(db.Stmts, query) + stmt.Close() + } } db.mux.Unlock() @@ -40,6 +43,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { stmt, err := db.ConnPool.PrepareContext(context.Background(), query) if err == nil { db.Stmts[query] = stmt + db.PreparedSQL = append(db.PreparedSQL, query) } db.mux.Unlock() diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 00000000..08cb523c --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1 @@ +go.sum From a140908839f5f6f3b2e493fbe7b779fb9fffc3ff Mon Sep 17 00:00:00 2001 From: Qt Date: Tue, 28 Jul 2020 17:25:03 +0800 Subject: [PATCH 560/881] refactor function convertParams's default case (#3208) --- logger/sql.go | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index d3c0bf10..02d559c5 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -50,30 +50,22 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a case string: vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper default: - if v == nil { + rv := reflect.ValueOf(v) + if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { vars[idx] = "NULL" + } else if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + convertParams(v, idx) + } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { + convertParams(reflect.Indirect(rv).Interface(), idx) } else { - rv := reflect.ValueOf(v) - - if !rv.IsValid() { - vars[idx] = "NULL" - } else if rv.Kind() == reflect.Ptr && rv.IsNil() { - vars[idx] = "NULL" - } else if valuer, ok := v.(driver.Valuer); ok { - v, _ = valuer.Value() - convertParams(v, idx) - } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { - convertParams(reflect.Indirect(rv).Interface(), idx) - } else { - for _, t := range convertableTypes { - if rv.Type().ConvertibleTo(t) { - convertParams(rv.Convert(t).Interface(), idx) - return - } + for _, t := range convertableTypes { + if rv.Type().ConvertibleTo(t) { + convertParams(rv.Convert(t).Interface(), idx) + return } - - vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper } + vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper } } } From 2cbdd29f26eeb81e7c1b9f014bf1a0a8066f76ba Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 29 Jul 2020 10:23:14 +0800 Subject: [PATCH 561/881] Returns error for invalid embedded field, close #3209 --- schema/field.go | 78 ++++++++++++++++++++++++++----------------------- 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/schema/field.go b/schema/field.go index a170e60e..f377a34a 100644 --- a/schema/field.go +++ b/schema/field.go @@ -304,44 +304,48 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer) { - var err error - field.Creatable = false - field.Updatable = false - field.Readable = false - if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { - schema.err = err + if reflect.Indirect(fieldValue).Kind() == reflect.Struct { + var err error + field.Creatable = false + field.Updatable = false + field.Readable = false + if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { + schema.err = err + } + for _, ef := range field.EmbeddedSchema.Fields { + ef.Schema = schema + ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) + // index is negative means is pointer + if field.FieldType.Kind() == reflect.Struct { + ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) + } else { + ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...) + } + + if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { + ef.DBName = prefix + ef.DBName + } + + if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { + ef.PrimaryKey = true + } else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { + ef.PrimaryKey = true + } else { + ef.PrimaryKey = false + } + + for k, v := range field.TagSettings { + ef.TagSettings[k] = v + } + } + + field.Schema.CreateClauses = append(field.Schema.CreateClauses, field.EmbeddedSchema.CreateClauses...) + field.Schema.QueryClauses = append(field.Schema.QueryClauses, field.EmbeddedSchema.QueryClauses...) + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, field.EmbeddedSchema.UpdateClauses...) + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, field.EmbeddedSchema.DeleteClauses...) + } else { + schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) } - for _, ef := range field.EmbeddedSchema.Fields { - ef.Schema = schema - ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) - // index is negative means is pointer - if field.FieldType.Kind() == reflect.Struct { - ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) - } else { - ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...) - } - - if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { - ef.DBName = prefix + ef.DBName - } - - if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { - ef.PrimaryKey = true - } else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { - ef.PrimaryKey = true - } else { - ef.PrimaryKey = false - } - - for k, v := range field.TagSettings { - ef.TagSettings[k] = v - } - } - - field.Schema.CreateClauses = append(field.Schema.CreateClauses, field.EmbeddedSchema.CreateClauses...) - field.Schema.QueryClauses = append(field.Schema.QueryClauses, field.EmbeddedSchema.QueryClauses...) - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, field.EmbeddedSchema.UpdateClauses...) - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, field.EmbeddedSchema.DeleteClauses...) } return field From 7c2ecdfc1c738f118b892d593ac3899d8e92b74b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jul 2020 10:23:35 +0800 Subject: [PATCH 562/881] Fix use pointer of Valuer as foreign key, close #3212 --- schema/field.go | 5 +++-- tests/scanner_valuer_test.go | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/schema/field.go b/schema/field.go index f377a34a..329ae41c 100644 --- a/schema/field.go +++ b/schema/field.go @@ -742,15 +742,16 @@ func (field *Field) setupValuerAndSetter() { } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner field.Set = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if valuer, ok := v.(driver.Valuer); ok { - if valuer == nil { + if valuer == nil || reflectV.IsNil() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { v, _ = valuer.Value() } } - reflectV := reflect.ValueOf(v) if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 632bd74a..bee0ae98 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -136,6 +136,8 @@ type ScannerValuerStruct struct { Strings StringsSlice Structs StructsSlice Role Role + UserID *sql.NullInt64 + User User } type EncryptedData []byte From 47a5196734de9f4d8486a1be568c8341991b4ac8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jul 2020 17:36:39 +0800 Subject: [PATCH 563/881] Fix uninitialized Valuer return time.Time, close #3214 --- schema/field.go | 2 ++ tests/scanner_valuer_test.go | 44 ++++++++++++++++++++++++------------ 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/schema/field.go b/schema/field.go index 329ae41c..6d0fd1cc 100644 --- a/schema/field.go +++ b/schema/field.go @@ -213,6 +213,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.Struct: if _, ok := fieldValue.Interface().(*time.Time); ok { field.DataType = Time + } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + field.DataType = Time } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { field.DataType = Time } diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index bee0ae98..2c2c1e18 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -124,20 +124,21 @@ func TestInvalidValuer(t *testing.T) { type ScannerValuerStruct struct { gorm.Model - Name sql.NullString - Gender *sql.NullString - Age sql.NullInt64 - Male sql.NullBool - Height sql.NullFloat64 - Birthday sql.NullTime - Password EncryptedData - Bytes []byte - Num Num - Strings StringsSlice - Structs StructsSlice - Role Role - UserID *sql.NullInt64 - User User + Name sql.NullString + Gender *sql.NullString + Age sql.NullInt64 + Male sql.NullBool + Height sql.NullFloat64 + Birthday sql.NullTime + Password EncryptedData + Bytes []byte + Num Num + Strings StringsSlice + Structs StructsSlice + Role Role + UserID *sql.NullInt64 + User User + EmptyTime EmptyTime } type EncryptedData []byte @@ -244,3 +245,18 @@ func (role Role) Value() (driver.Value, error) { func (role Role) IsAdmin() bool { return role.Name == "admin" } + +type EmptyTime struct { + time.Time +} + +func (t *EmptyTime) Scan(v interface{}) error { + nullTime := sql.NullTime{} + err := nullTime.Scan(v) + t.Time = nullTime.Time + return err +} + +func (t EmptyTime) Value() (driver.Value, error) { + return t.Time, nil +} From 7bb883b665082c0506991f8c87e5f02d86254920 Mon Sep 17 00:00:00 2001 From: lninl Date: Thu, 30 Jul 2020 17:39:57 +0800 Subject: [PATCH 564/881] Auto creating/updating time with unix (milli) second (#3213) * Auto creating/updating time with unix (milli) second * add test for 'Auto creating/updating time with unix (milli) second' --- callbacks/update.go | 10 +++++++--- schema/field.go | 13 +++++++++++-- tests/customize_field_test.go | 36 +++++++++++++++++++++++------------ 3 files changed, 42 insertions(+), 17 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index e492cfc9..12806af6 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -140,7 +140,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - var priamryKeyExprs []clause.Expression + var primaryKeyExprs []clause.Expression for i := 0; i < stmt.ReflectValue.Len(); i++ { var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) var notZero bool @@ -150,10 +150,10 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { notZero = notZero || !isZero } if notZero { - priamryKeyExprs = append(priamryKeyExprs, clause.And(exprs...)) + primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) } } - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}}) + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) case reflect.Struct: for _, field := range stmt.Schema.PrimaryFields { if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { @@ -202,6 +202,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime == schema.UnixNanosecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) + } else if field.AutoUpdateTime == schema.UnixMillisecond { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) } else if field.GORMDataType == schema.Time { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) } else { @@ -223,6 +225,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime > 0 { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() + } else if field.AutoUpdateTime == schema.UnixMillisecond { + value = stmt.DB.NowFunc().UnixNano() / 1e6 } else if field.GORMDataType == schema.Time { value = stmt.DB.NowFunc() } else { diff --git a/schema/field.go b/schema/field.go index 6d0fd1cc..4eb95b98 100644 --- a/schema/field.go +++ b/schema/field.go @@ -19,8 +19,9 @@ type DataType string type TimeType int64 const ( - UnixSecond TimeType = 1 - UnixNanosecond TimeType = 2 + UnixSecond TimeType = 1 + UnixMillisecond TimeType = 2 + UnixNanosecond TimeType = 3 ) const ( @@ -233,6 +234,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if strings.ToUpper(v) == "NANO" { field.AutoCreateTime = UnixNanosecond + } else if strings.ToUpper(v) == "MILLI" { + field.AutoCreateTime = UnixMillisecond } else { field.AutoCreateTime = UnixSecond } @@ -241,6 +244,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if strings.ToUpper(v) == "NANO" { field.AutoUpdateTime = UnixNanosecond + } else if strings.ToUpper(v) == "MILLI" { + field.AutoUpdateTime = UnixMillisecond } else { field.AutoUpdateTime = UnixSecond } @@ -551,6 +556,8 @@ func (field *Field) setupValuerAndSetter() { case time.Time: if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(value).SetInt(data.UnixNano()) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) } else { field.ReflectValueOf(value).SetInt(data.Unix()) } @@ -558,6 +565,8 @@ func (field *Field) setupValuerAndSetter() { if data != nil { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(value).SetInt(data.UnixNano()) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) } else { field.ReflectValueOf(value).SetInt(data.Unix()) } diff --git a/tests/customize_field_test.go b/tests/customize_field_test.go index 9c6ab948..bf3c78fa 100644 --- a/tests/customize_field_test.go +++ b/tests/customize_field_test.go @@ -61,18 +61,20 @@ func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { func TestCustomizeField(t *testing.T) { type CustomizeFieldStruct struct { gorm.Model - Name string - FieldAllowCreate string `gorm:"<-:create"` - FieldAllowUpdate string `gorm:"<-:update"` - FieldAllowSave string `gorm:"<-"` - FieldAllowSave2 string `gorm:"<-:create,update"` - FieldAllowSave3 string `gorm:"->:false;<-:create"` - FieldReadonly string `gorm:"->"` - FieldIgnore string `gorm:"-"` - AutoUnixCreateTime int64 `gorm:"autocreatetime"` - AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"` - AutoUnixUpdateTime int64 `gorm:"autoupdatetime"` - AutoUnixNanoUpdateTime int64 `gorm:"autoupdatetime:nano"` + Name string + FieldAllowCreate string `gorm:"<-:create"` + FieldAllowUpdate string `gorm:"<-:update"` + FieldAllowSave string `gorm:"<-"` + FieldAllowSave2 string `gorm:"<-:create,update"` + FieldAllowSave3 string `gorm:"->:false;<-:create"` + FieldReadonly string `gorm:"->"` + FieldIgnore string `gorm:"-"` + AutoUnixCreateTime int64 `gorm:"autocreatetime"` + AutoUnixMilliCreateTime int64 `gorm:"autocreatetime:milli"` + AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"` + AutoUnixUpdateTime int64 `gorm:"autoupdatetime"` + AutoUnixMilliUpdateTime int64 `gorm:"autoupdatetime:milli"` + AutoUnixNanoUpdateTime int64 `gorm:"autoupdatetime:nano"` } DB.Migrator().DropTable(&CustomizeFieldStruct{}) @@ -118,6 +120,10 @@ func TestCustomizeField(t *testing.T) { t.Fatalf("invalid create/update unix time: %#v", result) } + if result.AutoUnixMilliCreateTime != result.AutoUnixMilliUpdateTime || result.AutoUnixMilliCreateTime == 0 || result.AutoUnixMilliCreateTime/result.AutoUnixCreateTime < 1e3 { + t.Fatalf("invalid create/update unix milli time: %#v", result) + } + if result.AutoUnixNanoCreateTime != result.AutoUnixNanoUpdateTime || result.AutoUnixNanoCreateTime == 0 || result.AutoUnixNanoCreateTime/result.AutoUnixCreateTime < 1e6 { t.Fatalf("invalid create/update unix nano time: %#v", result) } @@ -163,6 +169,8 @@ func TestCustomizeField(t *testing.T) { createWithDefaultTime := generateStruct("create_with_default_time") createWithDefaultTime.AutoUnixCreateTime = 100 createWithDefaultTime.AutoUnixUpdateTime = 100 + createWithDefaultTime.AutoUnixMilliCreateTime = 100 + createWithDefaultTime.AutoUnixMilliUpdateTime = 100 createWithDefaultTime.AutoUnixNanoCreateTime = 100 createWithDefaultTime.AutoUnixNanoUpdateTime = 100 DB.Create(&createWithDefaultTime) @@ -174,6 +182,10 @@ func TestCustomizeField(t *testing.T) { t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult) } + if createWithDefaultTimeResult.AutoUnixMilliCreateTime != createWithDefaultTimeResult.AutoUnixMilliUpdateTime || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 { + t.Fatalf("invalid create/update unix milli time: %#v", createWithDefaultTimeResult) + } + if createWithDefaultTimeResult.AutoUnixNanoCreateTime != createWithDefaultTimeResult.AutoUnixNanoUpdateTime || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 { t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult) } From 07ce8caf7df21e067de87a048d3cf638426bfe33 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jul 2020 17:42:41 +0800 Subject: [PATCH 565/881] Remove labeler workflows --- .github/workflows/labeler.yml | 19 ------------------- 1 file changed, 19 deletions(-) delete mode 100644 .github/workflows/labeler.yml diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml deleted file mode 100644 index 1490730b..00000000 --- a/.github/workflows/labeler.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: "Issue Labeler" -on: - issues: - types: [opened, edited, reopened] - pull_request: - types: [opened, edited, reopened, ready_for_review, synchronize] - -jobs: - triage: - runs-on: ubuntu-latest - name: Label issues and pull requests - steps: - - name: check out - uses: actions/checkout@v2 - - - name: labeler - uses: jinzhu/super-labeler-action@develop - with: - GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}" From 81c68db87fe8c4dc18a86caf198466d6fe29b0d3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jul 2020 17:56:16 +0800 Subject: [PATCH 566/881] Fix zero time failed on mysql 8 --- tests/scanner_valuer_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 2c2c1e18..63a7c63c 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -258,5 +258,5 @@ func (t *EmptyTime) Scan(v interface{}) error { } func (t EmptyTime) Value() (driver.Value, error) { - return t.Time, nil + return time.Now() /* pass tests, mysql 8 doesn't support 0000-00-00 by default */, nil } From dc299b900f5916c101b36b23edc77801ca76d056 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 31 Jul 2020 14:47:26 +0800 Subject: [PATCH 567/881] Use specified table when preloading data with Join --- callbacks/query.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 66bbf805..be829fbc 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -124,13 +124,13 @@ func BuildQuerySQL(db *gorm.DB) { for idx, ref := range relation.References { if ref.OwnPrimaryKey { exprs[idx] = clause.Eq{ - Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.PrimaryKey.DBName}, + Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, } } else { if ref.PrimaryValue == "" { exprs[idx] = clause.Eq{ - Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.ForeignKey.DBName}, + Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, } } else { From 2676fa4fb8e3c2b11c6bc72c1fb639c1586f6f3b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 31 Jul 2020 18:19:25 +0800 Subject: [PATCH 568/881] Remove autoincrement tag for join table, close #3217 --- schema/relationship.go | 4 ++-- schema/utils.go | 2 +- schema/utils_test.go | 1 + tests/postgres_test.go | 4 ++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index e67092b4..b7ab4f66 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -220,7 +220,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: ownField.StructField.PkgPath, Type: ownField.StructField.Type, - Tag: removeSettingFromTag(ownField.StructField.Tag, "column"), + Tag: removeSettingFromTag(removeSettingFromTag(ownField.StructField.Tag, "column"), "autoincrement"), }) } @@ -243,7 +243,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: relField.StructField.PkgPath, Type: relField.StructField.Type, - Tag: removeSettingFromTag(relField.StructField.Tag, "column"), + Tag: removeSettingFromTag(removeSettingFromTag(relField.StructField.Tag, "column"), "autoincrement"), }) } diff --git a/schema/utils.go b/schema/utils.go index defa83af..1481d428 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -50,7 +50,7 @@ func toColumns(val string) (results []string) { } func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag { - return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`:.*?)(;|("))`).ReplaceAllString(string(tag), "${1}${4}")) + return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}")) } // GetRelationsValues get relations's values from a reflect value diff --git a/schema/utils_test.go b/schema/utils_test.go index e70169bf..1b47ef25 100644 --- a/schema/utils_test.go +++ b/schema/utils_test.go @@ -13,6 +13,7 @@ func TestRemoveSettingFromTag(t *testing.T) { `gorm:"column:db" other:"before:value;column:db;after:value"`: `gorm:"" other:"before:value;column:db;after:value"`, `gorm:"before:value;column:db ;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`, `gorm:"before:value;column:db; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`, } for k, v := range tags { diff --git a/tests/postgres_test.go b/tests/postgres_test.go index a0b1fddb..85cd34d4 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -39,13 +39,13 @@ func TestPostgres(t *testing.T) { } type Post struct { - ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4()"` + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"` Title string Categories []*Category `gorm:"Many2Many:post_categories"` } type Category struct { - ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4()"` + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"` Title string Posts []*Post `gorm:"Many2Many:post_categories"` } From f83b00d20dd57bb0df964cacfefa8f7b259a09d3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 3 Aug 2020 10:30:25 +0800 Subject: [PATCH 569/881] Fix Count with Select when Model not specfied, close #3220 --- finisher_api.go | 11 +++++++++-- schema/schema.go | 4 ++++ tests/count_test.go | 12 ++++++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 77bea578..33a4f121 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -274,11 +274,18 @@ func (db *DB) Count(count *int64) (tx *DB) { expr := clause.Expr{SQL: "count(1)"} if len(tx.Statement.Selects) == 1 { + dbName := tx.Statement.Selects[0] if tx.Statement.Parse(tx.Statement.Model) == nil { - if f := tx.Statement.Schema.LookUpField(tx.Statement.Selects[0]); f != nil { - expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: f.DBName}}} + if f := tx.Statement.Schema.LookUpField(dbName); f != nil { + dbName = f.DBName } } + + if tx.Statement.Distinct { + expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}} + } else { + expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} + } } tx.Statement.AddClause(clause.Select{Expression: expr}) diff --git a/schema/schema.go b/schema/schema.go index 1106f0c5..9206c24e 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -72,6 +72,10 @@ type Tabler interface { // get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + if dest == nil { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + modelType := reflect.ValueOf(dest).Type() for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() diff --git a/tests/count_test.go b/tests/count_test.go index 826d6a36..05661ae8 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -2,6 +2,7 @@ package tests_test import ( "fmt" + "regexp" "testing" "gorm.io/gorm" @@ -55,4 +56,15 @@ func TestCount(t *testing.T) { if count3 != 2 { t.Errorf("Should get correct count for count with group, but got %v", count3) } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + result := dryDB.Table("users").Select("name").Count(&count) + if !regexp.MustCompile(`SELECT COUNT\(.name.\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Table("users").Distinct("name").Count(&count) + if !regexp.MustCompile(`SELECT COUNT\(DISTINCT\(.name.\)\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) + } } From c11c939b959c489c96bd6b5967b6a47c8b402ceb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 3 Aug 2020 21:48:36 +0800 Subject: [PATCH 570/881] callbacks support sort with wildcard --- callbacks.go | 16 ++++++++++++++-- gorm.go | 2 +- prepare_stmt.go | 34 +++++++++++++++++----------------- tests/callbacks_test.go | 8 ++++++++ 4 files changed, 40 insertions(+), 20 deletions(-) diff --git a/callbacks.go b/callbacks.go index c917a678..baeb6c09 100644 --- a/callbacks.go +++ b/callbacks.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "sort" "time" "gorm.io/gorm/logger" @@ -207,6 +208,9 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { names, sorted []string sortCallback func(*callback) error ) + sort.Slice(cs, func(i, j int) bool { + return cs[j].before == "*" || cs[j].after == "*" + }) for _, c := range cs { // show warning message the callback name already exists @@ -218,7 +222,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { sortCallback = func(c *callback) error { if c.before != "" { // if defined before callback - if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { + if c.before == "*" && len(sorted) > 0 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + sorted = append([]string{c.name}, sorted...) + } + } else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 { // if before callback already sorted, append current callback just after it sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) @@ -232,7 +240,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { } if c.after != "" { // if defined after callback - if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { + if c.after == "*" && len(sorted) > 0 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + sorted = append(sorted, c.name) + } + } else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 { // if after callback sorted, append current callback to last sorted = append(sorted, c.name) diff --git a/gorm.go b/gorm.go index c786b5a5..1ace0099 100644 --- a/gorm.go +++ b/gorm.go @@ -165,7 +165,7 @@ func (db *DB) Session(config *Session) *DB { preparedStmt := v.(*PreparedStmtDB) tx.Statement.ConnPool = &PreparedStmtDB{ ConnPool: db.Config.ConnPool, - mux: preparedStmt.mux, + Mux: preparedStmt.Mux, Stmts: preparedStmt.Stmts, } } diff --git a/prepare_stmt.go b/prepare_stmt.go index 2f4e1d57..7e87558d 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -9,12 +9,12 @@ import ( type PreparedStmtDB struct { Stmts map[string]*sql.Stmt PreparedSQL []string - mux sync.RWMutex + Mux sync.RWMutex ConnPool } func (db *PreparedStmtDB) Close() { - db.mux.Lock() + db.Mux.Lock() for _, query := range db.PreparedSQL { if stmt, ok := db.Stmts[query]; ok { delete(db.Stmts, query) @@ -22,21 +22,21 @@ func (db *PreparedStmtDB) Close() { } } - db.mux.Unlock() + db.Mux.Unlock() } func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { - db.mux.RLock() + db.Mux.RLock() if stmt, ok := db.Stmts[query]; ok { - db.mux.RUnlock() + db.Mux.RUnlock() return stmt, nil } - db.mux.RUnlock() + db.Mux.RUnlock() - db.mux.Lock() + db.Mux.Lock() // double check if stmt, ok := db.Stmts[query]; ok { - db.mux.Unlock() + db.Mux.Unlock() return stmt, nil } @@ -45,7 +45,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { db.Stmts[query] = stmt db.PreparedSQL = append(db.PreparedSQL, query) } - db.mux.Unlock() + db.Mux.Unlock() return stmt, err } @@ -63,10 +63,10 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. if err == nil { result, err = stmt.ExecContext(ctx, args...) if err != nil { - db.mux.Lock() + db.Mux.Lock() stmt.Close() delete(db.Stmts, query) - db.mux.Unlock() + db.Mux.Unlock() } } return result, err @@ -77,10 +77,10 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . if err == nil { rows, err = stmt.QueryContext(ctx, args...) if err != nil { - db.mux.Lock() + db.Mux.Lock() stmt.Close() delete(db.Stmts, query) - db.mux.Unlock() + db.Mux.Unlock() } } return rows, err @@ -104,10 +104,10 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. if err == nil { result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...) if err != nil { - tx.PreparedStmtDB.mux.Lock() + tx.PreparedStmtDB.Mux.Lock() stmt.Close() delete(tx.PreparedStmtDB.Stmts, query) - tx.PreparedStmtDB.mux.Unlock() + tx.PreparedStmtDB.Mux.Unlock() } } return result, err @@ -118,10 +118,10 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . if err == nil { rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) if err != nil { - tx.PreparedStmtDB.mux.Lock() + tx.PreparedStmtDB.Mux.Lock() stmt.Close() delete(tx.PreparedStmtDB.Stmts, query) - tx.PreparedStmtDB.mux.Unlock() + tx.PreparedStmtDB.Mux.Unlock() } } return rows, err diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 1dbae441..84f56165 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -96,6 +96,14 @@ func TestCallbacks(t *testing.T) { callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, results: []string{"c1", "c4", "c3"}, }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5, before: "*"}}, + results: []string{"c5", "c1", "c2", "c3", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, after: "*"}, {h: c4}, {h: c5, before: "*"}}, + results: []string{"c5", "c1", "c2", "c4", "c3"}, + }, } for idx, data := range datas { From ff985b90cc0f2f11be492300dd9f6914cba0cf22 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 4 Aug 2020 12:10:19 +0800 Subject: [PATCH 571/881] Fix failed to guess relations for embedded types, close #3224 --- migrator/migrator.go | 1 + schema/field.go | 2 + schema/relationship.go | 69 +++++++++++++++++++++++++++-------- tests/callbacks_test.go | 8 +++- tests/embedded_struct_test.go | 14 +++++++ 5 files changed, 76 insertions(+), 18 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 3e5d86d3..d50159dd 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -120,6 +120,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } return nil }); err != nil { + fmt.Println(err) return err } } diff --git a/schema/field.go b/schema/field.go index 4eb95b98..1ca4cb6d 100644 --- a/schema/field.go +++ b/schema/field.go @@ -62,6 +62,7 @@ type Field struct { TagSettings map[string]string Schema *Schema EmbeddedSchema *Schema + OwnerSchema *Schema ReflectValueOf func(reflect.Value) reflect.Value ValueOf func(reflect.Value) (value interface{}, zero bool) Set func(reflect.Value, interface{}) error @@ -321,6 +322,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } for _, ef := range field.EmbeddedSchema.Fields { ef.Schema = schema + ef.OwnerSchema = field.EmbeddedSchema ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) // index is negative means is pointer if field.FieldType.Kind() == reflect.Struct { diff --git a/schema/relationship.go b/schema/relationship.go index b7ab4f66..93080105 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -5,6 +5,7 @@ import ( "reflect" "regexp" "strings" + "sync" "github.com/jinzhu/inflection" "gorm.io/gorm/clause" @@ -66,9 +67,16 @@ func (schema *Schema) parseRelation(field *Field) { } ) - if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { - schema.err = err - return + if field.OwnerSchema != nil { + if relation.FieldSchema, err = Parse(fieldValue, &sync.Map{}, schema.namer); err != nil { + schema.err = err + return + } + } else { + if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { + schema.err = err + return + } } if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { @@ -78,7 +86,7 @@ func (schema *Schema) parseRelation(field *Field) { } else { switch field.IndirectFieldType.Kind() { case reflect.Struct, reflect.Slice: - schema.guessRelation(relation, field, true) + schema.guessRelation(relation, field, guessHas) default: schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) } @@ -316,21 +324,50 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } } -func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) { +type guessLevel int + +const ( + guessHas guessLevel = iota + guessEmbeddedHas + guessBelongs + guessEmbeddedBelongs +) + +func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) { var ( primaryFields, foreignFields []*Field primarySchema, foreignSchema = schema, relation.FieldSchema ) - if !guessHas { - primarySchema, foreignSchema = relation.FieldSchema, schema + reguessOrErr := func(err string, args ...interface{}) { + switch gl { + case guessHas: + schema.guessRelation(relation, field, guessEmbeddedHas) + case guessEmbeddedHas: + schema.guessRelation(relation, field, guessBelongs) + case guessBelongs: + schema.guessRelation(relation, field, guessEmbeddedBelongs) + default: + schema.err = fmt.Errorf(err, args...) + } } - reguessOrErr := func(err string, args ...interface{}) { - if guessHas { - schema.guessRelation(relation, field, false) + switch gl { + case guessEmbeddedHas: + if field.OwnerSchema != nil { + primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema } else { - schema.err = fmt.Errorf(err, args...) + reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) + return + } + case guessBelongs: + primarySchema, foreignSchema = relation.FieldSchema, schema + case guessEmbeddedBelongs: + if field.OwnerSchema != nil { + primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema + } else { + reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) + return } } @@ -345,8 +382,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH } } else { for _, primaryField := range primarySchema.PrimaryFields { - lookUpName := schema.Name + primaryField.Name - if !guessHas { + lookUpName := primarySchema.Name + primaryField.Name + if gl == guessBelongs { lookUpName = field.Name + primaryField.Name } @@ -358,7 +395,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH } if len(foreignFields) == 0 { - reguessOrErr("failed to guess %v's relations with %v's field %v 1 g %v", relation.FieldSchema, schema, field.Name, guessHas) + reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) return } else if len(relation.primaryKeys) > 0 { for idx, primaryKey := range relation.primaryKeys { @@ -394,11 +431,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], ForeignKey: foreignField, - OwnPrimaryKey: schema == primarySchema && guessHas, + OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas), }) } - if guessHas { + if gl == guessHas || gl == guessEmbeddedHas { relation.Type = "has" } else { relation.Type = BelongsTo diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 84f56165..02765b8c 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -101,8 +101,12 @@ func TestCallbacks(t *testing.T) { results: []string{"c5", "c1", "c2", "c3", "c4"}, }, { - callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, after: "*"}, {h: c4}, {h: c5, before: "*"}}, - results: []string{"c5", "c1", "c2", "c4", "c3"}, + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "*"}, {h: c4}, {h: c5, before: "*"}}, + results: []string{"c3", "c5", "c1", "c2", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c4", after: "*"}, {h: c4, after: "*"}, {h: c5, before: "*"}}, + results: []string{"c5", "c1", "c2", "c3", "c4"}, }, } diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 7f40a0a4..fb0d6f23 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -7,6 +7,7 @@ import ( "testing" "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" ) func TestEmbeddedStruct(t *testing.T) { @@ -152,3 +153,16 @@ func TestEmbeddedScanValuer(t *testing.T) { t.Errorf("Failed to create got error %v", err) } } + +func TestEmbeddedRelations(t *testing.T) { + type AdvancedUser struct { + User `gorm:"embedded"` + Advanced bool + } + + DB.Debug().Migrator().DropTable(&AdvancedUser{}) + + if err := DB.Debug().AutoMigrate(&AdvancedUser{}); err != nil { + t.Errorf("Failed to auto migrate advanced user, got error %v", err) + } +} From f962872b48fae9095c9309d1c94215c4636befe8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 5 Aug 2020 14:22:35 +0800 Subject: [PATCH 572/881] Fix labeler --- .github/workflows/labeler.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 .github/workflows/labeler.yml diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml new file mode 100644 index 00000000..bc1add53 --- /dev/null +++ b/.github/workflows/labeler.yml @@ -0,0 +1,19 @@ +name: "Issue Labeler" +on: + issues: + types: [opened, edited, reopened] + pull_request: + types: [opened, edited, reopened] + +jobs: + triage: + runs-on: ubuntu-latest + name: Label issues and pull requests + steps: + - name: check out + uses: actions/checkout@v2 + + - name: labeler + uses: jinzhu/super-labeler-action@develop + with: + GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}" From da1e54d5abb4482ca2accabbad0a1e1d65a9fc8f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 6 Aug 2020 15:37:36 +0800 Subject: [PATCH 573/881] Add sql-cli --- tests/tests_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tests_test.go b/tests/tests_test.go index 5aedc061..192160a0 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -64,6 +64,8 @@ func OpenTestConnection() (db *gorm.DB, err error) { // USE gorm; // CREATE USER gorm FROM LOGIN gorm; // sp_changedbowner 'gorm'; + // npm install -g sql-cli + // mssql -u gorm -p LoremIpsum86 -d gorm -o 9930 log.Println("testing sqlserver...") if dbDSN == "" { dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" From 3df249c127e637f8af6c99e5e4fed9c466803d79 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 6 Aug 2020 16:25:26 +0800 Subject: [PATCH 574/881] Use table expr when inserting table, close #3239 --- callbacks/create.go | 8 ++------ tests/go.mod | 4 ++-- tests/table_test.go | 11 +++++++++++ 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index b41a3ef2..3a414dd7 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -43,9 +43,7 @@ func Create(config *Config) func(db *gorm.DB) { if db.Statement.SQL.String() == "" { db.Statement.SQL.Grow(180) - db.Statement.AddClauseIfNotExists(clause.Insert{ - Table: clause.Table{Name: db.Statement.Table}, - }) + db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") @@ -105,9 +103,7 @@ func CreateWithReturning(db *gorm.DB) { } if db.Statement.SQL.String() == "" { - db.Statement.AddClauseIfNotExists(clause.Insert{ - Table: clause.Table{Name: db.Statement.Table}, - }) + db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") diff --git a/tests/go.mod b/tests/go.mod index 6eb6eb07..82d4fdc8 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,8 +8,8 @@ require ( github.com/lib/pq v1.6.0 gorm.io/driver/mysql v0.3.1 gorm.io/driver/postgres v0.2.6 - gorm.io/driver/sqlite v1.0.8 - gorm.io/driver/sqlserver v0.2.5 + gorm.io/driver/sqlite v1.0.9 + gorm.io/driver/sqlserver v0.2.6 gorm.io/gorm v0.2.19 ) diff --git a/tests/table_test.go b/tests/table_test.go index faee6499..647b5e19 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -40,6 +40,17 @@ func TestTable(t *testing.T) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } + r = dryDB.Create(&UserWithTable{}).Statement + if DB.Dialector.Name() != "sqlite" { + if !regexp.MustCompile(`INSERT INTO .gorm.\..user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } else { + if !regexp.MustCompile(`INSERT INTO .user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } + r = dryDB.Table("(?) as u", DB.Model(&User{}).Select("name")).Find(&User{}).Statement if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) From 39c8d6220b75b5a28dfff6ae88da17485b35dc46 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 6 Aug 2020 17:48:46 +0800 Subject: [PATCH 575/881] Fix soft delete panic when using unaddressable value --- soft_delete.go | 2 +- tests/delete_test.go | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/soft_delete.go b/soft_delete.go index 6b88b1a5..180bf745 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -64,7 +64,7 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) } - if stmt.Dest != stmt.Model && stmt.Model != nil { + if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) diff --git a/tests/delete_test.go b/tests/delete_test.go index 3d461f65..f5b3e784 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -43,6 +43,14 @@ func TestDelete(t *testing.T) { t.Errorf("no error should returns when query %v, but got %v", user.ID, err) } } + + if err := DB.Delete(users[0]).Error; err != nil { + t.Errorf("errors happened when delete: %v", err) + } + + if err := DB.Where("id = ?", users[0].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should returns record not found error, but got %v", err) + } } func TestDeleteWithTable(t *testing.T) { From 15b96ed3f482a29201b2c6c15fa0d3936d4d9a17 Mon Sep 17 00:00:00 2001 From: Caelansar Date: Mon, 10 Aug 2020 15:34:20 +0800 Subject: [PATCH 576/881] add testcase --- tests/scanner_valuer_test.go | 69 +++++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 16 deletions(-) diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 63a7c63c..6b8f086e 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -35,7 +35,9 @@ func TestScannerValuer(t *testing.T) { {"name1", "value1"}, {"name2", "value2"}, }, - Role: Role{Name: "admin"}, + Role: Role{Name: "admin"}, + ExampleStruct: ExampleStruct1{"name", "value"}, + ExampleStructPtr: &ExampleStruct1{"name", "value"}, } if err := DB.Create(&data).Error; err != nil { @@ -49,6 +51,14 @@ func TestScannerValuer(t *testing.T) { } AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs") + + if result.ExampleStructPtr.Val != "value" { + t.Errorf(`ExampleStructPtr.Val should equal to "value", but got %v`, result.ExampleStructPtr.Val) + } + + if result.ExampleStruct.Val != "value" { + t.Errorf(`ExampleStruct.Val should equal to "value", but got %v`, result.ExampleStruct.Val) + } } func TestScannerValuerWithFirstOrCreate(t *testing.T) { @@ -124,21 +134,23 @@ func TestInvalidValuer(t *testing.T) { type ScannerValuerStruct struct { gorm.Model - Name sql.NullString - Gender *sql.NullString - Age sql.NullInt64 - Male sql.NullBool - Height sql.NullFloat64 - Birthday sql.NullTime - Password EncryptedData - Bytes []byte - Num Num - Strings StringsSlice - Structs StructsSlice - Role Role - UserID *sql.NullInt64 - User User - EmptyTime EmptyTime + Name sql.NullString + Gender *sql.NullString + Age sql.NullInt64 + Male sql.NullBool + Height sql.NullFloat64 + Birthday sql.NullTime + Password EncryptedData + Bytes []byte + Num Num + Strings StringsSlice + Structs StructsSlice + Role Role + UserID *sql.NullInt64 + User User + EmptyTime EmptyTime + ExampleStruct ExampleStruct1 + ExampleStructPtr *ExampleStruct1 } type EncryptedData []byte @@ -207,6 +219,31 @@ type ExampleStruct struct { Value string } +type ExampleStruct1 struct { + Name string `json:"name,omitempty"` + Val string `json:"val,omitempty"` +} + +func (s ExampleStruct1) Value() (driver.Value, error) { + if len(s.Name) == 0 { + return nil, nil + } + //for test, has no practical meaning + s.Name = "" + return json.Marshal(s) +} + +func (s *ExampleStruct1) Scan(src interface{}) error { + switch value := src.(type) { + case string: + return json.Unmarshal([]byte(value), s) + case []byte: + return json.Unmarshal(value, s) + default: + return errors.New("not supported") + } +} + type StructsSlice []ExampleStruct func (l StructsSlice) Value() (driver.Value, error) { From 4a9d3a688aa47a7db7611902f6467f0b311aee79 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 11 Aug 2020 21:22:51 +0800 Subject: [PATCH 577/881] Don't parse ignored anonymous field --- schema/field.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/schema/field.go b/schema/field.go index 1ca4cb6d..ea6364a4 100644 --- a/schema/field.go +++ b/schema/field.go @@ -311,7 +311,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer) { + if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable)) { if reflect.Indirect(fieldValue).Kind() == reflect.Struct { var err error field.Creatable = false From a3dda47afac01b7430efb200d27473e24fe2fca9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 11 Aug 2020 21:22:51 +0800 Subject: [PATCH 578/881] Don't parse ignored anonymous field --- schema/field.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/schema/field.go b/schema/field.go index 1ca4cb6d..ea6364a4 100644 --- a/schema/field.go +++ b/schema/field.go @@ -311,7 +311,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer) { + if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable)) { if reflect.Indirect(fieldValue).Kind() == reflect.Struct { var err error field.Creatable = false From 7d45833f3e309f9c15bb9ca301c1782b23cb9f0e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 12:05:55 +0800 Subject: [PATCH 579/881] Fix driver.Valuer interface returns nil, close #3248 --- schema/field.go | 60 +++++++++++++++++------------------- tests/scanner_valuer_test.go | 52 ++++++++++++++++--------------- 2 files changed, 56 insertions(+), 56 deletions(-) diff --git a/schema/field.go b/schema/field.go index ea6364a4..84fdb695 100644 --- a/schema/field.go +++ b/schema/field.go @@ -731,40 +731,10 @@ func (field *Field) setupValuerAndSetter() { return nil } default: - if _, ok := fieldValue.Interface().(sql.Scanner); ok { - // struct scanner - field.Set = func(value reflect.Value, v interface{}) (err error) { - if valuer, ok := v.(driver.Valuer); ok { - v, _ = valuer.Value() - } - - reflectV := reflect.ValueOf(v) - if !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) - } else if reflectV.Kind() == reflect.Ptr { - if reflectV.Elem().IsNil() || !reflectV.Elem().IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) - } else { - return field.Set(value, reflectV.Elem().Interface()) - } - } else { - err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) - } - return - } - } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { + if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner field.Set = func(value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) - - if valuer, ok := v.(driver.Valuer); ok { - if valuer == nil || reflectV.IsNil() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) - } else { - v, _ = valuer.Value() - } - } - if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { @@ -778,10 +748,38 @@ func (field *Field) setupValuerAndSetter() { if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } + + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + err = fieldValue.Interface().(sql.Scanner).Scan(v) } return } + } else if _, ok := fieldValue.Interface().(sql.Scanner); ok { + // struct scanner + field.Set = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV) + } else if reflectV.Kind() == reflect.Ptr { + if reflectV.IsNil() || !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + return field.Set(value, reflectV.Elem().Interface()) + } + } else { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) + } + return + } } else { field.Set = func(value reflect.Value, v interface{}) (err error) { return fallbackSetter(value, v, field.Set) diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 6b8f086e..b8306af7 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -36,8 +36,8 @@ func TestScannerValuer(t *testing.T) { {"name2", "value2"}, }, Role: Role{Name: "admin"}, - ExampleStruct: ExampleStruct1{"name", "value"}, - ExampleStructPtr: &ExampleStruct1{"name", "value"}, + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, } if err := DB.Create(&data).Error; err != nil { @@ -46,19 +46,18 @@ func TestScannerValuer(t *testing.T) { var result ScannerValuerStruct - if err := DB.Find(&result).Error; err != nil { + if err := DB.Find(&result, "id = ?", data.ID).Error; err != nil { t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err) } + if result.ExampleStructPtr.Val != "value2" { + t.Errorf(`ExampleStructPtr.Val should equal to "value2", but got %v`, result.ExampleStructPtr.Val) + } + + if result.ExampleStruct.Val != "value1" { + t.Errorf(`ExampleStruct.Val should equal to "value1", but got %#v`, result.ExampleStruct) + } AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs") - - if result.ExampleStructPtr.Val != "value" { - t.Errorf(`ExampleStructPtr.Val should equal to "value", but got %v`, result.ExampleStructPtr.Val) - } - - if result.ExampleStruct.Val != "value" { - t.Errorf(`ExampleStruct.Val should equal to "value", but got %v`, result.ExampleStruct.Val) - } } func TestScannerValuerWithFirstOrCreate(t *testing.T) { @@ -68,9 +67,11 @@ func TestScannerValuerWithFirstOrCreate(t *testing.T) { } data := ScannerValuerStruct{ - Name: sql.NullString{String: "name", Valid: true}, - Gender: &sql.NullString{String: "M", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: true}, + Name: sql.NullString{String: "name", Valid: true}, + Gender: &sql.NullString{String: "M", Valid: true}, + Age: sql.NullInt64{Int64: 18, Valid: true}, + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, } var result ScannerValuerStruct @@ -109,7 +110,9 @@ func TestInvalidValuer(t *testing.T) { } data := ScannerValuerStruct{ - Password: EncryptedData("xpass1"), + Password: EncryptedData("xpass1"), + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, } if err := DB.Create(&data).Error; err == nil { @@ -149,8 +152,8 @@ type ScannerValuerStruct struct { UserID *sql.NullInt64 User User EmptyTime EmptyTime - ExampleStruct ExampleStruct1 - ExampleStructPtr *ExampleStruct1 + ExampleStruct ExampleStruct + ExampleStructPtr *ExampleStruct } type EncryptedData []byte @@ -215,25 +218,24 @@ func (l *StringsSlice) Scan(input interface{}) error { } type ExampleStruct struct { - Name string - Value string + Name string + Val string } -type ExampleStruct1 struct { - Name string `json:"name,omitempty"` - Val string `json:"val,omitempty"` +func (ExampleStruct) GormDataType() string { + return "bytes" } -func (s ExampleStruct1) Value() (driver.Value, error) { +func (s ExampleStruct) Value() (driver.Value, error) { if len(s.Name) == 0 { return nil, nil } - //for test, has no practical meaning + // for test, has no practical meaning s.Name = "" return json.Marshal(s) } -func (s *ExampleStruct1) Scan(src interface{}) error { +func (s *ExampleStruct) Scan(src interface{}) error { switch value := src.(type) { case string: return json.Unmarshal([]byte(value), s) From 045d5f853838b9800acdb8ae204969ba3d93e00a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 12:18:36 +0800 Subject: [PATCH 580/881] Fix count with join and no model, close #3255 --- callbacks/query.go | 2 +- tests/count_test.go | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/callbacks/query.go b/callbacks/query.go index be829fbc..5ae1e904 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -96,7 +96,7 @@ func BuildQuerySQL(db *gorm.DB) { // inline joins if len(db.Statement.Joins) != 0 { - if len(db.Statement.Selects) == 0 { + if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil { clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) for idx, dbName := range db.Statement.Schema.DBNames { clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} diff --git a/tests/count_test.go b/tests/count_test.go index 05661ae8..216fa3a1 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -67,4 +67,9 @@ func TestCount(t *testing.T) { if !regexp.MustCompile(`SELECT COUNT\(DISTINCT\(.name.\)\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) { t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) } + + var count4 int64 + if err := DB.Debug().Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 { + t.Errorf("count with join, got error: %v, count %v", err, count) + } } From ecc946be6e93a108bbdcc10cf2719d08baa50f3f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 16:05:06 +0800 Subject: [PATCH 581/881] Test update from sub query --- callbacks/update.go | 9 +++++++-- tests/update_test.go | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 12806af6..0ced3ffb 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -174,11 +174,16 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { sort.Strings(keys) for _, k := range keys { + kv := value[k] + if _, ok := kv.(*gorm.DB); ok { + kv = []interface{}{kv} + } + if stmt.Schema != nil { if field := stmt.Schema.LookUpField(k); field != nil { if field.DBName != "" { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv}) assignValue(field, value[k]) } } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { @@ -189,7 +194,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { - set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]}) + set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv}) } } diff --git a/tests/update_test.go b/tests/update_test.go index 2ff150dd..83a7b9a2 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -545,3 +545,21 @@ func TestUpdatesTableWithIgnoredValues(t *testing.T) { t.Errorf("element's ignored field should not be updated") } } + +func TestUpdateFromSubQuery(t *testing.T) { + user := *GetUser("update_from_sub_query", Config{Company: true}) + if err := DB.Create(&user).Error; err != nil { + t.Errorf("failed to create user, got error: %v", err) + } + + if err := DB.Model(&user).Update("name", DB.Model(&Company{}).Select("name").Where("companies.id = users.company_id")).Error; err != nil { + t.Errorf("failed to update with sub query, got error %v", err) + } + + var result User + DB.First(&result, user.ID) + + if result.Name != user.Company.Name { + t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) + } +} From dea93edb6acdccdb398a5f9d89412f9bd0be5b39 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 16:28:21 +0800 Subject: [PATCH 582/881] Copy TableExpr when clone statement --- statement.go | 1 + tests/update_test.go | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/statement.go b/statement.go index e9d826c4..b5b5db5a 100644 --- a/statement.go +++ b/statement.go @@ -392,6 +392,7 @@ func (stmt *Statement) Parse(value interface{}) (err error) { func (stmt *Statement) clone() *Statement { newStmt := &Statement{ + TableExpr: stmt.TableExpr, Table: stmt.Table, Model: stmt.Model, Dest: stmt.Dest, diff --git a/tests/update_test.go b/tests/update_test.go index 83a7b9a2..a59a8856 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -562,4 +562,14 @@ func TestUpdateFromSubQuery(t *testing.T) { if result.Name != user.Company.Name { t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) } + + DB.Model(&user.Company).Update("Name", "new company name") + if err := DB.Table("users").Where("1 = 1").Update("name", DB.Table("companies").Select("name").Where("companies.id = users.company_id")).Error; err != nil { + t.Errorf("failed to update with sub query, got error %v", err) + } + + DB.First(&result, user.ID) + if result.Name != "new company name" { + t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) + } } From 2c4e8571259bf6193cf5d396594104fca7fa727d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 18:09:04 +0800 Subject: [PATCH 583/881] Should ignore association conditions when querying with struct --- statement.go | 12 ++++++------ tests/query_test.go | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/statement.go b/statement.go index b5b5db5a..6114f468 100644 --- a/statement.go +++ b/statement.go @@ -309,10 +309,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c for _, field := range s.Fields { if field.Readable { if v, isZero := field.ValueOf(reflectValue); !isZero { - if field.DBName == "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) - } else { + if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) } } } @@ -322,10 +322,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c for _, field := range s.Fields { if field.Readable { if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { - if field.DBName == "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) - } else { + if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) } } } diff --git a/tests/query_test.go b/tests/query_test.go index 59f1130b..4c2a2abd 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -103,6 +103,22 @@ func TestFind(t *testing.T) { }) } +func TestQueryWithAssociation(t *testing.T) { + user := *GetUser("query_with_association", Config{Account: true, Pets: 2, Toys: 1, Company: true, Manager: true, Team: 2, Languages: 1, Friends: 3}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create user: %v", err) + } + + if err := DB.Where(&user).First(&User{}).Error; err != nil { + t.Errorf("search with struct with association should returns no error, but got %v", err) + } + + if err := DB.Where(user).First(&User{}).Error; err != nil { + t.Errorf("search with struct with association should returns no error, but got %v", err) + } +} + func TestFindInBatches(t *testing.T) { var users = []User{ *GetUser("find_in_batches", Config{}), From 2faff25dfbcfff9e3fb37c8fcf1a20a468f887a8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 18:38:39 +0800 Subject: [PATCH 584/881] Fix FirstOr(Init/Create) when assigning with association --- finisher_api.go | 67 +++++++++++++++++++++++++++++++-------------- tests/query_test.go | 2 ++ 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 33a4f121..8a3d4199 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -8,6 +8,7 @@ import ( "strings" "gorm.io/gorm/clause" + "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) @@ -132,19 +133,47 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat return } -func (tx *DB) assignExprsToValue(exprs []clause.Expression) { - for _, expr := range exprs { - if eq, ok := expr.(clause.Eq); ok { - switch column := eq.Column.(type) { - case string: - if field := tx.Statement.Schema.LookUpField(column); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) +func (tx *DB) assignInterfacesToValue(values ...interface{}) { + for _, value := range values { + switch v := value.(type) { + case []clause.Expression: + for _, expr := range v { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + if field := tx.Statement.Schema.LookUpField(column); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + case clause.Column: + if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + default: + } } - case clause.Column: - if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: + exprs := tx.Statement.BuildCondition(value) + tx.assignInterfacesToValue(exprs) + default: + if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil { + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Struct: + for _, f := range s.Fields { + if f.Readable { + if v, isZero := f.ValueOf(reflectValue); !isZero { + if field := tx.Statement.Schema.LookUpField(f.Name); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, v)) + } + } + } + } } - default: + } else if len(values) > 0 { + exprs := tx.Statement.BuildCondition(values[0], values[1:]...) + tx.assignInterfacesToValue(exprs) + return } } } @@ -154,22 +183,20 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { - tx.assignExprsToValue(where.Exprs) + tx.assignInterfacesToValue(where.Exprs) } } // initialize with attrs, conds if len(tx.Statement.attrs) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.attrs...) } tx.Error = nil } // initialize with attrs, conds if len(tx.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.assigns...) } return } @@ -180,20 +207,18 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { - tx.assignExprsToValue(where.Exprs) + tx.assignInterfacesToValue(where.Exprs) } } // initialize with attrs, conds if len(tx.Statement.attrs) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.attrs...) } // initialize with attrs, conds if len(tx.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.assigns...) } return tx.Create(dest) diff --git a/tests/query_test.go b/tests/query_test.go index 4c2a2abd..72dd89b9 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -110,6 +110,8 @@ func TestQueryWithAssociation(t *testing.T) { t.Fatalf("errors happened when create user: %v", err) } + user.CreatedAt = time.Time{} + user.UpdatedAt = time.Time{} if err := DB.Where(&user).First(&User{}).Error; err != nil { t.Errorf("search with struct with association should returns no error, but got %v", err) } From 6834c25cec6b037299970cc845de1a186e04ba1f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Aug 2020 12:02:41 +0800 Subject: [PATCH 585/881] Fix stack overflow for embedded self-referred associations, close #3269 --- schema/field.go | 8 +++++++- schema/model_test.go | 22 ++++++++++++++++++++++ schema/relationship.go | 17 +++++++---------- schema/schema_test.go | 6 ++++++ 4 files changed, 42 insertions(+), 11 deletions(-) diff --git a/schema/field.go b/schema/field.go index 84fdb695..78eeccdc 100644 --- a/schema/field.go +++ b/schema/field.go @@ -317,7 +317,13 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Creatable = false field.Updatable = false field.Readable = false - if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { + + cacheStore := schema.cacheStore + if _, embedded := schema.cacheStore.Load("embedded_cache_store"); !embedded { + cacheStore = &sync.Map{} + cacheStore.Store("embedded_cache_store", true) + } + if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil { schema.err = err } for _, ef := range field.EmbeddedSchema.Fields { diff --git a/schema/model_test.go b/schema/model_test.go index a13372b5..84c7b327 100644 --- a/schema/model_test.go +++ b/schema/model_test.go @@ -39,3 +39,25 @@ type AdvancedDataTypeUser struct { Active mybool Admin *mybool } + +type BaseModel struct { + ID uint `gorm:"primarykey"` + CreatedAt time.Time + CreatedBy *int + Created *VersionUser `gorm:"foreignKey:CreatedBy"` + UpdatedAt time.Time + DeletedAt gorm.DeletedAt `gorm:"index"` +} + +type VersionModel struct { + BaseModel + Version int + CompanyID int +} + +type VersionUser struct { + VersionModel + Name string + Age uint + Birthday *time.Time +} diff --git a/schema/relationship.go b/schema/relationship.go index 93080105..537a3582 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -5,7 +5,6 @@ import ( "reflect" "regexp" "strings" - "sync" "github.com/jinzhu/inflection" "gorm.io/gorm/clause" @@ -67,16 +66,14 @@ func (schema *Schema) parseRelation(field *Field) { } ) + cacheStore := schema.cacheStore if field.OwnerSchema != nil { - if relation.FieldSchema, err = Parse(fieldValue, &sync.Map{}, schema.namer); err != nil { - schema.err = err - return - } - } else { - if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { - schema.err = err - return - } + cacheStore = field.OwnerSchema.cacheStore + } + + if relation.FieldSchema, err = Parse(fieldValue, cacheStore, schema.namer); err != nil { + schema.err = err + return } if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { diff --git a/schema/schema_test.go b/schema/schema_test.go index 99781e47..966f80e4 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -160,3 +160,9 @@ func TestCustomizeTableName(t *testing.T) { t.Errorf("Failed to customize table with TableName method") } } + +func TestNestedModel(t *testing.T) { + if _, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{}); err != nil { + t.Fatalf("failed to parse nested user, got error %v", err) + } +} From 2a716e04e6528f1979dc0a7a2de509f0350e9e04 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Aug 2020 12:16:42 +0800 Subject: [PATCH 586/881] Avoid panic for invalid transaction, close #3271 --- finisher_api.go | 6 ++++-- tests/transaction_test.go | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 8a3d4199..19534460 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -445,7 +445,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { // Commit commit a transaction func (db *DB) Commit() *DB { - if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { db.AddError(committer.Commit()) } else { db.AddError(ErrInvalidTransaction) @@ -456,7 +456,9 @@ func (db *DB) Commit() *DB { // Rollback rollback a transaction func (db *DB) Rollback() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { - db.AddError(committer.Rollback()) + if !reflect.ValueOf(committer).IsNil() { + db.AddError(committer.Rollback()) + } } else { db.AddError(ErrInvalidTransaction) } diff --git a/tests/transaction_test.go b/tests/transaction_test.go index c101388a..aea151d9 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "context" "errors" "testing" @@ -57,6 +58,25 @@ func TestTransaction(t *testing.T) { } } +func TestCancelTransaction(t *testing.T) { + ctx := context.Background() + ctx, cancelFunc := context.WithCancel(ctx) + cancelFunc() + + user := *GetUser("cancel_transaction", Config{}) + DB.Create(&user) + + err := DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var result User + tx.First(&result, user.ID) + return nil + }) + + if err == nil { + t.Fatalf("Transaction should get error when using cancelled context") + } +} + func TestTransactionWithBlock(t *testing.T) { assertPanic := func(f func()) { defer func() { From 681268cc43a2aa665e5577680b88ac77b9e5b64c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Aug 2020 16:31:09 +0800 Subject: [PATCH 587/881] Refactor Create/Query/Update/DeleteClauses interface --- schema/field.go | 22 -------------------- schema/interfaces.go | 8 ++++---- schema/schema.go | 17 ++++++++++++++++ soft_delete.go | 48 +++++++++++++++++++++++++++++++++----------- 4 files changed, 57 insertions(+), 38 deletions(-) diff --git a/schema/field.go b/schema/field.go index 78eeccdc..bc47e543 100644 --- a/schema/field.go +++ b/schema/field.go @@ -88,23 +88,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } fieldValue := reflect.New(field.IndirectFieldType) - - if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { - field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses()...) - } - - if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { - field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses()...) - } - - if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses()...) - } - - if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses()...) - } - // if field is valuer, used its value or first fields as data type valuer, isValuer := fieldValue.Interface().(driver.Valuer) if isValuer { @@ -353,11 +336,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.TagSettings[k] = v } } - - field.Schema.CreateClauses = append(field.Schema.CreateClauses, field.EmbeddedSchema.CreateClauses...) - field.Schema.QueryClauses = append(field.Schema.QueryClauses, field.EmbeddedSchema.QueryClauses...) - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, field.EmbeddedSchema.UpdateClauses...) - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, field.EmbeddedSchema.DeleteClauses...) } else { schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) } diff --git a/schema/interfaces.go b/schema/interfaces.go index f5d07843..e8e51e4c 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -7,17 +7,17 @@ type GormDataTypeInterface interface { } type CreateClausesInterface interface { - CreateClauses() []clause.Interface + CreateClauses(*Field) []clause.Interface } type QueryClausesInterface interface { - QueryClauses() []clause.Interface + QueryClauses(*Field) []clause.Interface } type UpdateClausesInterface interface { - UpdateClauses() []clause.Interface + UpdateClauses(*Field) []clause.Interface } type DeleteClausesInterface interface { - DeleteClauses() []clause.Interface + DeleteClauses(*Field) []clause.Interface } diff --git a/schema/schema.go b/schema/schema.go index 9206c24e..d81da4b8 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -219,6 +219,23 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) return schema, schema.err } } + + fieldValue := reflect.New(field.IndirectFieldType) + if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } } } diff --git a/soft_delete.go b/soft_delete.go index 180bf745..875623bc 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -24,37 +24,61 @@ func (n DeletedAt) Value() (driver.Value, error) { return n.Time, nil } -func (DeletedAt) QueryClauses() []clause.Interface { +func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { return []clause.Interface{ clause.Where{Exprs: []clause.Expression{ clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: "deleted_at"}, + Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: nil, }, }}, } } -func (DeletedAt) DeleteClauses() []clause.Interface { - return []clause.Interface{SoftDeleteClause{}} +type SoftDeleteQueryClause struct { + Field *schema.Field } -type SoftDeleteClause struct { -} - -func (SoftDeleteClause) Name() string { +func (sd SoftDeleteQueryClause) Name() string { return "" } -func (SoftDeleteClause) Build(clause.Builder) { +func (sd SoftDeleteQueryClause) Build(clause.Builder) { } -func (SoftDeleteClause) MergeClause(*clause.Clause) { +func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) { } -func (SoftDeleteClause) ModifyStatement(stmt *Statement) { +func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { + if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{ + clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil}, + }}) + stmt.Clauses["soft_delete_enabled"] = clause.Clause{} + } +} + +func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteDeleteClause{Field: f}} +} + +type SoftDeleteDeleteClause struct { + Field *schema.Field +} + +func (sd SoftDeleteDeleteClause) Name() string { + return "" +} + +func (sd SoftDeleteDeleteClause) Build(clause.Builder) { +} + +func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { +} + +func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { if stmt.SQL.String() == "" { - stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: stmt.DB.NowFunc()}}) + stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: stmt.DB.NowFunc()}}) if stmt.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) From 9fcc337bd1ccfccfddcdbd4a9b8b08ad08bf465c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Aug 2020 17:41:36 +0800 Subject: [PATCH 588/881] Fix create from map --- callbacks/associations.go | 59 ++++++++++++++++++++++++--------------- callbacks/create.go | 22 ++++++++++++--- callbacks/helper.go | 10 ++++++- tests/create_test.go | 39 ++++++++++++++++++++++++++ tests/go.mod | 2 +- 5 files changed, 103 insertions(+), 29 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 3508335a..2710ffe9 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -48,14 +48,19 @@ func SaveBeforeAssociations(db *gorm.DB) { elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) - if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value - rv := rel.Field.ReflectValueOf(obj) // relation reflect value - objs = append(objs, obj) - if isPtr { - elems = reflect.Append(elems, rv) - } else { - elems = reflect.Append(elems, rv.Addr()) + + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(obj) // relation reflect value + objs = append(objs, obj) + if isPtr { + elems = reflect.Append(elems, rv) + } else { + elems = reflect.Append(elems, rv.Addr()) + } } + } else { + break } } @@ -112,22 +117,24 @@ func SaveAfterAssociations(db *gorm.DB) { for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) - if _, zero := rel.Field.ValueOf(obj); !zero { - rv := rel.Field.ReflectValueOf(obj) - if rv.Kind() != reflect.Ptr { - rv = rv.Addr() - } - - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - db.AddError(ref.ForeignKey.Set(rv, fv)) - } else if ref.PrimaryValue != "" { - db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(obj); !zero { + rv := rel.Field.ReflectValueOf(obj) + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() } - } - elems = reflect.Append(elems, rv) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(obj) + db.AddError(ref.ForeignKey.Set(rv, fv)) + } else if ref.PrimaryValue != "" { + db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + } + } + + elems = reflect.Append(elems, rv) + } } } @@ -207,7 +214,10 @@ func SaveAfterAssociations(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - appendToElems(db.Statement.ReflectValue.Index(i)) + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } } case reflect.Struct: appendToElems(db.Statement.ReflectValue) @@ -277,7 +287,10 @@ func SaveAfterAssociations(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - appendToElems(db.Statement.ReflectValue.Index(i)) + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } } case reflect.Struct: appendToElems(db.Statement.ReflectValue) diff --git a/callbacks/create.go b/callbacks/create.go index 3a414dd7..4cc0f555 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -61,16 +61,26 @@ func Create(config *Config) func(db *gorm.DB) { case reflect.Slice, reflect.Array: if config.LastInsertIDReversed { for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)) + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) insertID-- } } } else { for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) insertID++ } } @@ -140,6 +150,10 @@ func CreateWithReturning(db *gorm.DB) { for rows.Next() { BEGIN: reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected)) + if reflect.Indirect(reflectValue).Kind() != reflect.Struct { + break + } + for idx, field := range fields { fieldValue := field.ReflectValueOf(reflectValue) diff --git a/callbacks/helper.go b/callbacks/helper.go index 7bd910f6..80fbc2a1 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -26,6 +26,10 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { values.Columns = append(values.Columns, clause.Column{Name: k}) + if len(values.Values) == 0 { + values.Values = [][]interface{}{{}} + } + values.Values[0] = append(values.Values[0], value) } } @@ -61,11 +65,15 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st sort.Strings(columns) values.Values = make([][]interface{}, len(mapValues)) + values.Columns = make([]clause.Column, len(columns)) for idx, column := range columns { + values.Columns[idx] = clause.Column{Name: column} + for i, v := range result[column] { - if i == 0 { + if len(values.Values[i]) == 0 { values.Values[i] = make([]interface{}, len(columns)) } + values.Values[i][idx] = v } } diff --git a/tests/create_test.go b/tests/create_test.go index ae6e1232..ab0a78d4 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -39,6 +39,45 @@ func TestCreate(t *testing.T) { } } +func TestCreateFromMap(t *testing.T) { + if err := DB.Model(&User{}).Create(map[string]interface{}{"Name": "create_from_map", "Age": 18}).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + var result User + if err := DB.Where("name = ?", "create_from_map").First(&result).Error; err != nil || result.Age != 18 { + t.Fatalf("failed to create from map, got error %v", err) + } + + if err := DB.Model(&User{}).Create(map[string]interface{}{"name": "create_from_map_1", "age": 18}).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + var result1 User + if err := DB.Where("name = ?", "create_from_map_1").First(&result1).Error; err != nil || result1.Age != 18 { + t.Fatalf("failed to create from map, got error %v", err) + } + + datas := []map[string]interface{}{ + {"Name": "create_from_map_2", "Age": 19}, + {"name": "create_from_map_3", "Age": 20}, + } + + if err := DB.Model(&User{}).Create(datas).Error; err != nil { + t.Fatalf("failed to create data from slice of map, got error: %v", err) + } + + var result2 User + if err := DB.Where("name = ?", "create_from_map_2").First(&result2).Error; err != nil || result2.Age != 19 { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } + + var result3 User + if err := DB.Where("name = ?", "create_from_map_3").First(&result3).Error; err != nil || result3.Age != 20 { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } +} + func TestCreateWithAssociations(t *testing.T) { var user = *GetUser("create_with_associations", Config{ Account: true, diff --git a/tests/go.mod b/tests/go.mod index 82d4fdc8..54a808d0 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( gorm.io/driver/mysql v0.3.1 gorm.io/driver/postgres v0.2.6 gorm.io/driver/sqlite v1.0.9 - gorm.io/driver/sqlserver v0.2.6 + gorm.io/driver/sqlserver v0.2.7 gorm.io/gorm v0.2.19 ) From dc48e04896aa529bb4014390347e21e2c4c509b2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Aug 2020 11:21:40 +0800 Subject: [PATCH 589/881] Fix nested embedded struct, close #3278 --- schema/field.go | 8 +++----- schema/model_test.go | 5 ++--- schema/schema.go | 37 ++++++++++++++++++----------------- schema/schema_test.go | 18 ++++++++++++++++- schema/utils.go | 2 ++ tests/embedded_struct_test.go | 4 ++-- utils/tests/utils.go | 2 +- 7 files changed, 46 insertions(+), 30 deletions(-) diff --git a/schema/field.go b/schema/field.go index bc47e543..35c1e44d 100644 --- a/schema/field.go +++ b/schema/field.go @@ -301,14 +301,12 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Updatable = false field.Readable = false - cacheStore := schema.cacheStore - if _, embedded := schema.cacheStore.Load("embedded_cache_store"); !embedded { - cacheStore = &sync.Map{} - cacheStore.Store("embedded_cache_store", true) - } + cacheStore := &sync.Map{} + cacheStore.Store(embeddedCacheKey, true) if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil { schema.err = err } + for _, ef := range field.EmbeddedSchema.Fields { ef.Schema = schema ef.OwnerSchema = field.EmbeddedSchema diff --git a/schema/model_test.go b/schema/model_test.go index 84c7b327..1f2b0948 100644 --- a/schema/model_test.go +++ b/schema/model_test.go @@ -41,7 +41,7 @@ type AdvancedDataTypeUser struct { } type BaseModel struct { - ID uint `gorm:"primarykey"` + ID uint CreatedAt time.Time CreatedBy *int Created *VersionUser `gorm:"foreignKey:CreatedBy"` @@ -51,8 +51,7 @@ type BaseModel struct { type VersionModel struct { BaseModel - Version int - CompanyID int + Version int } type VersionUser struct { diff --git a/schema/schema.go b/schema/schema.go index d81da4b8..458256d1 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -212,29 +212,30 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded { - // parse relations for unidentified fields - for _, field := range schema.Fields { - if field.DataType == "" && field.Creatable { - if schema.parseRelation(field); schema.err != nil { - return schema, schema.err + if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { + for _, field := range schema.Fields { + if field.DataType == "" && field.Creatable { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err + } } - } - fieldValue := reflect.New(field.IndirectFieldType) - if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { - field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) - } + fieldValue := reflect.New(field.IndirectFieldType) + if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } - if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { - field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) - } + if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } - if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) - } + if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } - if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } } } } diff --git a/schema/schema_test.go b/schema/schema_test.go index 966f80e4..c0ad3c25 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -162,7 +162,23 @@ func TestCustomizeTableName(t *testing.T) { } func TestNestedModel(t *testing.T) { - if _, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{}); err != nil { + versionUser, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{}) + + if err != nil { t.Fatalf("failed to parse nested user, got error %v", err) } + + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"VersionModel", "BaseModel", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, + {Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Int, Size: 64}, + {Name: "Version", DBName: "version", BindNames: []string{"VersionModel", "Version"}, DataType: schema.Int, Size: 64}, + } + + for _, f := range fields { + checkSchemaField(t, versionUser, &f, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + f.Readable = true + }) + } } diff --git a/schema/utils.go b/schema/utils.go index 1481d428..29f2fefb 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -9,6 +9,8 @@ import ( "gorm.io/gorm/utils" ) +var embeddedCacheKey = "embedded_cache_store" + func ParseTagSetting(str string, sep string) map[string]string { settings := map[string]string{} names := strings.Split(str, sep) diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index fb0d6f23..c29078bd 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -160,9 +160,9 @@ func TestEmbeddedRelations(t *testing.T) { Advanced bool } - DB.Debug().Migrator().DropTable(&AdvancedUser{}) + DB.Migrator().DropTable(&AdvancedUser{}) - if err := DB.Debug().AutoMigrate(&AdvancedUser{}); err != nil { + if err := DB.AutoMigrate(&AdvancedUser{}); err != nil { t.Errorf("Failed to auto migrate advanced user, got error %v", err) } } diff --git a/utils/tests/utils.go b/utils/tests/utils.go index 0067d5c6..817e4b0b 100644 --- a/utils/tests/utils.go +++ b/utils/tests/utils.go @@ -76,7 +76,7 @@ func AssertEqual(t *testing.T, got, expect interface{}) { } } else { name := reflect.ValueOf(got).Type().Elem().Name() - t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len()) + t.Errorf("%v expects length: %v, got %v (expects: %+v, got %+v)", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len(), expect, got) } return } From 50826742fd0bd26caf55a7a5a96b2c85b612f4ae Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Aug 2020 18:00:36 +0800 Subject: [PATCH 590/881] Add error gorm.ErrInvalidData --- callbacks/create.go | 2 ++ callbacks/update.go | 2 ++ errors.go | 2 ++ tests/update_test.go | 9 +++++++++ 4 files changed, 15 insertions(+) diff --git a/callbacks/create.go b/callbacks/create.go index 4cc0f555..7a32ed5c 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -309,6 +309,8 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } } } + default: + stmt.AddError(gorm.ErrInvalidData) } } diff --git a/callbacks/update.go b/callbacks/update.go index 0ced3ffb..5656d166 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -252,6 +252,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } } + default: + stmt.AddError(gorm.ErrInvalidData) } } diff --git a/errors.go b/errors.go index 115b8e25..32ff8ec1 100644 --- a/errors.go +++ b/errors.go @@ -19,6 +19,8 @@ var ( ErrPrimaryKeyRequired = errors.New("primary key required") // ErrModelValueRequired model value required ErrModelValueRequired = errors.New("model value required") + // ErrInvalidData unsupported data + ErrInvalidData = errors.New("unsupported data") // ErrUnsupportedDriver unsupported driver ErrUnsupportedDriver = errors.New("unsupported driver") // ErrRegistered registered diff --git a/tests/update_test.go b/tests/update_test.go index a59a8856..49a13be9 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -334,6 +334,15 @@ func TestSelectWithUpdateWithMap(t *testing.T) { AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages") } +func TestWithUpdateWithInvalidMap(t *testing.T) { + user := *GetUser("update_with_invalid_map", Config{}) + DB.Create(&user) + + if err := DB.Model(&user).Updates(map[string]string{"name": "jinzhu"}).Error; !errors.Is(err, gorm.ErrInvalidData) { + t.Errorf("should returns error for unsupported updating data") + } +} + func TestOmitWithUpdate(t *testing.T) { user := *GetUser("omit_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) DB.Create(&user) From b5de8aeb425cc9eccf92b8c3252fc0a7201ed52e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Aug 2020 18:58:53 +0800 Subject: [PATCH 591/881] Fix overrite SELECT clause --- chainable_api.go | 3 +++ finisher_api.go | 2 +- tests/query_test.go | 5 +++++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/chainable_api.go b/chainable_api.go index 4df8780e..78724cc8 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -91,6 +91,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { return } } + delete(tx.Statement.Clauses, "SELECT") case string: fields := strings.FieldsFunc(v, utils.IsChar) @@ -112,6 +113,8 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { return } } + + delete(tx.Statement.Clauses, "SELECT") } else { tx.Statement.AddClause(clause.Select{ Distinct: db.Statement.Distinct, diff --git a/finisher_api.go b/finisher_api.go index 19534460..88873948 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -294,7 +294,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 0 { tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) - defer tx.Statement.AddClause(clause.Select{}) + defer delete(tx.Statement.Clauses, "SELECT") } else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") { expr := clause.Expr{SQL: "count(1)"} diff --git a/tests/query_test.go b/tests/query_test.go index 72dd89b9..d71c813a 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -346,6 +346,11 @@ func TestSelect(t *testing.T) { if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) { t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String()) } + + r = dryDB.Select("count(*)").Select("u.*").Table("users as u").First(&User{}, user.ID) + if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String()) + } } func TestOmit(t *testing.T) { From 3411425d651e540cf19f9845d83cc507d929f2e6 Mon Sep 17 00:00:00 2001 From: deepoli <67894732+deepoil@users.noreply.github.com> Date: Tue, 18 Aug 2020 20:03:09 +0900 Subject: [PATCH 592/881] fix return value and delete unused default (#3280) --- chainable_api.go | 2 +- finisher_api.go | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 78724cc8..9b46a95b 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -69,7 +69,7 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) { if len(args) > 0 { tx = tx.Select(args[0], args[1:]...) } - return tx + return } // Select specify fields that you want when querying, creating, updating diff --git a/finisher_api.go b/finisher_api.go index 88873948..db069c5c 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -148,7 +148,6 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) } - default: } } } From c1782d60c149483111b021e29c412d9139bd46ba Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 19 Aug 2020 15:47:08 +0800 Subject: [PATCH 593/881] Fix embedded scanner/valuer, close #3283 --- schema/field.go | 34 +++++++++++++++++++++------------- tests/scanner_valuer_test.go | 6 ++++++ 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/schema/field.go b/schema/field.go index 35c1e44d..59367399 100644 --- a/schema/field.go +++ b/schema/field.go @@ -92,32 +92,40 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { valuer, isValuer := fieldValue.Interface().(driver.Valuer) if isValuer { if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { - var overrideFieldValue bool - if v, err := valuer.Value(); v != nil && err == nil { - overrideFieldValue = true + if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil { fieldValue = reflect.ValueOf(v) } - if field.IndirectFieldType.Kind() == reflect.Struct { - for i := 0; i < field.IndirectFieldType.NumField(); i++ { - if !overrideFieldValue { - newFieldType := field.IndirectFieldType.Field(i).Type + var getRealFieldValue func(reflect.Value) + getRealFieldValue = func(v reflect.Value) { + rv := reflect.Indirect(v) + if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + for i := 0; i < rv.Type().NumField(); i++ { + newFieldType := rv.Type().Field(i).Type for newFieldType.Kind() == reflect.Ptr { newFieldType = newFieldType.Elem() } fieldValue = reflect.New(newFieldType) - overrideFieldValue = true - } - // copy tag settings from valuer - for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { - if _, ok := field.TagSettings[key]; !ok { - field.TagSettings[key] = value + if rv.Type() != reflect.Indirect(fieldValue).Type() { + getRealFieldValue(fieldValue) + } + + if fieldValue.IsValid() { + return + } + + for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { + if _, ok := field.TagSettings[key]; !ok { + field.TagSettings[key] = value + } } } } } + + getRealFieldValue(fieldValue) } } diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index b8306af7..ce8a2b50 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -27,6 +27,7 @@ func TestScannerValuer(t *testing.T) { Male: sql.NullBool{Bool: true, Valid: true}, Height: sql.NullFloat64{Float64: 1.8888, Valid: true}, Birthday: sql.NullTime{Time: time.Now(), Valid: true}, + Allergen: NullString{sql.NullString{String: "Allergen", Valid: true}}, Password: EncryptedData("pass1"), Bytes: []byte("byte"), Num: 18, @@ -143,6 +144,7 @@ type ScannerValuerStruct struct { Male sql.NullBool Height sql.NullFloat64 Birthday sql.NullTime + Allergen NullString Password EncryptedData Bytes []byte Num Num @@ -299,3 +301,7 @@ func (t *EmptyTime) Scan(v interface{}) error { func (t EmptyTime) Value() (driver.Value, error) { return time.Now() /* pass tests, mysql 8 doesn't support 0000-00-00 by default */, nil } + +type NullString struct { + sql.NullString +} From 3313c11888538af30abed9b168550b426a4af082 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 19 Aug 2020 19:02:32 +0800 Subject: [PATCH 594/881] Fix embedded struct containing field named ID, close #3286 --- schema/field.go | 8 ++++++++ schema/schema_helper_test.go | 9 +++++++-- schema/schema_test.go | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/schema/field.go b/schema/field.go index 59367399..de937132 100644 --- a/schema/field.go +++ b/schema/field.go @@ -336,6 +336,14 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.PrimaryKey = true } else { ef.PrimaryKey = false + + if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { + ef.AutoIncrement = false + } + + if ef.DefaultValue == "" { + ef.HasDefaultValue = false + } } for k, v := range field.TagSettings { diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index f202b487..4e916f84 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -49,7 +49,12 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* } } - if parsedField, ok := s.FieldsByName[f.Name]; !ok { + parsedField, ok := s.FieldsByDBName[f.DBName] + if !ok { + parsedField, ok = s.FieldsByName[f.Name] + } + + if !ok { t.Errorf("schema %v failed to look up field with name %v", s, f.Name) } else { tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") @@ -62,7 +67,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* for _, name := range []string{f.DBName, f.Name} { if name != "" { - if field := s.LookUpField(name); field == nil || parsedField != field { + if field := s.LookUpField(name); field == nil || (field.Name != name && field.DBName != name) { t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) } } diff --git a/schema/schema_test.go b/schema/schema_test.go index c0ad3c25..c28812af 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -182,3 +182,35 @@ func TestNestedModel(t *testing.T) { }) } } + +func TestEmbeddedStruct(t *testing.T) { + type Company struct { + ID int + Name string + } + + type Corp struct { + ID uint + Base Company `gorm:"embedded;embeddedPrefix:company_"` + } + + cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, schema.NamingStrategy{}) + + if err != nil { + t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) + } + + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, + {Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + } + + for _, f := range fields { + checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + f.Readable = true + }) + } +} From 528e5ba5c41b647367d48e527b9fe9ad7dfcdd72 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 19 Aug 2020 20:30:39 +0800 Subject: [PATCH 595/881] Cleanup Model after Count --- finisher_api.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/finisher_api.go b/finisher_api.go index db069c5c..cf46f78a 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -289,6 +289,9 @@ func (db *DB) Count(count *int64) (tx *DB) { tx = db.getInstance() if tx.Statement.Model == nil { tx.Statement.Model = tx.Statement.Dest + defer func() { + tx.Statement.Model = nil + }() } if len(tx.Statement.Selects) == 0 { From 0c9870d1ae52a466837daf7f8386e3f2c0c1505c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 20 Aug 2020 10:39:01 +0800 Subject: [PATCH 596/881] Test Association Mode with conditions --- tests/associations_has_many_test.go | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index d8befd8a..173e9231 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -21,6 +21,23 @@ func TestHasManyAssociation(t *testing.T) { DB.Model(&user2).Association("Pets").Find(&user2.Pets) CheckUser(t, user2, user) + var pets []Pet + DB.Model(&user).Where("name = ?", user.Pets[0].Name).Association("Pets").Find(&pets) + + if len(pets) != 1 { + t.Fatalf("should only find one pets, but got %v", len(pets)) + } + + CheckPet(t, pets[0], *user.Pets[0]) + + if count := DB.Model(&user).Where("name = ?", user.Pets[1].Name).Association("Pets").Count(); count != 1 { + t.Fatalf("should only find one pets, but got %v", count) + } + + if count := DB.Model(&user).Where("name = ?", "not found").Association("Pets").Count(); count != 0 { + t.Fatalf("should only find no pet with invalid conditions, but got %v", count) + } + // Count AssertAssociationCount(t, user, "Pets", 2, "") @@ -40,13 +57,13 @@ func TestHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Pets", 3, "AfterAppend") - var pets = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} + var pets2 = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} - if err := DB.Model(&user2).Association("Pets").Append(&pets); err != nil { + if err := DB.Model(&user2).Association("Pets").Append(&pets2); err != nil { t.Fatalf("Error happened when append pet, got %v", err) } - for _, pet := range pets { + for _, pet := range pets2 { var pet = pet if pet.ID == 0 { t.Fatalf("Pet's ID should be created") From 06de6e8834baf8ed56230727cdf715809e2c7f27 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 20 Aug 2020 10:58:35 +0800 Subject: [PATCH 597/881] Test same field name from embedded field, close #3291 --- schema/schema_helper_test.go | 2 +- schema/schema_test.go | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 4e916f84..cc0306e0 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -57,7 +57,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* if !ok { t.Errorf("schema %v failed to look up field with name %v", s, f.Name) } else { - tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") + tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "TagSettings") if f.DBName != "" { if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { diff --git a/schema/schema_test.go b/schema/schema_test.go index c28812af..8bd1e5ca 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -4,6 +4,7 @@ import ( "sync" "testing" + "gorm.io/gorm" "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" ) @@ -184,13 +185,19 @@ func TestNestedModel(t *testing.T) { } func TestEmbeddedStruct(t *testing.T) { + type CorpBase struct { + gorm.Model + OwnerID string + } + type Company struct { - ID int - Name string + ID int + OwnerID int + Name string } type Corp struct { - ID uint + CorpBase Base Company `gorm:"embedded;embeddedPrefix:company_"` } @@ -201,9 +208,11 @@ func TestEmbeddedStruct(t *testing.T) { } fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, + {Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, {Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, } for _, f := range fields { From f88e8b072c6e9dc5ecb0530823ee957f9cff5f6f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 20 Aug 2020 18:13:29 +0800 Subject: [PATCH 598/881] Check valid pointer before use it as Valuer --- schema/field.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/schema/field.go b/schema/field.go index de937132..497aa02d 100644 --- a/schema/field.go +++ b/schema/field.go @@ -473,16 +473,16 @@ func (field *Field) setupValuerAndSetter() { } } - if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - err = setter(value, v) - } - } else if reflectV.Kind() == reflect.Ptr { + if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { err = setter(value, reflectV.Elem().Interface()) } + } else if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = setter(value, v) + } } else { return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) } From 2b510d6423f6299d53eee6a69252a6acc4c431c7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 21 Aug 2020 15:40:50 +0800 Subject: [PATCH 599/881] Don't create index for join table, close #3294 --- schema/relationship.go | 4 ++-- schema/utils.go | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index 537a3582..c8d129f2 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -225,7 +225,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: ownField.StructField.PkgPath, Type: ownField.StructField.Type, - Tag: removeSettingFromTag(removeSettingFromTag(ownField.StructField.Tag, "column"), "autoincrement"), + Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), }) } @@ -248,7 +248,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: relField.StructField.PkgPath, Type: relField.StructField.Type, - Tag: removeSettingFromTag(removeSettingFromTag(relField.StructField.Tag, "column"), "autoincrement"), + Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), }) } diff --git a/schema/utils.go b/schema/utils.go index 29f2fefb..41bd9d60 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -51,8 +51,11 @@ func toColumns(val string) (results []string) { return } -func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag { - return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}")) +func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.StructTag { + for _, name := range names { + tag = reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}")) + } + return tag } // GetRelationsValues get relations's values from a reflect value From 3a97639880a6a965c5e8209e2ff5557008e8b191 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Aug 2020 10:40:37 +0800 Subject: [PATCH 600/881] Fix unordered joins, close #3267 --- callbacks/query.go | 8 ++++---- chainable_api.go | 5 +---- statement.go | 13 +++++++++---- tests/joins_test.go | 8 ++++++++ 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 5ae1e904..f6cb32d5 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -104,12 +104,12 @@ func BuildQuerySQL(db *gorm.DB) { } joins := []clause.Join{} - for name, conds := range db.Statement.Joins { + for _, join := range db.Statement.Joins { if db.Statement.Schema == nil { joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: name, Vars: conds}, + Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, }) - } else if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { + } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { tableAliasName := relation.Name for _, s := range relation.FieldSchema.DBNames { @@ -149,7 +149,7 @@ func BuildQuerySQL(db *gorm.DB) { }) } else { joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: name, Vars: conds}, + Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, }) } } diff --git a/chainable_api.go b/chainable_api.go index 9b46a95b..e1b73457 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -172,10 +172,7 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() - if tx.Statement.Joins == nil { - tx.Statement.Joins = map[string][]interface{}{} - } - tx.Statement.Joins[query] = args + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) return } diff --git a/statement.go b/statement.go index 6114f468..214a15bb 100644 --- a/statement.go +++ b/statement.go @@ -29,7 +29,7 @@ type Statement struct { Distinct bool Selects []string // selected columns Omits []string // omit columns - Joins map[string][]interface{} + Joins []join Preloads map[string][]interface{} Settings sync.Map ConnPool ConnPool @@ -44,6 +44,11 @@ type Statement struct { assigns []interface{} } +type join struct { + Name string + Conds []interface{} +} + // StatementModifier statement modifier interface type StatementModifier interface { ModifyStatement(*Statement) @@ -401,7 +406,6 @@ func (stmt *Statement) clone() *Statement { Distinct: stmt.Distinct, Selects: stmt.Selects, Omits: stmt.Omits, - Joins: map[string][]interface{}{}, Preloads: map[string][]interface{}{}, ConnPool: stmt.ConnPool, Schema: stmt.Schema, @@ -417,8 +421,9 @@ func (stmt *Statement) clone() *Statement { newStmt.Preloads[k] = p } - for k, j := range stmt.Joins { - newStmt.Joins[k] = j + if len(stmt.Joins) > 0 { + newStmt.Joins = make([]join, len(stmt.Joins)) + copy(newStmt.Joins, stmt.Joins) } stmt.Settings.Range(func(k, v interface{}) bool { diff --git a/tests/joins_test.go b/tests/joins_test.go index e54d3784..f78ddf67 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "regexp" "sort" "testing" @@ -88,6 +89,13 @@ func TestJoinConds(t *testing.T) { if db5.Error != nil { t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + stmt := dryDB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5).Statement + + if !regexp.MustCompile("SELECT .* FROM .users. left join pets.*join accounts.*").MatchString(stmt.SQL.String()) { + t.Errorf("joins should be ordered, but got %v", stmt.SQL.String()) + } } func TestJoinsWithSelect(t *testing.T) { From cc6a64adfb0ed47d5f8ccf8de13eaf8145656973 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Aug 2020 15:40:19 +0800 Subject: [PATCH 601/881] Support smart migrate, close #3078 --- migrator.go | 1 + migrator/migrator.go | 63 ++++++++++++++++++++++++++++++++-- schema/field.go | 5 +++ statement.go | 1 - tests/go.mod | 6 ++-- tests/migrate_test.go | 80 +++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 149 insertions(+), 7 deletions(-) diff --git a/migrator.go b/migrator.go index 37051f81..ed8a8e26 100644 --- a/migrator.go +++ b/migrator.go @@ -42,6 +42,7 @@ type Migrator interface { AddColumn(dst interface{}, field string) error DropColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error + MigrateColumn(dst interface{}, field *schema.Field, columnType *sql.ColumnType) error HasColumn(dst interface{}, field string) bool RenameColumn(dst interface{}, oldName, field string) error ColumnTypes(dst interface{}) ([]*sql.ColumnType, error) diff --git a/migrator/migrator.go b/migrator/migrator.go index d50159dd..d93b8a6d 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "reflect" + "regexp" "strings" "gorm.io/gorm" @@ -80,7 +81,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { // AutoMigrate func (m Migrator) AutoMigrate(values ...interface{}) error { - // TODO smart migrate data type for _, value := range m.ReorderModels(values, true) { tx := m.DB.Session(&gorm.Session{}) if !tx.Migrator().HasTable(value) { @@ -89,11 +89,26 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } else { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { + columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + for _, field := range stmt.Schema.FieldsByDBName { - if !tx.Migrator().HasColumn(value, field.DBName) { + var foundColumn *sql.ColumnType + + for _, columnType := range columnTypes { + if columnType.Name() == field.DBName { + foundColumn = columnType + break + } + } + + if foundColumn == nil { + // not found, add column if err := tx.Migrator().AddColumn(value, field.DBName); err != nil { return err } + } else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil { + // found, smart migrate + return err } } @@ -120,7 +135,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } return nil }); err != nil { - fmt.Println(err) return err } } @@ -327,6 +341,49 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error }) } +func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType *sql.ColumnType) error { + // found, smart migrate + fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) + realDataType := strings.ToLower(columnType.DatabaseTypeName()) + + alterColumn := false + + // check size + if length, _ := columnType.Length(); length != int64(field.Size) { + if length > 0 && field.Size > 0 { + alterColumn = true + } else { + // has size in data type and not equal + matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1) + matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1) + if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) { + alterColumn = true + } + } + } + + // check precision + if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { + if strings.Contains(fullDataType, fmt.Sprint(field.Precision)) { + alterColumn = true + } + } + + // check nullable + if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull { + // not primary key & database is nullable + if !field.PrimaryKey && nullable { + alterColumn = true + } + } + + if alterColumn { + return m.DB.Migrator().AlterColumn(value, field.Name) + } + + return nil +} + func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { err = m.RunWithValue(value, func(stmt *gorm.Statement) error { rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() diff --git a/schema/field.go b/schema/field.go index 497aa02d..524d19fb 100644 --- a/schema/field.go +++ b/schema/field.go @@ -55,6 +55,7 @@ type Field struct { Comment string Size int Precision int + Scale int FieldType reflect.Type IndirectFieldType reflect.Type StructField reflect.StructField @@ -160,6 +161,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Precision, _ = strconv.Atoi(p) } + if s, ok := field.TagSettings["SCALE"]; ok { + field.Scale, _ = strconv.Atoi(s) + } + if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { field.NotNull = true } diff --git a/statement.go b/statement.go index 214a15bb..95d23fa5 100644 --- a/statement.go +++ b/statement.go @@ -379,7 +379,6 @@ func (stmt *Statement) Build(clauses ...string) { } } } - // TODO handle named vars } func (stmt *Statement) Parse(value interface{}) (err error) { diff --git a/tests/go.mod b/tests/go.mod index 54a808d0..9d4e892d 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,11 +6,11 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.3.1 - gorm.io/driver/postgres v0.2.6 + gorm.io/driver/mysql v0.3.2 + gorm.io/driver/postgres v0.2.9 gorm.io/driver/sqlite v1.0.9 gorm.io/driver/sqlserver v0.2.7 - gorm.io/gorm v0.2.19 + gorm.io/gorm v0.2.36 ) replace gorm.io/gorm => ../ diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 1b002049..4cc8a7c3 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -47,6 +47,86 @@ func TestMigrate(t *testing.T) { } } +func TestSmartMigrateColumn(t *testing.T) { + type UserMigrateColumn struct { + ID uint + Name string + Salary float64 + Birthday time.Time + } + + DB.Migrator().DropTable(&UserMigrateColumn{}) + + DB.AutoMigrate(&UserMigrateColumn{}) + + type UserMigrateColumn2 struct { + ID uint + Name string `gorm:"size:128"` + Salary float64 `gorm:"precision:2"` + Birthday time.Time `gorm:"precision:2"` + } + + if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) + if err != nil { + t.Fatalf("failed to get column types, got error: %v", err) + } + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "name": + if length, _ := columnType.Length(); length != 0 && length != 128 { + t.Fatalf("name's length should be 128, but got %v", length) + } + case "salary": + if precision, o, _ := columnType.DecimalSize(); precision != 0 && precision != 2 { + t.Fatalf("salary's precision should be 2, but got %v %v", precision, o) + } + case "birthday": + if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 2 { + t.Fatalf("birthday's precision should be 2, but got %v", precision) + } + } + } + + type UserMigrateColumn3 struct { + ID uint + Name string `gorm:"size:256"` + Salary float64 `gorm:"precision:3"` + Birthday time.Time `gorm:"precision:3"` + } + + if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) + if err != nil { + t.Fatalf("failed to get column types, got error: %v", err) + } + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "name": + if length, _ := columnType.Length(); length != 0 && length != 256 { + t.Fatalf("name's length should be 128, but got %v", length) + } + case "salary": + if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 { + t.Fatalf("salary's precision should be 2, but got %v", precision) + } + case "birthday": + if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 { + t.Fatalf("birthday's precision should be 2, but got %v", precision) + } + } + } + +} + func TestMigrateWithComment(t *testing.T) { type UserWithComment struct { gorm.Model From ebdb4edda8363fdd79c87ab323ca19b2be7a8872 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Aug 2020 20:08:23 +0800 Subject: [PATCH 602/881] Add AllowGlobalUpdate mode --- callbacks/delete.go | 2 +- callbacks/update.go | 2 +- gorm.go | 7 +++++++ soft_delete.go | 2 +- tests/delete_test.go | 4 ++++ tests/update_test.go | 4 ++++ 6 files changed, 18 insertions(+), 3 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index 288f2d69..f444f020 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -51,7 +51,7 @@ func Delete(db *gorm.DB) { } } - if _, ok := db.Statement.Clauses["WHERE"]; !ok { + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { db.AddError(gorm.ErrMissingWhereClause) return } diff --git a/callbacks/update.go b/callbacks/update.go index 5656d166..bd8a4150 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -69,7 +69,7 @@ func Update(db *gorm.DB) { db.Statement.Build("UPDATE", "SET", "WHERE") } - if _, ok := db.Statement.Clauses["WHERE"]; !ok { + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { db.AddError(gorm.ErrMissingWhereClause) return } diff --git a/gorm.go b/gorm.go index 1ace0099..3c187f42 100644 --- a/gorm.go +++ b/gorm.go @@ -32,6 +32,8 @@ type Config struct { DisableAutomaticPing bool // DisableForeignKeyConstraintWhenMigrating DisableForeignKeyConstraintWhenMigrating bool + // AllowGlobalUpdate allow global update + AllowGlobalUpdate bool // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder @@ -61,6 +63,7 @@ type Session struct { PrepareStmt bool WithConditions bool SkipDefaultTransaction bool + AllowGlobalUpdate bool Context context.Context Logger logger.Interface NowFunc func() time.Time @@ -154,6 +157,10 @@ func (db *DB) Session(config *Session) *DB { tx.Config.SkipDefaultTransaction = true } + if config.AllowGlobalUpdate { + txConfig.AllowGlobalUpdate = true + } + if config.Context != nil { tx.Statement = tx.Statement.clone() tx.Statement.DB = tx diff --git a/soft_delete.go b/soft_delete.go index 875623bc..d33bf866 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -98,7 +98,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } } - if _, ok := stmt.Clauses["WHERE"]; !ok { + if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { stmt.DB.AddError(ErrMissingWhereClause) return } diff --git a/tests/delete_test.go b/tests/delete_test.go index f5b3e784..09c1a075 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -118,4 +118,8 @@ func TestBlockGlobalDelete(t *testing.T) { if err := DB.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { t.Errorf("should returns missing WHERE clause while deleting error") } + + if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&User{}).Error; err != nil { + t.Errorf("should returns no error while enable global update, but got err %v", err) + } } diff --git a/tests/update_test.go b/tests/update_test.go index 49a13be9..e52dc652 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -222,6 +222,10 @@ func TestBlockGlobalUpdate(t *testing.T) { if err := DB.Model(&User{}).Update("name", "jinzhu").Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { t.Errorf("should returns missing WHERE clause while updating error, got err %v", err) } + + if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&User{}).Update("name", "jinzhu").Error; err != nil { + t.Errorf("should returns no error while enable global update, but got err %v", err) + } } func TestSelectWithUpdate(t *testing.T) { From 84dbb36d3bd91a5e7b3c1ee5a617ea923a4098d2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 24 Aug 2020 20:24:25 +0800 Subject: [PATCH 603/881] Add Golang v1.15 --- .github/workflows/tests.yml | 10 +++++----- tests/default_value_test.go | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b626ce94..4388c31d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: sqlite: strategy: matrix: - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [ubuntu-latest, macos-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} @@ -38,7 +38,7 @@ jobs: sqlite_windows: strategy: matrix: - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [windows-latest] runs-on: ${{ matrix.platform }} @@ -64,7 +64,7 @@ jobs: strategy: matrix: dbversion: ['mysql:latest', 'mysql:5.7', 'mysql:5.6', 'mariadb:latest'] - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -108,7 +108,7 @@ jobs: strategy: matrix: dbversion: ['postgres:latest', 'postgres:11', 'postgres:10'] - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [ubuntu-latest] # can not run in macOS and widnowsOS runs-on: ${{ matrix.platform }} @@ -150,7 +150,7 @@ jobs: sqlserver: strategy: matrix: - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} diff --git a/tests/default_value_test.go b/tests/default_value_test.go index ea496d60..aa4a511a 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -10,7 +10,7 @@ func TestDefaultValue(t *testing.T) { type Harumph struct { gorm.Model Email string `gorm:"not null;index:,unique"` - Name string `gorm:"not null;default:'foo'"` + Name string `gorm:"not null;default:foo"` Name2 string `gorm:"size:233;not null;default:'foo'"` Name3 string `gorm:"size:233;not null;default:''"` Age int `gorm:"default:18"` From 3dfa8a66f1bef0a7469c34968cb298c208e59fb9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 25 Aug 2020 17:27:28 +0800 Subject: [PATCH 604/881] Fix panic when delet without pointer, close #3308 --- callbacks/delete.go | 12 ++++++------ soft_delete.go | 5 ----- tests/delete_test.go | 4 ++++ 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index f444f020..76b78fb4 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -41,7 +41,7 @@ func Delete(db *gorm.DB) { db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) } - if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { + if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) @@ -51,15 +51,15 @@ func Delete(db *gorm.DB) { } } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { - db.AddError(gorm.ErrMissingWhereClause) - return - } - db.Statement.AddClauseIfNotExists(clause.From{}) db.Statement.Build("DELETE", "FROM", "WHERE") } + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { + db.AddError(gorm.ErrMissingWhereClause) + return + } + if !db.DryRun && db.Error == nil { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/soft_delete.go b/soft_delete.go index d33bf866..484f565c 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -98,11 +98,6 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } } - if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { - stmt.DB.AddError(ErrMissingWhereClause) - return - } - stmt.AddClauseIfNotExists(clause.Update{}) stmt.Build("UPDATE", "SET", "WHERE") } diff --git a/tests/delete_test.go b/tests/delete_test.go index 09c1a075..17299677 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -48,6 +48,10 @@ func TestDelete(t *testing.T) { t.Errorf("errors happened when delete: %v", err) } + if err := DB.Delete(User{}).Error; err != gorm.ErrMissingWhereClause { + t.Errorf("errors happened when delete: %v", err) + } + if err := DB.Where("id = ?", users[0].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("should returns record not found error, but got %v", err) } From 0f3201e73b97c358d2b7d98d24185fab91e5dd73 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 25 Aug 2020 18:18:16 +0800 Subject: [PATCH 605/881] friendly invalid field error message --- schema/relationship.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index c8d129f2..dad2e629 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -336,7 +336,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue primarySchema, foreignSchema = schema, relation.FieldSchema ) - reguessOrErr := func(err string, args ...interface{}) { + reguessOrErr := func() { switch gl { case guessHas: schema.guessRelation(relation, field, guessEmbeddedHas) @@ -345,7 +345,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue case guessBelongs: schema.guessRelation(relation, field, guessEmbeddedBelongs) default: - schema.err = fmt.Errorf(err, args...) + schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name) } } @@ -354,7 +354,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue if field.OwnerSchema != nil { primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema } else { - reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) + reguessOrErr() return } case guessBelongs: @@ -363,7 +363,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue if field.OwnerSchema != nil { primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema } else { - reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) + reguessOrErr() return } } @@ -373,7 +373,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue if f := foreignSchema.LookUpField(foreignKey); f != nil { foreignFields = append(foreignFields, f) } else { - reguessOrErr("unsupported relations %v for %v on field %v with foreign keys %v", relation.FieldSchema, schema, field.Name, relation.foreignKeys) + reguessOrErr() return } } @@ -392,7 +392,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue } if len(foreignFields) == 0 { - reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) + reguessOrErr() return } else if len(relation.primaryKeys) > 0 { for idx, primaryKey := range relation.primaryKeys { @@ -400,11 +400,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue if len(primaryFields) < idx+1 { primaryFields = append(primaryFields, f) } else if f != primaryFields[idx] { - reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys) + reguessOrErr() return } } else { - reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys) + reguessOrErr() return } } @@ -414,7 +414,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue } else if len(primarySchema.PrimaryFields) == len(foreignFields) { primaryFields = append(primaryFields, primarySchema.PrimaryFields...) } else { - reguessOrErr("unsupported relations %v for %v on field %v", relation.FieldSchema, schema, field.Name) + reguessOrErr() return } } From 3195ae12072f51d15064a3428f4e906c6873c4e2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 25 Aug 2020 18:59:19 +0800 Subject: [PATCH 606/881] Allow override alias table in preload conditions --- callbacks/preload.go | 6 +++--- tests/preload_test.go | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index cd09a6d6..25b8cb2b 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -50,7 +50,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { joinResults := rel.JoinTable.MakeSlice().Elem() column, values := schema.ToQueryValues(rel.JoinTable.Table, joinForeignKeys, joinForeignValues) - tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()) + db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error) // convert join identity map to relation identity map fieldValues := make([]interface{}, len(joinForeignFields)) @@ -93,7 +93,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } reflectResults := rel.FieldSchema.MakeSlice().Elem() - column, values := schema.ToQueryValues(rel.FieldSchema.Table, relForeignKeys, foreignValues) + column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) for _, cond := range conds { if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { @@ -103,7 +103,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } } - tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...) + db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error) fieldValues := make([]interface{}, len(relForeignFields)) diff --git a/tests/preload_test.go b/tests/preload_test.go index 3caa17b4..7e5d2622 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -5,6 +5,7 @@ import ( "strconv" "testing" + "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -108,6 +109,20 @@ func TestPreloadWithConds(t *testing.T) { } CheckUser(t, users2[0], users[0]) + + var users3 []User + if err := DB.Preload("Account", func(tx *gorm.DB) *gorm.DB { + return tx.Table("accounts AS a").Select("a.*") + }).Find(&users3, "id IN ?", userIDs).Error; err != nil { + t.Errorf("failed to query, got error %v", err) + } + sort.Slice(users3, func(i, j int) bool { + return users2[i].ID < users2[j].ID + }) + + for i, u := range users3 { + CheckUser(t, u, users[i]) + } } func TestNestedPreloadWithConds(t *testing.T) { From 0d96f99499f2501a0d3a5e0d93ef157cc287e44f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 26 Aug 2020 12:22:11 +0800 Subject: [PATCH 607/881] Update README --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index b51297c4..c727e2cf 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. * Composite Primary Key * Auto Migrations * Logger -* Extendable, write Plugins based on GORM callbacks +* Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus… * Every feature comes with tests * Developer Friendly @@ -40,4 +40,3 @@ The fantastic ORM library for Golang, aims to be developer friendly. © Jinzhu, 2013~time.Now Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License) - From ce8853e7a6142420a786be1b0f0c5ffeb8778778 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 27 Aug 2020 15:03:57 +0800 Subject: [PATCH 608/881] Add GormValuer interface support --- README.md | 2 +- callbacks/create.go | 8 +++--- callbacks/delete.go | 4 +-- callbacks/interfaces.go | 39 ++++++++++++++++++++++++++++ callbacks/query.go | 2 +- callbacks/update.go | 8 +++--- interfaces.go | 37 +++------------------------ schema/interfaces.go | 4 ++- statement.go | 2 ++ tests/scanner_valuer_test.go | 49 ++++++++++++++++++++++++++++++++++++ 10 files changed, 108 insertions(+), 47 deletions(-) create mode 100644 callbacks/interfaces.go diff --git a/README.md b/README.md index c727e2cf..9c0aded0 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point * Context, Prepared Statment Mode, DryRun Mode * Batch Insert, FindInBatches, Find To Map -* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg +* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr * Composite Primary Key * Auto Migrations * Logger diff --git a/callbacks/create.go b/callbacks/create.go index 7a32ed5c..cc7e2671 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -12,14 +12,14 @@ func BeforeCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { - if i, ok := value.(gorm.BeforeSaveInterface); ok { + if i, ok := value.(BeforeSaveInterface); ok { called = true db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeCreate { - if i, ok := value.(gorm.BeforeCreateInterface); ok { + if i, ok := value.(BeforeCreateInterface); ok { called = true db.AddError(i.BeforeCreate(tx)) } @@ -203,14 +203,14 @@ func AfterCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { - if i, ok := value.(gorm.AfterSaveInterface); ok { + if i, ok := value.(AfterSaveInterface); ok { called = true db.AddError(i.AfterSave(tx)) } } if db.Statement.Schema.AfterCreate { - if i, ok := value.(gorm.AfterCreateInterface); ok { + if i, ok := value.(AfterCreateInterface); ok { called = true db.AddError(i.AfterCreate(tx)) } diff --git a/callbacks/delete.go b/callbacks/delete.go index 76b78fb4..e95117a1 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -11,7 +11,7 @@ import ( func BeforeDelete(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { callMethod(db, func(value interface{}, tx *gorm.DB) bool { - if i, ok := value.(gorm.BeforeDeleteInterface); ok { + if i, ok := value.(BeforeDeleteInterface); ok { db.AddError(i.BeforeDelete(tx)) return true } @@ -75,7 +75,7 @@ func Delete(db *gorm.DB) { func AfterDelete(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { callMethod(db, func(value interface{}, tx *gorm.DB) bool { - if i, ok := value.(gorm.AfterDeleteInterface); ok { + if i, ok := value.(AfterDeleteInterface); ok { db.AddError(i.AfterDelete(tx)) return true } diff --git a/callbacks/interfaces.go b/callbacks/interfaces.go new file mode 100644 index 00000000..2302470f --- /dev/null +++ b/callbacks/interfaces.go @@ -0,0 +1,39 @@ +package callbacks + +import "gorm.io/gorm" + +type BeforeCreateInterface interface { + BeforeCreate(*gorm.DB) error +} + +type AfterCreateInterface interface { + AfterCreate(*gorm.DB) error +} + +type BeforeUpdateInterface interface { + BeforeUpdate(*gorm.DB) error +} + +type AfterUpdateInterface interface { + AfterUpdate(*gorm.DB) error +} + +type BeforeSaveInterface interface { + BeforeSave(*gorm.DB) error +} + +type AfterSaveInterface interface { + AfterSave(*gorm.DB) error +} + +type BeforeDeleteInterface interface { + BeforeDelete(*gorm.DB) error +} + +type AfterDeleteInterface interface { + AfterDelete(*gorm.DB) error +} + +type AfterFindInterface interface { + AfterFind(*gorm.DB) error +} diff --git a/callbacks/query.go b/callbacks/query.go index f6cb32d5..0703b92e 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -214,7 +214,7 @@ func Preload(db *gorm.DB) { func AfterQuery(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind { callMethod(db, func(value interface{}, tx *gorm.DB) bool { - if i, ok := value.(gorm.AfterFindInterface); ok { + if i, ok := value.(AfterFindInterface); ok { db.AddError(i.AfterFind(tx)) return true } diff --git a/callbacks/update.go b/callbacks/update.go index bd8a4150..73c062b4 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -32,14 +32,14 @@ func BeforeUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { - if i, ok := value.(gorm.BeforeSaveInterface); ok { + if i, ok := value.(BeforeSaveInterface); ok { called = true db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeUpdate { - if i, ok := value.(gorm.BeforeUpdateInterface); ok { + if i, ok := value.(BeforeUpdateInterface); ok { called = true db.AddError(i.BeforeUpdate(tx)) } @@ -90,14 +90,14 @@ func AfterUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { - if i, ok := value.(gorm.AfterSaveInterface); ok { + if i, ok := value.(AfterSaveInterface); ok { called = true db.AddError(i.AfterSave(tx)) } } if db.Statement.Schema.AfterUpdate { - if i, ok := value.(gorm.AfterUpdateInterface); ok { + if i, ok := value.(AfterUpdateInterface); ok { called = true db.AddError(i.AfterUpdate(tx)) } diff --git a/interfaces.go b/interfaces.go index b2ce59b3..e933952b 100644 --- a/interfaces.go +++ b/interfaces.go @@ -53,38 +53,7 @@ type TxCommitter interface { Rollback() error } -type BeforeCreateInterface interface { - BeforeCreate(*DB) error -} - -type AfterCreateInterface interface { - AfterCreate(*DB) error -} - -type BeforeUpdateInterface interface { - BeforeUpdate(*DB) error -} - -type AfterUpdateInterface interface { - AfterUpdate(*DB) error -} - -type BeforeSaveInterface interface { - BeforeSave(*DB) error -} - -type AfterSaveInterface interface { - AfterSave(*DB) error -} - -type BeforeDeleteInterface interface { - BeforeDelete(*DB) error -} - -type AfterDeleteInterface interface { - AfterDelete(*DB) error -} - -type AfterFindInterface interface { - AfterFind(*DB) error +// Valuer gorm valuer interface +type Valuer interface { + GormValue(context.Context, *DB) clause.Expr } diff --git a/schema/interfaces.go b/schema/interfaces.go index e8e51e4c..98abffbd 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -1,6 +1,8 @@ package schema -import "gorm.io/gorm/clause" +import ( + "gorm.io/gorm/clause" +) type GormDataTypeInterface interface { GormDataType() string diff --git a/statement.go b/statement.go index 95d23fa5..fba1991d 100644 --- a/statement.go +++ b/statement.go @@ -161,6 +161,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { stmt.Vars = append(stmt.Vars, v.Value) case clause.Column, clause.Table: stmt.QuoteTo(writer, v) + case Valuer: + stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) case clause.Expr: var varStr strings.Builder var sql = v.SQL diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index ce8a2b50..ec16ccf6 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -1,16 +1,20 @@ package tests_test import ( + "context" "database/sql" "database/sql/driver" "encoding/json" "errors" + "fmt" "reflect" + "regexp" "strconv" "testing" "time" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -305,3 +309,48 @@ func (t EmptyTime) Value() (driver.Value, error) { type NullString struct { sql.NullString } + +type Point struct { + X, Y int +} + +func (point *Point) Scan(v interface{}) error { + return nil +} + +func (point Point) GormDataType() string { + return "geo" +} + +func (point Point) GormValue(ctx context.Context, db *gorm.DB) clause.Expr { + return clause.Expr{ + SQL: "ST_PointFromText(?)", + Vars: []interface{}{fmt.Sprintf("POINT(%d %d)", point.X, point.Y)}, + } +} + +func TestGORMValuer(t *testing.T) { + type UserWithPoint struct { + Name string + Point Point + } + + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Create(&UserWithPoint{ + Name: "jinzhu", + Point: Point{X: 100, Y: 100}, + }).Statement + + if stmt.SQL.String() == "" || len(stmt.Vars) != 2 { + t.Errorf("Failed to generate sql, got %v", stmt.SQL.String()) + } + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } +} From 7a90496701f7b81e06daaa134a8f8853c1f935d5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 27 Aug 2020 16:27:59 +0800 Subject: [PATCH 609/881] Test create from sql expr with map --- callbacks/create.go | 4 ++++ callbacks/helper.go | 12 ++++++++---- tests/scanner_valuer_test.go | 26 ++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index cc7e2671..c59b14b5 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -225,8 +225,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { switch value := stmt.Dest.(type) { case map[string]interface{}: values = ConvertMapToValuesForCreate(stmt, value) + case *map[string]interface{}: + values = ConvertMapToValuesForCreate(stmt, *value) case []map[string]interface{}: values = ConvertSliceOfMapToValuesForCreate(stmt, value) + case *[]map[string]interface{}: + values = ConvertSliceOfMapToValuesForCreate(stmt, *value) default: var ( selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) diff --git a/callbacks/helper.go b/callbacks/helper.go index 80fbc2a1..e0a66dd2 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -20,8 +20,10 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter for _, k := range keys { value := mapValue[k] - if field := stmt.Schema.LookUpField(k); field != nil { - k = field.DBName + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } } if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { @@ -46,8 +48,10 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st for idx, mapValue := range mapValues { for k, v := range mapValue { - if field := stmt.Schema.LookUpField(k); field != nil { - k = field.DBName + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } } if _, ok := result[k]; !ok { diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index ec16ccf6..dbf5adac 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -353,4 +353,30 @@ func TestGORMValuer(t *testing.T) { if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { t.Errorf("generated vars is not equal, got %v", stmt.Vars) } + + stmt = dryRunDB.Model(UserWithPoint{}).Create(map[string]interface{}{ + "Name": "jinzhu", + "Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}}, + }).Statement + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } + + stmt = dryRunDB.Table("user_with_points").Create(&map[string]interface{}{ + "Name": "jinzhu", + "Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}}, + }).Statement + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.Name.,.Point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } } From cd54dddd94a992edd446611aeccc939a64ad2658 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 27 Aug 2020 18:42:40 +0800 Subject: [PATCH 610/881] Test update with GormValuer --- tests/go.mod | 2 +- tests/scanner_valuer_test.go | 19 +++++++++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index 9d4e892d..b0ed4497 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( gorm.io/driver/mysql v0.3.2 gorm.io/driver/postgres v0.2.9 gorm.io/driver/sqlite v1.0.9 - gorm.io/driver/sqlserver v0.2.7 + gorm.io/driver/sqlserver v0.2.8 gorm.io/gorm v0.2.36 ) diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index dbf5adac..f42daae7 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -314,10 +314,6 @@ type Point struct { X, Y int } -func (point *Point) Scan(v interface{}) error { - return nil -} - func (point Point) GormDataType() string { return "geo" } @@ -379,4 +375,19 @@ func TestGORMValuer(t *testing.T) { if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { t.Errorf("generated vars is not equal, got %v", stmt.Vars) } + + stmt = dryRunDB.Session(&gorm.Session{ + AllowGlobalUpdate: true, + }).Model(&UserWithPoint{}).Updates(UserWithPoint{ + Name: "jinzhu", + Point: Point{X: 100, Y: 100}, + }).Statement + + if !regexp.MustCompile(`UPDATE .user_with_points. SET .name.=.+,.point.=ST_PointFromText\(.+\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } } From d50dbb0896100640d61a8b4017aa46946f3bc6c5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 27 Aug 2020 19:15:40 +0800 Subject: [PATCH 611/881] Fix check valid db name, close #3315 --- chainable_api.go | 6 +++--- finisher_api.go | 2 +- utils/utils.go | 4 ++-- utils/utils_test.go | 14 ++++++++++++++ 4 files changed, 20 insertions(+), 6 deletions(-) create mode 100644 utils/utils_test.go diff --git a/chainable_api.go b/chainable_api.go index e1b73457..c8417a6d 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -93,7 +93,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } delete(tx.Statement.Clauses, "SELECT") case string: - fields := strings.FieldsFunc(v, utils.IsChar) + fields := strings.FieldsFunc(v, utils.IsValidDBNameChar) // normal field names if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { @@ -133,7 +133,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { - tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsChar) + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) } else { tx.Statement.Omits = columns } @@ -180,7 +180,7 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() - fields := strings.FieldsFunc(name, utils.IsChar) + fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) tx.Statement.AddClause(clause.GroupBy{ Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, }) diff --git a/finisher_api.go b/finisher_api.go index cf46f78a..2cde3c31 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -362,7 +362,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx.AddError(ErrModelValueRequired) } - fields := strings.FieldsFunc(column, utils.IsChar) + fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) tx.Statement.AddClauseIfNotExists(clause.Select{ Distinct: tx.Statement.Distinct, Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, diff --git a/utils/utils.go b/utils/utils.go index e93f3055..71336f4b 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -29,8 +29,8 @@ func FileWithLineNum() string { return "" } -func IsChar(c rune) bool { - return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' +func IsValidDBNameChar(c rune) bool { + return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' } func CheckTruth(val interface{}) bool { diff --git a/utils/utils_test.go b/utils/utils_test.go new file mode 100644 index 00000000..5737c511 --- /dev/null +++ b/utils/utils_test.go @@ -0,0 +1,14 @@ +package utils + +import ( + "strings" + "testing" +) + +func TestIsValidDBNameChar(t *testing.T) { + for _, db := range []string{"db", "dbName", "db_name", "db1", "1dbname", "db$name"} { + if fields := strings.FieldsFunc(db, IsValidDBNameChar); len(fields) != 1 { + t.Fatalf("failed to parse db name %v", db) + } + } +} From dacbaa5f02bf40efa5d8841047c047f7a5340d9f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 27 Aug 2020 19:52:01 +0800 Subject: [PATCH 612/881] Fix update attrs order --- callbacks/update.go | 6 ++++-- tests/scanner_valuer_test.go | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 73c062b4..46f59157 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -199,7 +199,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } if !stmt.UpdatingColumn && stmt.Schema != nil { - for _, field := range stmt.Schema.FieldsByDBName { + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.LookUpField(dbName) if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { now := stmt.DB.NowFunc() @@ -222,7 +223,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { switch updatingValue.Kind() { case reflect.Struct: set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) - for _, field := range stmt.Schema.FieldsByDBName { + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.LookUpField(dbName) if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { value, isZero := field.ValueOf(updatingValue) diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index f42daae7..fb1f5791 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -384,7 +384,7 @@ func TestGORMValuer(t *testing.T) { }).Statement if !regexp.MustCompile(`UPDATE .user_with_points. SET .name.=.+,.point.=ST_PointFromText\(.+\)`).MatchString(stmt.SQL.String()) { - t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + t.Errorf("update with sql.Expr, but got %v", stmt.SQL.String()) } if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { From c19a3abefb2aef853e4541ae1af7fa93f2dc0848 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 28 Aug 2020 11:31:13 +0800 Subject: [PATCH 613/881] Fix self-referential belongs to, close #3319 --- association.go | 4 ++-- schema/relationship.go | 34 +++++++++++++++++++--------------- schema/relationship_test.go | 14 ++++++++++++++ schema/schema_test.go | 2 +- 4 files changed, 36 insertions(+), 18 deletions(-) diff --git a/association.go b/association.go index e59b8938..25e1fe8d 100644 --- a/association.go +++ b/association.go @@ -54,7 +54,7 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro for _, queryClause := range association.Relationship.JoinTable.QueryClauses { joinStmt.AddClause(queryClause) } - joinStmt.Build("WHERE", "LIMIT") + joinStmt.Build("WHERE") tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) } @@ -112,7 +112,7 @@ func (association *Association) Replace(values ...interface{}) error { updateMap[ref.ForeignKey.DBName] = nil } - association.DB.UpdateColumns(updateMap) + association.Error = association.DB.UpdateColumns(updateMap).Error } case schema.HasOne, schema.HasMany: var ( diff --git a/schema/relationship.go b/schema/relationship.go index dad2e629..5132ff74 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -82,7 +82,9 @@ func (schema *Schema) parseRelation(field *Field) { schema.buildMany2ManyRelation(relation, field, many2many) } else { switch field.IndirectFieldType.Kind() { - case reflect.Struct, reflect.Slice: + case reflect.Struct: + schema.guessRelation(relation, field, guessBelongs) + case reflect.Slice: schema.guessRelation(relation, field, guessHas) default: schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) @@ -324,10 +326,10 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel type guessLevel int const ( - guessHas guessLevel = iota - guessEmbeddedHas - guessBelongs + guessBelongs guessLevel = iota guessEmbeddedBelongs + guessHas + guessEmbeddedHas ) func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) { @@ -338,25 +340,19 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue reguessOrErr := func() { switch gl { - case guessHas: - schema.guessRelation(relation, field, guessEmbeddedHas) - case guessEmbeddedHas: - schema.guessRelation(relation, field, guessBelongs) case guessBelongs: schema.guessRelation(relation, field, guessEmbeddedBelongs) + case guessEmbeddedBelongs: + schema.guessRelation(relation, field, guessHas) + case guessHas: + schema.guessRelation(relation, field, guessEmbeddedHas) + // case guessEmbeddedHas: default: schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name) } } switch gl { - case guessEmbeddedHas: - if field.OwnerSchema != nil { - primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema - } else { - reguessOrErr() - return - } case guessBelongs: primarySchema, foreignSchema = relation.FieldSchema, schema case guessEmbeddedBelongs: @@ -366,6 +362,14 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue reguessOrErr() return } + case guessHas: + case guessEmbeddedHas: + if field.OwnerSchema != nil { + primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema + } else { + reguessOrErr() + return + } } if len(relation.foreignKeys) > 0 { diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 2c09f528..2e85c538 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -55,6 +55,20 @@ func TestBelongsToOverrideReferences(t *testing.T) { }) } +func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) { + type User struct { + ID int32 `gorm:"primaryKey"` + Name string + CreatedBy *int32 + Creator *User `gorm:"foreignKey:CreatedBy;references:ID"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "CreatedBy", "User", "", false}}, + }) +} + func TestHasOneOverrideForeignKey(t *testing.T) { type Profile struct { gorm.Model diff --git a/schema/schema_test.go b/schema/schema_test.go index 8bd1e5ca..4d13ebd2 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -171,7 +171,7 @@ func TestNestedModel(t *testing.T) { fields := []schema.Field{ {Name: "ID", DBName: "id", BindNames: []string{"VersionModel", "BaseModel", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, - {Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Int, Size: 64}, + {Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Uint, Size: 64}, {Name: "Version", DBName: "version", BindNames: []string{"VersionModel", "Version"}, DataType: schema.Int, Size: 64}, } From 94c6bb980b8c3775d98121d5d42109cefe596c5c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 28 Aug 2020 12:25:25 +0800 Subject: [PATCH 614/881] Refactor association --- association.go | 92 ++++++++++++++++++++------------------------------ 1 file changed, 37 insertions(+), 55 deletions(-) diff --git a/association.go b/association.go index 25e1fe8d..db77cc4e 100644 --- a/association.go +++ b/association.go @@ -43,32 +43,8 @@ func (db *DB) Association(column string) *Association { func (association *Association) Find(out interface{}, conds ...interface{}) error { if association.Error == nil { - var ( - queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) - tx = association.DB.Model(out) - ) - - if association.Relationship.JoinTable != nil { - if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { - joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} - for _, queryClause := range association.Relationship.JoinTable.QueryClauses { - joinStmt.AddClause(queryClause) - } - joinStmt.Build("WHERE") - tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) - } - - tx.Clauses(clause.From{Joins: []clause.Join{{ - Table: clause.Table{Name: association.Relationship.JoinTable.Table}, - ON: clause.Where{Exprs: queryConds}, - }}}) - } else { - tx.Clauses(clause.Where{Exprs: queryConds}) - } - - association.Error = tx.Find(out, conds...).Error + association.Error = association.buildCondition().Find(out, conds...).Error } - return association.Error } @@ -80,7 +56,7 @@ func (association *Association) Append(values ...interface{}) error { association.Error = association.Replace(values...) } default: - association.saveAssociation(false, values...) + association.saveAssociation( /*clear*/ false, values...) } } @@ -90,7 +66,7 @@ func (association *Association) Append(values ...interface{}) error { func (association *Association) Replace(values ...interface{}) error { if association.Error == nil { // save associations - association.saveAssociation(true, values...) + association.saveAssociation( /*clear*/ true, values...) // set old associations's foreign key to null reflectValue := association.DB.Statement.ReflectValue @@ -234,7 +210,7 @@ func (association *Association) Delete(values ...interface{}) error { var ( primaryFields, relPrimaryFields []*schema.Field joinPrimaryKeys, joinRelPrimaryKeys []string - modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + joinValue = reflect.New(rel.JoinTable.ModelType).Interface() ) for _, ref := range rel.References { @@ -259,10 +235,11 @@ func (association *Association) Delete(values ...interface{}) error { relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) - association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error + association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error } if association.Error == nil { + // clean up deleted values's foreign key relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { @@ -328,33 +305,8 @@ func (association *Association) Clear() error { func (association *Association) Count() (count int64) { if association.Error == nil { - var ( - conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) - modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() - tx = association.DB.Model(modelValue) - ) - - if association.Relationship.JoinTable != nil { - if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { - joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} - for _, queryClause := range association.Relationship.JoinTable.QueryClauses { - joinStmt.AddClause(queryClause) - } - joinStmt.Build("WHERE", "LIMIT") - tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) - } - - tx.Clauses(clause.From{Joins: []clause.Join{{ - Table: clause.Table{Name: association.Relationship.JoinTable.Table}, - ON: clause.Where{Exprs: conds}, - }}}) - } else { - tx.Clauses(clause.Where{Exprs: conds}) - } - - association.Error = tx.Count(&count).Error + association.Error = association.buildCondition().Count(&count).Error } - return } @@ -435,6 +387,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if len(values) != reflectValue.Len() { + // clear old data if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { @@ -467,6 +420,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error } case reflect.Struct: + // clear old data if clear && len(values) == 0 { association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) @@ -498,3 +452,31 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } } } + +func (association *Association) buildCondition() *DB { + var ( + queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) + modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() + tx = association.DB.Model(modelValue) + ) + + if association.Relationship.JoinTable != nil { + if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { + joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} + for _, queryClause := range association.Relationship.JoinTable.QueryClauses { + joinStmt.AddClause(queryClause) + } + joinStmt.Build("WHERE") + tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) + } + + tx.Clauses(clause.From{Joins: []clause.Join{{ + Table: clause.Table{Name: association.Relationship.JoinTable.Table}, + ON: clause.Where{Exprs: queryConds}, + }}}) + } else { + tx.Clauses(clause.Where{Exprs: queryConds}) + } + + return tx +} From 06461b32549fb13090b92713703228da2e8290aa Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 28 Aug 2020 21:16:47 +0800 Subject: [PATCH 615/881] GORM V2.0.0 --- tests/go.mod | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index b0ed4497..1a6fe7a8 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,11 +6,11 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.3.2 - gorm.io/driver/postgres v0.2.9 - gorm.io/driver/sqlite v1.0.9 - gorm.io/driver/sqlserver v0.2.8 - gorm.io/gorm v0.2.36 + gorm.io/driver/mysql v1.0.0 + gorm.io/driver/postgres v1.0.0 + gorm.io/driver/sqlite v1.1.0 + gorm.io/driver/sqlserver v1.0.0 + gorm.io/gorm v1.9.19 ) replace gorm.io/gorm => ../ From 677edf9d9e3fc2f435e0668f74126a118fa97c25 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 29 Aug 2020 22:09:07 +0800 Subject: [PATCH 616/881] ignore AS when alias table as it doesn't work on oracle db, close #3328 --- statement.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/statement.go b/statement.go index fba1991d..d72a086f 100644 --- a/statement.go +++ b/statement.go @@ -86,7 +86,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { } if v.Alias != "" { - writer.WriteString(" AS ") + writer.WriteByte(' ') stmt.DB.Dialector.QuoteTo(writer, v.Alias) } case clause.Column: From 59586dcd313bd067c2b94c118a9d20663ab3c8d0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 29 Aug 2020 23:02:19 +0800 Subject: [PATCH 617/881] Fix unnecessary duplicated primary condition when using Save, close #3330 --- finisher_api.go | 9 ++------- tests/update_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 2cde3c31..824f2a2e 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -32,17 +32,12 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.callbacks.Create().Execute(tx) case reflect.Struct: if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { - where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} - for idx, pf := range tx.Statement.Schema.PrimaryFields { - if pv, isZero := pf.ValueOf(reflectValue); isZero { + for _, pf := range tx.Statement.Schema.PrimaryFields { + if _, isZero := pf.ValueOf(reflectValue); isZero { tx.callbacks.Create().Execute(tx) return - } else { - where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} } } - - tx.Statement.AddClause(where) } fallthrough diff --git a/tests/update_test.go b/tests/update_test.go index e52dc652..d566c04d 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -2,6 +2,7 @@ package tests_test import ( "errors" + "regexp" "sort" "strings" "testing" @@ -586,3 +587,26 @@ func TestUpdateFromSubQuery(t *testing.T) { t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) } } + +func TestSave(t *testing.T) { + user := *GetUser("save", Config{}) + DB.Create(&user) + + if err := DB.First(&User{}, "name = ?", "save").Error; err != nil { + t.Fatalf("failed to find created user") + } + + user.Name = "save2" + DB.Save(&user) + + var result User + if err := DB.First(&result, "name = ?", "save2").Error; err != nil || result.ID != user.ID { + t.Fatalf("failed to find updated user") + } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + stmt := dryDB.Save(&user).Statement + if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(stmt.SQL.String()) { + t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) + } +} From b4166d9515c3a86da2a1c7a695bc73d83861737d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 30 Aug 2020 10:12:49 +0800 Subject: [PATCH 618/881] Fix V2 Save compatibility, close #3332 --- association.go | 4 ++-- callbacks/create.go | 2 +- finisher_api.go | 10 +++++++++- tests/go.mod | 2 +- tests/update_test.go | 20 ++++++++++++++++++++ 5 files changed, 33 insertions(+), 5 deletions(-) diff --git a/association.go b/association.go index db77cc4e..140ae6ac 100644 --- a/association.go +++ b/association.go @@ -417,7 +417,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) // TODO support save slice data, sql with case? - association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error + association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Index(i).Addr().Interface()).Error } case reflect.Struct: // clear old data @@ -439,7 +439,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } if len(values) > 0 { - association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Addr().Interface()).Error + association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Addr().Interface()).Error } } diff --git a/callbacks/create.go b/callbacks/create.go index c59b14b5..5de19d35 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -319,7 +319,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } if stmt.UpdatingColumn { - if stmt.Schema != nil { + if stmt.Schema != nil && len(values.Columns) > 1 { columns := make([]string, 0, len(values.Columns)-1) for _, column := range values.Columns { if field := stmt.Schema.LookUpField(column.Name); field != nil { diff --git a/finisher_api.go b/finisher_api.go index 824f2a2e..a205b859 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -42,11 +42,19 @@ func (db *DB) Save(value interface{}) (tx *DB) { fallthrough default: - if len(tx.Statement.Selects) == 0 { + selectedUpdate := len(tx.Statement.Selects) != 0 + // when updating, use all fields including those zero-value fields + if !selectedUpdate { tx.Statement.Selects = append(tx.Statement.Selects, "*") } tx.callbacks.Update().Execute(tx) + + if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { + if err := tx.Session(&Session{}).First(value).Error; errors.Is(err, ErrRecordNotFound) { + return tx.Create(value) + } + } } return diff --git a/tests/go.mod b/tests/go.mod index 1a6fe7a8..c09747ab 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( gorm.io/driver/mysql v1.0.0 gorm.io/driver/postgres v1.0.0 gorm.io/driver/sqlite v1.1.0 - gorm.io/driver/sqlserver v1.0.0 + gorm.io/driver/sqlserver v1.0.1 gorm.io/gorm v1.9.19 ) diff --git a/tests/update_test.go b/tests/update_test.go index d566c04d..1944ed3f 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -610,3 +610,23 @@ func TestSave(t *testing.T) { t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) } } + +func TestSaveWithPrimaryValue(t *testing.T) { + lang := Language{Code: "save", Name: "save"} + if result := DB.Save(&lang); result.RowsAffected != 1 { + t.Errorf("should create language, rows affected: %v", result.RowsAffected) + } + + var result Language + DB.First(&result, "code = ?", "save") + AssertEqual(t, result, lang) + + lang.Name = "save name2" + if result := DB.Save(&lang); result.RowsAffected != 1 { + t.Errorf("should update language") + } + + var result2 Language + DB.First(&result2, "code = ?", "save") + AssertEqual(t, result2, lang) +} From 53f8c9fc1c5d24324308673cc9ae3afd4442516a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 30 Aug 2020 20:57:58 +0800 Subject: [PATCH 619/881] More compatible prioritized primary field #3156 --- schema/schema.go | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index 458256d1..ea81d683 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -161,13 +161,18 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) field.setupValuerAndSetter() } - if f := schema.LookUpField("id"); f != nil { - if f.PrimaryKey { - schema.PrioritizedPrimaryField = f + prioritizedPrimaryField := schema.LookUpField("id") + if prioritizedPrimaryField == nil { + prioritizedPrimaryField = schema.LookUpField("ID") + } + + if prioritizedPrimaryField != nil { + if prioritizedPrimaryField.PrimaryKey { + schema.PrioritizedPrimaryField = prioritizedPrimaryField } else if len(schema.PrimaryFields) == 0 { - f.PrimaryKey = true - schema.PrioritizedPrimaryField = f - schema.PrimaryFields = append(schema.PrimaryFields, f) + prioritizedPrimaryField.PrimaryKey = true + schema.PrioritizedPrimaryField = prioritizedPrimaryField + schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField) } } From 9b0ad4730f16d6ac7cf18d1aa42d74714959745b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 31 Aug 2020 12:08:33 +0800 Subject: [PATCH 620/881] Squashed commit of the following: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit commit 759038a126122d5b3323979fdd7d867a4ab85585 Author: Jinzhu Date: Mon Aug 31 12:06:31 2020 +0800 Add PreparedStmt tests commit 066d54db1fc93ea58c190195104a2d7086623f69 Author: 王岚 Date: Fri Aug 28 18:40:59 2020 +0800 prepare_stmt add ctx --- gorm.go | 1 + prepare_stmt.go | 22 ++++++++--------- tests/prepared_stmt_test.go | 48 +++++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 11 deletions(-) create mode 100644 tests/prepared_stmt_test.go diff --git a/gorm.go b/gorm.go index 3c187f42..fec4310b 100644 --- a/gorm.go +++ b/gorm.go @@ -169,6 +169,7 @@ func (db *DB) Session(config *Session) *DB { if config.PrepareStmt { if v, ok := db.cacheStore.Load("preparedStmt"); ok { + tx.Statement = tx.Statement.clone() preparedStmt := v.(*PreparedStmtDB) tx.Statement.ConnPool = &PreparedStmtDB{ ConnPool: db.Config.ConnPool, diff --git a/prepare_stmt.go b/prepare_stmt.go index 7e87558d..7c80bafe 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -25,7 +25,7 @@ func (db *PreparedStmtDB) Close() { db.Mux.Unlock() } -func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { +func (db *PreparedStmtDB) prepare(ctx context.Context, query string) (*sql.Stmt, error) { db.Mux.RLock() if stmt, ok := db.Stmts[query]; ok { db.Mux.RUnlock() @@ -40,7 +40,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { return stmt, nil } - stmt, err := db.ConnPool.PrepareContext(context.Background(), query) + stmt, err := db.ConnPool.PrepareContext(ctx, query) if err == nil { db.Stmts[query] = stmt db.PreparedSQL = append(db.PreparedSQL, query) @@ -59,7 +59,7 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn } func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { - stmt, err := db.prepare(query) + stmt, err := db.prepare(ctx, query) if err == nil { result, err = stmt.ExecContext(ctx, args...) if err != nil { @@ -73,7 +73,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. } func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { - stmt, err := db.prepare(query) + stmt, err := db.prepare(ctx, query) if err == nil { rows, err = stmt.QueryContext(ctx, args...) if err != nil { @@ -87,7 +87,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . } func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - stmt, err := db.prepare(query) + stmt, err := db.prepare(ctx, query) if err == nil { return stmt.QueryRowContext(ctx, args...) } @@ -100,9 +100,9 @@ type PreparedStmtTX struct { } func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { - stmt, err := tx.PreparedStmtDB.prepare(query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, query) if err == nil { - result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...) + result, err = tx.Tx.StmtContext(ctx, stmt).ExecContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() stmt.Close() @@ -114,9 +114,9 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. } func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { - stmt, err := tx.PreparedStmtDB.prepare(query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, query) if err == nil { - rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) + rows, err = tx.Tx.StmtContext(ctx, stmt).QueryContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() stmt.Close() @@ -128,9 +128,9 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . } func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - stmt, err := tx.PreparedStmtDB.prepare(query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, query) if err == nil { - return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...) + return tx.Tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...) } return &sql.Row{} } diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go new file mode 100644 index 00000000..b81318d3 --- /dev/null +++ b/tests/prepared_stmt_test.go @@ -0,0 +1,48 @@ +package tests_test + +import ( + "context" + "testing" + "time" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestPreparedStmt(t *testing.T) { + tx := DB.Session(&gorm.Session{PrepareStmt: true}) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + txCtx := tx.WithContext(ctx) + + user := *GetUser("prepared_stmt", Config{}) + + txCtx.Create(&user) + + var result1 User + if err := txCtx.Find(&result1, user.ID).Error; err != nil { + t.Fatalf("no error should happen but got %v", err) + } + + time.Sleep(time.Second) + + var result2 User + if err := tx.Find(&result2, user.ID).Error; err != nil { + t.Fatalf("no error should happen but got %v", err) + } + + user2 := *GetUser("prepared_stmt2", Config{}) + if err := txCtx.Create(&user2).Error; err == nil { + t.Fatalf("should failed to create with timeout context") + } + + if err := tx.Create(&user2).Error; err != nil { + t.Fatalf("failed to create, got error %v", err) + } + + var result3 User + if err := tx.Find(&result3, user2.ID).Error; err != nil { + t.Fatalf("no error should happen but got %v", err) + } +} From 496db1f13e51ef20db2a68f6591047df6b20e292 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 31 Aug 2020 15:45:56 +0800 Subject: [PATCH 621/881] Fix named argument with multiple line SQL, fix #3336 --- clause/expression.go | 2 +- prepare_stmt.go | 2 +- tests/go.mod | 2 ++ tests/named_argument_test.go | 14 +++++++++++++- 4 files changed, 17 insertions(+), 3 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 4d5e328b..3b914e68 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -94,7 +94,7 @@ func (expr NamedExpr) Build(builder Builder) { if v == '@' && !inName { inName = true name = []byte{} - } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' { + } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' { if inName { if nv, ok := namedMap[string(name)]; ok { builder.AddVar(builder, nv) diff --git a/prepare_stmt.go b/prepare_stmt.go index 7c80bafe..de7e2a26 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -116,7 +116,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { stmt, err := tx.PreparedStmtDB.prepare(ctx, query) if err == nil { - rows, err = tx.Tx.StmtContext(ctx, stmt).QueryContext(ctx, args...) + rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() stmt.Close() diff --git a/tests/go.mod b/tests/go.mod index c09747ab..f3dd6dbc 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -14,3 +14,5 @@ require ( ) replace gorm.io/gorm => ../ + +replace github.com/jackc/pgx/v4 => github.com/jinzhu/pgx/v4 v4.8.2 diff --git a/tests/named_argument_test.go b/tests/named_argument_test.go index 56fad5f4..d0a6f915 100644 --- a/tests/named_argument_test.go +++ b/tests/named_argument_test.go @@ -48,10 +48,22 @@ func TestNamedArg(t *testing.T) { t.Errorf("failed to update with named arg") } + namedUser.Name1 = "jinzhu-new" + namedUser.Name2 = "jinzhu-new2" + namedUser.Name3 = "jinzhu-new" + var result5 NamedUser if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result5).Error; err != nil { t.Errorf("failed to update with named arg") } - AssertEqual(t, result4, namedUser) + AssertEqual(t, result5, namedUser) + + var result6 NamedUser + if err := DB.Raw(`SELECT * FROM named_users WHERE (name1 = @name + AND name3 = @name) AND name2 = @name2`, map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result6).Error; err != nil { + t.Errorf("failed to update with named arg") + } + + AssertEqual(t, result6, namedUser) } From 0273856e4d9744c98aa42b98d485d726099e9020 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 31 Aug 2020 16:27:22 +0800 Subject: [PATCH 622/881] Don't alter column with full column data type, close #3339 --- migrator/migrator.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index d93b8a6d..c736a3e0 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -297,10 +297,12 @@ func (m Migrator) DropColumn(value interface{}, name string) error { func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { + fileType := clause.Expr{SQL: m.DataTypeOf(field)} return m.DB.Exec( "ALTER TABLE ? ALTER COLUMN ? TYPE ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, fileType, ).Error + } return fmt.Errorf("failed to look up field with name: %s", field) }) From 162367be7d1d10aa59dc08bb507c356b4495c95e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 1 Sep 2020 11:30:16 +0800 Subject: [PATCH 623/881] Fix multiple M2M relations on one table, close #3347 --- schema/relationship.go | 62 +++++++++++++++++++++---------------- schema/relationship_test.go | 31 +++++++++++++++++++ 2 files changed, 66 insertions(+), 27 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index 5132ff74..aa992b84 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -254,12 +254,18 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel }) } + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: schema.Name + field.Name, + Type: schema.ModelType, + Tag: `gorm:"-"`, + }) + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { schema.err = err } relation.JoinTable.Name = many2many relation.JoinTable.Table = schema.namer.JoinTableName(many2many) - relation.JoinTable.PrimaryFields = make([]*Field, len(relation.JoinTable.Fields)) + relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields)) relName := relation.Schema.Name relRefName := relation.FieldSchema.Name @@ -290,36 +296,38 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } // build references - for idx, f := range relation.JoinTable.Fields { - // use same data type for foreign keys - f.DataType = fieldsMap[f.Name].DataType - f.GORMDataType = fieldsMap[f.Name].GORMDataType - relation.JoinTable.PrimaryFields[idx] = f - ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] + for _, f := range relation.JoinTable.Fields { + if f.Creatable { + // use same data type for foreign keys + f.DataType = fieldsMap[f.Name].DataType + f.GORMDataType = fieldsMap[f.Name].GORMDataType + relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) + ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] - if ownPriamryField { - joinRel := relation.JoinTable.Relationships.Relations[relName] - joinRel.Field = relation.Field - joinRel.References = append(joinRel.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], - ForeignKey: f, - }) - } else { - joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] - if joinRefRel.Field == nil { - joinRefRel.Field = relation.Field + if ownPriamryField { + joinRel := relation.JoinTable.Relationships.Relations[relName] + joinRel.Field = relation.Field + joinRel.References = append(joinRel.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + }) + } else { + joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] + if joinRefRel.Field == nil { + joinRefRel.Field = relation.Field + } + joinRefRel.References = append(joinRefRel.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + }) } - joinRefRel.References = append(joinRefRel.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], - ForeignKey: f, + + relation.References = append(relation.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + OwnPrimaryKey: ownPriamryField, }) } - - relation.References = append(relation.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], - ForeignKey: f, - OwnPrimaryKey: ownPriamryField, - }) } } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 2e85c538..f2d63323 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -267,3 +267,34 @@ func TestMany2ManyWithMultiPrimaryKeys(t *testing.T) { }, ) } + +func TestMultipleMany2Many(t *testing.T) { + type Thing struct { + ID int + } + + type Person struct { + ID int + Likes []Thing `gorm:"many2many:likes"` + Dislikes []Thing `gorm:"many2many:dislikes"` + } + + checkStructRelation(t, &Person{}, + Relation{ + Name: "Likes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing", + JoinTable: JoinTable{Name: "likes", Table: "likes"}, + References: []Reference{ + {"ID", "Person", "PersonID", "likes", "", true}, + {"ID", "Thing", "ThingID", "likes", "", false}, + }, + }, + Relation{ + Name: "Dislikes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing", + JoinTable: JoinTable{Name: "dislikes", Table: "dislikes"}, + References: []Reference{ + {"ID", "Person", "PersonID", "dislikes", "", true}, + {"ID", "Thing", "ThingID", "dislikes", "", false}, + }, + }, + ) +} From 308d22b166eb3b71d2a3374bfc565be29ed88eda Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 1 Sep 2020 13:48:37 +0800 Subject: [PATCH 624/881] Clean up associations before Preload, close #3345 --- callbacks/preload.go | 10 ++++++++++ tests/helper_test.go | 10 +++++----- tests/preload_test.go | 14 ++++++++++++++ 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index 25b8cb2b..9b8f762a 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -107,6 +107,16 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { fieldValues := make([]interface{}, len(relForeignFields)) + // clean up old values before preloading + switch reflectValue.Kind() { + case reflect.Struct: + rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + } + } + for i := 0; i < reflectResults.Len(); i++ { elem := reflectResults.Index(i) for idx, field := range relForeignFields { diff --git a/tests/helper_test.go b/tests/helper_test.go index cc0d808c..eee34e99 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -115,7 +115,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Pets", func(t *testing.T) { if len(user.Pets) != len(expect.Pets) { - t.Errorf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets)) + t.Fatalf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets)) } sort.Slice(user.Pets, func(i, j int) bool { @@ -137,7 +137,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Toys", func(t *testing.T) { if len(user.Toys) != len(expect.Toys) { - t.Errorf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys)) + t.Fatalf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys)) } sort.Slice(user.Toys, func(i, j int) bool { @@ -177,7 +177,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Team", func(t *testing.T) { if len(user.Team) != len(expect.Team) { - t.Errorf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team)) + t.Fatalf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team)) } sort.Slice(user.Team, func(i, j int) bool { @@ -195,7 +195,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Languages", func(t *testing.T) { if len(user.Languages) != len(expect.Languages) { - t.Errorf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages)) + t.Fatalf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages)) } sort.Slice(user.Languages, func(i, j int) bool { @@ -212,7 +212,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Friends", func(t *testing.T) { if len(user.Friends) != len(expect.Friends) { - t.Errorf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends)) + t.Fatalf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends)) } sort.Slice(user.Friends, func(i, j int) bool { diff --git a/tests/preload_test.go b/tests/preload_test.go index 7e5d2622..76b72f14 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -31,6 +31,20 @@ func TestPreloadWithAssociations(t *testing.T) { var user2 User DB.Preload(clause.Associations).Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + var user3 = *GetUser("preload_with_associations_new", Config{ + Account: true, + Pets: 2, + Toys: 3, + Company: true, + Manager: true, + Team: 4, + Languages: 3, + Friends: 1, + }) + + DB.Preload(clause.Associations).Find(&user3, "id = ?", user.ID) + CheckUser(t, user3, user) } func TestNestedPreload(t *testing.T) { From e98a4a3a4ef602a20803c1fc4deb3f8bdbf84fec Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 1 Sep 2020 14:01:59 +0800 Subject: [PATCH 625/881] Change default timeout interval to avoid test fail on CI --- tests/prepared_stmt_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index b81318d3..af610165 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -12,7 +12,7 @@ import ( func TestPreparedStmt(t *testing.T) { tx := DB.Session(&gorm.Session{PrepareStmt: true}) - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() txCtx := tx.WithContext(ctx) From e6f4b711a7e1f885a2200b22e40786cf0dacddcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=8B=E5=B0=8F=E5=8C=97?= Date: Tue, 1 Sep 2020 15:50:53 +0800 Subject: [PATCH 626/881] fix order case (#3350) --- chainable_api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chainable_api.go b/chainable_api.go index c8417a6d..ae2ac4f1 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -198,7 +198,7 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { // Order specify order when retrieve records from database // db.Order("name DESC") -// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression +// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) func (db *DB) Order(value interface{}) (tx *DB) { tx = db.getInstance() From e73147fa8e25bea98257444ae1d65e19a1af089d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 1 Sep 2020 16:55:30 +0800 Subject: [PATCH 627/881] Better support for scan into map, fix unfriendly data type for interface, close #3351 --- scan.go | 72 +++++++++++++++++++----------- tests/query_test.go | 104 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 149 insertions(+), 27 deletions(-) diff --git a/scan.go b/scan.go index 0b199029..89d9a07a 100644 --- a/scan.go +++ b/scan.go @@ -2,12 +2,52 @@ package gorm import ( "database/sql" + "database/sql/driver" "reflect" "strings" "gorm.io/gorm/schema" ) +func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { + if db.Statement.Schema != nil { + for idx, name := range columns { + if field := db.Statement.Schema.LookUpField(name); field != nil { + values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + continue + } + values[idx] = new(interface{}) + } + } else if len(columnTypes) > 0 { + for idx, columnType := range columnTypes { + if columnType.ScanType() != nil { + values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface() + } else { + values[idx] = new(interface{}) + } + } + } else { + for idx := range columns { + values[idx] = new(interface{}) + } + } +} + +func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) { + for idx, column := range columns { + if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() { + mapValue[column] = reflectValue.Interface() + if valuer, ok := mapValue[column].(driver.Valuer); ok { + mapValue[column], _ = valuer.Value() + } else if b, ok := mapValue[column].(sql.RawBytes); ok { + mapValue[column] = string(b) + } + } else { + mapValue[column] = nil + } + } +} + func Scan(rows *sql.Rows, db *DB, initialized bool) { columns, _ := rows.Columns() values := make([]interface{}, len(columns)) @@ -15,9 +55,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { switch dest := db.Statement.Dest.(type) { case map[string]interface{}, *map[string]interface{}: if initialized || rows.Next() { - for idx := range columns { - values[idx] = new(interface{}) - } + columnTypes, _ := rows.ColumnTypes() + prepareValues(values, db, columnTypes, columns) db.RowsAffected++ db.AddError(rows.Scan(values...)) @@ -28,38 +67,19 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { mapValue = *v } } - - for idx, column := range columns { - if v, ok := values[idx].(*interface{}); ok { - if v == nil { - mapValue[column] = nil - } else { - mapValue[column] = *v - } - } - } + scanIntoMap(mapValue, values, columns) } case *[]map[string]interface{}: + columnTypes, _ := rows.ColumnTypes() for initialized || rows.Next() { - for idx := range columns { - values[idx] = new(interface{}) - } + prepareValues(values, db, columnTypes, columns) initialized = false db.RowsAffected++ db.AddError(rows.Scan(values...)) mapValue := map[string]interface{}{} - for idx, column := range columns { - if v, ok := values[idx].(*interface{}); ok { - if v == nil { - mapValue[column] = nil - } else { - mapValue[column] = *v - } - } - } - + scanIntoMap(mapValue, values, columns) *dest = append(*dest, mapValue) } case *int, *int64, *uint, *uint64, *float32, *float64: diff --git a/tests/query_test.go b/tests/query_test.go index d71c813a..6bb68cd3 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -6,6 +6,7 @@ import ( "regexp" "sort" "strconv" + "strings" "testing" "time" @@ -61,6 +62,54 @@ func TestFind(t *testing.T) { for _, name := range []string{"Name", "Age", "Birthday"} { t.Run(name, func(t *testing.T) { dbName := DB.NamingStrategy.ColumnName("", name) + + switch name { + case "Name": + if _, ok := first[dbName].(string); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) + } + case "Age": + if _, ok := first[dbName].(uint); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) + } + case "Birthday": + if _, ok := first[dbName].(*time.Time); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) + } + } + + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) + AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) + }) + } + } + }) + + t.Run("FirstMapWithTable", func(t *testing.T) { + var first = map[string]interface{}{} + if err := DB.Table("users").Where("name = ?", "find").Find(first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + resultType := reflect.ValueOf(first[dbName]).Type().Name() + + switch name { + case "Name": + if !strings.Contains(resultType, "string") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) + } + case "Age": + if !strings.Contains(resultType, "int") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) + } + case "Birthday": + if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) + } + } + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) }) @@ -86,13 +135,29 @@ func TestFind(t *testing.T) { t.Run("FirstSliceOfMap", func(t *testing.T) { var allMap = []map[string]interface{}{} if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { - t.Errorf("errors happened when query first: %v", err) + t.Errorf("errors happened when query find: %v", err) } else { for idx, user := range users { t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { for _, name := range []string{"Name", "Age", "Birthday"} { t.Run(name, func(t *testing.T) { dbName := DB.NamingStrategy.ColumnName("", name) + + switch name { + case "Name": + if _, ok := allMap[idx][dbName].(string); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) + } + case "Age": + if _, ok := allMap[idx][dbName].(uint); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) + } + case "Birthday": + if _, ok := allMap[idx][dbName].(*time.Time); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) + } + } + reflectValue := reflect.Indirect(reflect.ValueOf(user)) AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) }) @@ -101,6 +166,43 @@ func TestFind(t *testing.T) { } } }) + + t.Run("FindSliceOfMapWithTable", func(t *testing.T) { + var allMap = []map[string]interface{}{} + if err := DB.Table("users").Where("name = ?", "find").Find(&allMap).Error; err != nil { + t.Errorf("errors happened when query find: %v", err) + } else { + for idx, user := range users { + t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + resultType := reflect.ValueOf(allMap[idx][dbName]).Type().Name() + + switch name { + case "Name": + if !strings.Contains(resultType, "string") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) + } + case "Age": + if !strings.Contains(resultType, "int") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) + } + case "Birthday": + if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) + } + } + + reflectValue := reflect.Indirect(reflect.ValueOf(user)) + AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) + }) + } + }) + } + } + }) + } func TestQueryWithAssociation(t *testing.T) { From bf6123b01e265ecfe709738b290c3ea3f6ad9bdc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 1 Sep 2020 18:05:26 +0800 Subject: [PATCH 628/881] Fix duplicated soft delete clause --- soft_delete.go | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/soft_delete.go b/soft_delete.go index 484f565c..b13fc63f 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -25,14 +25,7 @@ func (n DeletedAt) Value() (driver.Value, error) { } func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { - return []clause.Interface{ - clause.Where{Exprs: []clause.Expression{ - clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, - Value: nil, - }, - }}, - } + return []clause.Interface{SoftDeleteQueryClause{Field: f}} } type SoftDeleteQueryClause struct { From 22317b43c007f1a4aa21d6bf6c3e5088ce0ca507 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 1 Sep 2020 18:58:16 +0800 Subject: [PATCH 629/881] Fix migrate field, failed to migrate when field size changed --- migrator/migrator.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index c736a3e0..1aebc50d 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -356,9 +356,9 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy alterColumn = true } else { // has size in data type and not equal - matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1) + matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllStringSubmatch(realDataType, -1) matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1) - if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) { + if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size)) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) { alterColumn = true } } From d1e17d549fc3fb9a66e150d425e090dca838ab07 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 1 Sep 2020 20:52:06 +0800 Subject: [PATCH 630/881] request ColumnTypes after new session method --- migrator/migrator.go | 2 +- tests/go.mod | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 1aebc50d..29d26c9e 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -388,7 +388,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { err = m.RunWithValue(value, func(stmt *gorm.Statement) error { - rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() + rows, err := m.DB.Session(&gorm.Session{}).Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() if err == nil { columnTypes, err = rows.ColumnTypes() } diff --git a/tests/go.mod b/tests/go.mod index f3dd6dbc..30a7dda7 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,10 +6,10 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v1.0.0 + gorm.io/driver/mysql v1.0.1 gorm.io/driver/postgres v1.0.0 - gorm.io/driver/sqlite v1.1.0 - gorm.io/driver/sqlserver v1.0.1 + gorm.io/driver/sqlite v1.1.1 + gorm.io/driver/sqlserver v1.0.2 gorm.io/gorm v1.9.19 ) From 9a101c8a089b724fc19af525fcdca58bff0b7997 Mon Sep 17 00:00:00 2001 From: aimuz Date: Tue, 1 Sep 2020 21:03:37 +0800 Subject: [PATCH 631/881] fmt.Sprint() to strconv.Format (#3354) --- logger/sql.go | 14 +++++++------- schema/field.go | 2 +- utils/utils.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 02d559c5..0efc0971 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -3,6 +3,7 @@ package logger import ( "database/sql/driver" "fmt" + "gorm.io/gorm/utils" "reflect" "regexp" "strconv" @@ -24,13 +25,12 @@ var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { var convertParams func(interface{}, int) - var vars = make([]interface{}, len(avars)) - copy(vars, avars) + var vars = make([]string, len(avars)) convertParams = func(v interface{}, idx int) { switch v := v.(type) { case bool: - vars[idx] = fmt.Sprint(v) + vars[idx] = strconv.FormatBool(v) case time.Time: if v.IsZero() { vars[idx] = escaper + "0000-00-00 00:00:00" + escaper @@ -44,7 +44,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a vars[idx] = escaper + "" + escaper } case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - vars[idx] = fmt.Sprintf("%d", v) + vars[idx] = utils.ToString(v) case float64, float32: vars[idx] = fmt.Sprintf("%.6f", v) case string: @@ -70,18 +70,18 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } } - for idx, v := range vars { + for idx, v := range avars { convertParams(v, idx) } if numericPlaceholder == nil { for _, v := range vars { - sql = strings.Replace(sql, "?", v.(string), 1) + sql = strings.Replace(sql, "?", v, 1) } } else { sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") for idx, v := range vars { - sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v.(string), 1) + sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1) } } diff --git a/schema/field.go b/schema/field.go index 524d19fb..2e649d81 100644 --- a/schema/field.go +++ b/schema/field.go @@ -671,7 +671,7 @@ func (field *Field) setupValuerAndSetter() { case []byte: field.ReflectValueOf(value).SetString(string(data)) case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - field.ReflectValueOf(value).SetString(fmt.Sprint(data)) + field.ReflectValueOf(value).SetString(utils.ToString(data)) case float64, float32: field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) default: diff --git a/utils/utils.go b/utils/utils.go index 71336f4b..905001a5 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -83,3 +83,31 @@ func AssertEqual(src, dst interface{}) bool { } return true } + +func ToString(value interface{}) string { + switch v := value.(type) { + case string: + return v + case int: + return strconv.FormatInt(int64(v), 10) + case int8: + return strconv.FormatInt(int64(v), 10) + case int16: + return strconv.FormatInt(int64(v), 10) + case int32: + return strconv.FormatInt(int64(v), 10) + case int64: + return strconv.FormatInt(v, 10) + case uint: + return strconv.FormatUint(uint64(v), 10) + case uint8: + return strconv.FormatUint(uint64(v), 10) + case uint16: + return strconv.FormatUint(uint64(v), 10) + case uint32: + return strconv.FormatUint(uint64(v), 10) + case uint64: + return strconv.FormatUint(v, 10) + } + return "" +} From dbaa6b0ec3f451903c2983fd091c52e5efc60669 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 2 Sep 2020 16:14:26 +0800 Subject: [PATCH 632/881] Fix Scan struct with primary key, close #3357 --- callbacks.go | 2 ++ callbacks/row.go | 2 +- finisher_api.go | 19 ++++++++++++++----- logger/sql.go | 3 ++- migrator.go | 2 +- tests/scan_test.go | 18 +++++++++++++++--- 6 files changed, 35 insertions(+), 11 deletions(-) diff --git a/callbacks.go b/callbacks.go index baeb6c09..eace06ca 100644 --- a/callbacks.go +++ b/callbacks.go @@ -79,6 +79,8 @@ func (p *processor) Execute(db *DB) { if stmt.Model == nil { stmt.Model = stmt.Dest + } else if stmt.Dest == nil { + stmt.Dest = stmt.Model } if stmt.Model != nil { diff --git a/callbacks/row.go b/callbacks/row.go index 7e70382e..a36c0116 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -11,7 +11,7 @@ func RowQuery(db *gorm.DB) { } if !db.DryRun { - if _, ok := db.Get("rows"); ok { + if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) { db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } else { db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/finisher_api.go b/finisher_api.go index a205b859..1d5ef5fc 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -331,13 +331,13 @@ func (db *DB) Count(count *int64) (tx *DB) { } func (db *DB) Row() *sql.Row { - tx := db.getInstance() + tx := db.getInstance().InstanceSet("rows", false) tx.callbacks.Row().Execute(tx) return tx.Statement.Dest.(*sql.Row) } func (db *DB) Rows() (*sql.Rows, error) { - tx := db.Set("rows", true) + tx := db.getInstance().InstanceSet("rows", true) tx.callbacks.Row().Execute(tx) return tx.Statement.Dest.(*sql.Rows), tx.Error } @@ -345,8 +345,14 @@ func (db *DB) Rows() (*sql.Rows, error) { // Scan scan value to a struct func (db *DB) Scan(dest interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) + if rows, err := tx.Rows(); err != nil { + tx.AddError(err) + } else { + defer rows.Close() + if rows.Next() { + tx.ScanRows(rows, dest) + } + } return } @@ -379,7 +385,10 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { tx := db.getInstance() tx.Error = tx.Statement.Parse(dest) tx.Statement.Dest = dest - tx.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(dest)) + tx.Statement.ReflectValue = reflect.ValueOf(dest) + for tx.Statement.ReflectValue.Kind() == reflect.Ptr { + tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem() + } Scan(rows, tx, true) return tx.Error } diff --git a/logger/sql.go b/logger/sql.go index 0efc0971..80645b0c 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -3,13 +3,14 @@ package logger import ( "database/sql/driver" "fmt" - "gorm.io/gorm/utils" "reflect" "regexp" "strconv" "strings" "time" "unicode" + + "gorm.io/gorm/utils" ) func isPrintable(s []byte) bool { diff --git a/migrator.go b/migrator.go index ed8a8e26..162fe680 100644 --- a/migrator.go +++ b/migrator.go @@ -9,7 +9,7 @@ import ( // Migrator returns migrator func (db *DB) Migrator() Migrator { - return db.Dialector.Migrator(db) + return db.Dialector.Migrator(db.Session(&Session{WithConditions: true})) } // AutoMigrate run auto migration for given models diff --git a/tests/scan_test.go b/tests/scan_test.go index d6a372bb..3e66a25a 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -6,6 +6,7 @@ import ( "strings" "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -16,14 +17,25 @@ func TestScan(t *testing.T) { DB.Save(&user1).Save(&user2).Save(&user3) type result struct { + ID uint Name string Age int } var res result - DB.Table("users").Select("name, age").Where("id = ?", user3.ID).Scan(&res) - if res.Name != user3.Name || res.Age != int(user3.Age) { - t.Errorf("Scan into struct should work") + DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&res) + if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) + } + + DB.Table("users").Select("id, name, age").Where("id = ?", user2.ID).Scan(&res) + if res.ID != user2.ID || res.Name != user2.Name || res.Age != int(user2.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user2) + } + + DB.Model(&User{Model: gorm.Model{ID: user3.ID}}).Select("id, name, age").Scan(&res) + if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) } var doubleAgeRes = &result{} From 680dda2c159d21c0b8f677b25519ec7fec29cd4d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 2 Sep 2020 20:09:51 +0800 Subject: [PATCH 633/881] Fix combine conditions when using string conditions, close #3358 --- clause/where.go | 52 ++++++++++++++++++++++++++++++++++++- tests/sql_builder_test.go | 54 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/clause/where.go b/clause/where.go index 9af9701c..a3774e1c 100644 --- a/clause/where.go +++ b/clause/where.go @@ -1,5 +1,9 @@ package clause +import ( + "strings" +) + // Where where clause type Where struct { Exprs []Expression @@ -22,6 +26,7 @@ func (where Where) Build(builder Builder) { } } + wrapInParentheses := false for idx, expr := range where.Exprs { if idx > 0 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { @@ -31,7 +36,36 @@ func (where Where) Build(builder Builder) { } } - expr.Build(builder) + if len(where.Exprs) > 1 { + switch v := expr.(type) { + case OrConditions: + if len(v.Exprs) == 1 { + if e, ok := v.Exprs[0].(Expr); ok { + sql := strings.ToLower(e.SQL) + wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + } + } + case AndConditions: + if len(v.Exprs) == 1 { + if e, ok := v.Exprs[0].(Expr); ok { + sql := strings.ToLower(e.SQL) + wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + } + } + case Expr: + sql := strings.ToLower(v.SQL) + wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + } + } + + if wrapInParentheses { + builder.WriteString(`(`) + expr.Build(builder) + builder.WriteString(`)`) + wrapInParentheses = false + } else { + expr.Build(builder) + } } } @@ -50,6 +84,8 @@ func (where Where) MergeClause(clause *Clause) { func And(exprs ...Expression) Expression { if len(exprs) == 0 { return nil + } else if len(exprs) == 1 { + return exprs[0] } return AndConditions{Exprs: exprs} } @@ -118,6 +154,7 @@ func (not NotConditions) Build(builder Builder) { if len(not.Exprs) > 1 { builder.WriteByte('(') } + for idx, c := range not.Exprs { if idx > 0 { builder.WriteString(" AND ") @@ -127,9 +164,22 @@ func (not NotConditions) Build(builder Builder) { negationBuilder.NegationBuild(builder) } else { builder.WriteString("NOT ") + e, wrapInParentheses := c.(Expr) + if wrapInParentheses { + sql := strings.ToLower(e.SQL) + if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses { + builder.WriteByte('(') + } + } + c.Build(builder) + + if wrapInParentheses { + builder.WriteByte(')') + } } } + if len(not.Exprs) > 1 { builder.WriteByte(')') } diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index e6038947..c0176fc3 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "regexp" "strings" "testing" @@ -188,3 +189,56 @@ func TestGroupConditions(t *testing.T) { t.Errorf("expects: %v, got %v", expects, result) } } + +func TestCombineStringConditions(t *testing.T) { + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + sql := dryRunDB.Where("a = ? or b = ?", "a", "b").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR c = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ?", "c").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND c = .+ AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ?", "e").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT e = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Or("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Not("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE NOT \(a = .+ or b = .+\)$`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } +} From dbe0f4d8d7dad471d7e3931ecb7e24610adb76f5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 2 Sep 2020 20:15:12 +0800 Subject: [PATCH 634/881] Allow use NULL as default value for string, close #3363 --- schema/field.go | 2 +- tests/default_value_test.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/schema/field.go b/schema/field.go index 2e649d81..b49b7de6 100644 --- a/schema/field.go +++ b/schema/field.go @@ -201,7 +201,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.String: field.DataType = String isFunc := strings.Contains(field.DefaultValue, "(") && - strings.Contains(field.DefaultValue, ")") + strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" if field.HasDefaultValue && !isFunc { field.DefaultValue = strings.Trim(field.DefaultValue, "'") diff --git a/tests/default_value_test.go b/tests/default_value_test.go index aa4a511a..44309eab 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -13,6 +13,7 @@ func TestDefaultValue(t *testing.T) { Name string `gorm:"not null;default:foo"` Name2 string `gorm:"size:233;not null;default:'foo'"` Name3 string `gorm:"size:233;not null;default:''"` + Name4 string `gorm:"size:233;default:null"` Age int `gorm:"default:18"` Enabled bool `gorm:"default:true"` } From 130f24090db2b9862282281f9dd288c2a214a263 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 2 Sep 2020 21:03:47 +0800 Subject: [PATCH 635/881] update default_value_test --- tests/default_value_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/default_value_test.go b/tests/default_value_test.go index 44309eab..aa4a511a 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -13,7 +13,6 @@ func TestDefaultValue(t *testing.T) { Name string `gorm:"not null;default:foo"` Name2 string `gorm:"size:233;not null;default:'foo'"` Name3 string `gorm:"size:233;not null;default:''"` - Name4 string `gorm:"size:233;default:null"` Age int `gorm:"default:18"` Enabled bool `gorm:"default:true"` } From fcb666cfa31ecf0de77fcd23e60a67c6819ad7fa Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 10:58:48 +0800 Subject: [PATCH 636/881] Fix associations using composite primary keys without ID field, close #3365 --- callbacks/associations.go | 18 +++++++++++++--- tests/multi_primary_keys_test.go | 36 ++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 2710ffe9..0c677f47 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -5,6 +5,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" + "gorm.io/gorm/schema" ) func SaveBeforeAssociations(db *gorm.DB) { @@ -145,7 +146,7 @@ func SaveAfterAssociations(db *gorm.DB) { } db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, + Columns: onConflictColumns(rel.FieldSchema), DoUpdates: clause.AssignmentColumns(assignmentColumns), }).Create(elems.Interface()).Error) } @@ -168,7 +169,7 @@ func SaveAfterAssociations(db *gorm.DB) { } db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, + Columns: onConflictColumns(rel.FieldSchema), DoUpdates: clause.AssignmentColumns(assignmentColumns), }).Create(f.Interface()).Error) } @@ -230,7 +231,7 @@ func SaveAfterAssociations(db *gorm.DB) { } db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, + Columns: onConflictColumns(rel.FieldSchema), DoUpdates: clause.AssignmentColumns(assignmentColumns), }).Create(elems.Interface()).Error) } @@ -310,3 +311,14 @@ func SaveAfterAssociations(db *gorm.DB) { } } } + +func onConflictColumns(s *schema.Schema) (columns []clause.Column) { + if s.PrioritizedPrimaryField != nil { + return []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}} + } + + for _, dbName := range s.PrimaryFieldDBNames { + columns = append(columns, clause.Column{Name: dbName}) + } + return +} diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go index 051e3ee2..68da8a88 100644 --- a/tests/multi_primary_keys_test.go +++ b/tests/multi_primary_keys_test.go @@ -6,6 +6,7 @@ import ( "testing" "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" ) type Blog struct { @@ -410,3 +411,38 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { t.Fatalf("EN Blog's tags should be cleared") } } + +func TestCompositePrimaryKeysAssociations(t *testing.T) { + type Label struct { + BookID *uint `gorm:"primarykey"` + Name string `gorm:"primarykey"` + Value string + } + + type Book struct { + ID int + Name string + Labels []Label + } + + DB.Migrator().DropTable(&Label{}, &Book{}) + if err := DB.AutoMigrate(&Label{}, &Book{}); err != nil { + t.Fatalf("failed to migrate") + } + + book := Book{ + Name: "my book", + Labels: []Label{ + {Name: "region", Value: "emea"}, + }, + } + + DB.Create(&book) + + var result Book + if err := DB.Preload("Labels").First(&result, book.ID).Error; err != nil { + t.Fatalf("failed to preload, got error %v", err) + } + + AssertEqual(t, book, result) +} From 48b395b760d86fddad7480972791444494a8ae68 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 11:32:30 +0800 Subject: [PATCH 637/881] returns ErrEmptySlice when creating with zero length slice --- callbacks/create.go | 5 +++++ callbacks/helper.go | 5 +++++ errors.go | 2 ++ tests/create_test.go | 12 ++++++++++++ tests/go.mod | 2 ++ 5 files changed, 26 insertions(+) diff --git a/callbacks/create.go b/callbacks/create.go index 5de19d35..e37c2c60 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -252,6 +252,11 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { stmt.SQL.Grow(stmt.ReflectValue.Len() * 15) values.Values = make([][]interface{}, stmt.ReflectValue.Len()) defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} + if stmt.ReflectValue.Len() == 0 { + stmt.AddError(gorm.ErrEmptySlice) + return + } + for i := 0; i < stmt.ReflectValue.Len(); i++ { rv := reflect.Indirect(stmt.ReflectValue.Index(i)) values.Values[i] = make([]interface{}, len(values.Columns)) diff --git a/callbacks/helper.go b/callbacks/helper.go index e0a66dd2..09ec4582 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -46,6 +46,11 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) ) + if len(mapValues) == 0 { + stmt.AddError(gorm.ErrEmptySlice) + return + } + for idx, mapValue := range mapValues { for k, v := range mapValue { if stmt.Schema != nil { diff --git a/errors.go b/errors.go index 32ff8ec1..508f6957 100644 --- a/errors.go +++ b/errors.go @@ -27,4 +27,6 @@ var ( ErrRegistered = errors.New("registered") // ErrInvalidField invalid field ErrInvalidField = errors.New("invalid field") + // ErrEmptySlice empty slice found + ErrEmptySlice = errors.New("empty slice found") ) diff --git a/tests/create_test.go b/tests/create_test.go index ab0a78d4..59fdd8f1 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -287,6 +287,18 @@ func TestCreateEmptyStruct(t *testing.T) { } } +func TestCreateEmptySlice(t *testing.T) { + var data = []User{} + if err := DB.Create(&data).Error; err != gorm.ErrEmptySlice { + t.Errorf("no data should be created, got %v", err) + } + + var sliceMap = []map[string]interface{}{} + if err := DB.Model(&User{}).Create(&sliceMap).Error; err != gorm.ErrEmptySlice { + t.Errorf("no data should be created, got %v", err) + } +} + func TestCreateWithExistingTimestamp(t *testing.T) { user := User{Name: "CreateUserExistingTimestamp"} curTime := now.MustParse("2016-01-01") diff --git a/tests/go.mod b/tests/go.mod index 30a7dda7..2b336850 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -16,3 +16,5 @@ require ( replace gorm.io/gorm => ../ replace github.com/jackc/pgx/v4 => github.com/jinzhu/pgx/v4 v4.8.2 + +replace gorm.io/driver/sqlserver => /Users/jinzhu/Projects/jinzhu/sqlserver From ff3880292dc89da8061269e74cfdeb75e20aee6f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 11:48:44 +0800 Subject: [PATCH 638/881] Update missing playground template --- .github/workflows/missing_playground.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index 6fb714ca..422cb9f5 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -13,7 +13,7 @@ jobs: uses: actions/stale@v3.0.7 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "This issue has been automatically marked as stale as it missing playground pull request link, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details, it will be closed in 2 days if no further activity occurs." + stale-issue-message: "This issue has been automatically marked as stale as it missing playground pull request link, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details, it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-label: "status:stale" days-before-stale: 0 days-before-close: 2 From 98e15e0b95b39f9caefbb8b14a1e479a237e52fe Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 12:54:26 +0800 Subject: [PATCH 639/881] Setup DB's ConnPool in PrepareStmt mode, fix #3362 --- gorm.go | 2 ++ tests/prepared_stmt_test.go | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/gorm.go b/gorm.go index fec4310b..ed01ccfe 100644 --- a/gorm.go +++ b/gorm.go @@ -176,6 +176,8 @@ func (db *DB) Session(config *Session) *DB { Mux: preparedStmt.Mux, Stmts: preparedStmt.Stmts, } + txConfig.ConnPool = tx.Statement.ConnPool + txConfig.PrepareStmt = true } } diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index af610165..6b10b6dc 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -12,6 +12,10 @@ import ( func TestPreparedStmt(t *testing.T) { tx := DB.Session(&gorm.Session{PrepareStmt: true}) + if _, ok := tx.ConnPool.(*gorm.PreparedStmtDB); !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() txCtx := tx.WithContext(ctx) From 3cc7a307122e1ca2d0fbb298c264c51fce1bdd62 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 13:28:37 +0800 Subject: [PATCH 640/881] Fix tests/go.mod --- tests/go.mod | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index 2b336850..30a7dda7 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -16,5 +16,3 @@ require ( replace gorm.io/gorm => ../ replace github.com/jackc/pgx/v4 => github.com/jinzhu/pgx/v4 v4.8.2 - -replace gorm.io/driver/sqlserver => /Users/jinzhu/Projects/jinzhu/sqlserver From cf31508095ecae9a50ecfde1cf7c534d01fbe745 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 15:02:04 +0800 Subject: [PATCH 641/881] Fix tests_all.sh --- tests/go.mod | 2 +- tests/tests_all.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index 30a7dda7..4ddb0b69 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( gorm.io/driver/mysql v1.0.1 gorm.io/driver/postgres v1.0.0 gorm.io/driver/sqlite v1.1.1 - gorm.io/driver/sqlserver v1.0.2 + gorm.io/driver/sqlserver v1.0.3 gorm.io/gorm v1.9.19 ) diff --git a/tests/tests_all.sh b/tests/tests_all.sh index e87ff045..744a40e9 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -10,7 +10,7 @@ if [ -d tests ] then cd tests cp go.mod go.mod.bak - sed '/$[[:space:]]*gorm.io\/driver/d' go.mod.bak > go.mod + sed '/^[[:blank:]]*gorm.io\/driver/d' go.mod.bak > go.mod cd .. fi From f2adb088c598400086b6e67506ffee38780e9c3f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 16:11:15 +0800 Subject: [PATCH 642/881] Set field size from primary fields to foreign fields --- gorm.go | 3 +++ schema/relationship.go | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/gorm.go b/gorm.go index ed01ccfe..8efd8a73 100644 --- a/gorm.go +++ b/gorm.go @@ -319,6 +319,9 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { f.DataType = ref.ForeignKey.DataType f.GORMDataType = ref.ForeignKey.GORMDataType + if f.Size == 0 { + f.Size = ref.ForeignKey.Size + } ref.ForeignKey = f } else { return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) diff --git a/schema/relationship.go b/schema/relationship.go index aa992b84..47b948dc 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -165,6 +165,9 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi // use same data type for foreign keys relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType + if relation.Polymorphic.PolymorphicID.Size == 0 { + relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size + } relation.References = append(relation.References, &Reference{ PrimaryKey: primaryKeyField, @@ -301,6 +304,9 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel // use same data type for foreign keys f.DataType = fieldsMap[f.Name].DataType f.GORMDataType = fieldsMap[f.Name].GORMDataType + if f.Size == 0 { + f.Size = fieldsMap[f.Name].Size + } relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] @@ -436,6 +442,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue // use same data type for foreign keys foreignField.DataType = primaryFields[idx].DataType foreignField.GORMDataType = primaryFields[idx].GORMDataType + if foreignField.Size == 0 { + foreignField.Size = primaryFields[idx].Size + } relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], From 78e9c9b7488fbc71bf2ab853db4490d241cb0ada Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 18:20:57 +0800 Subject: [PATCH 643/881] raise error when failed to parse default value, close #3378 --- schema/field.go | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/schema/field.go b/schema/field.go index b49b7de6..0cb210f8 100644 --- a/schema/field.go +++ b/schema/field.go @@ -70,6 +70,8 @@ type Field struct { } func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { + var err error + field := &Field{ Name: fieldStruct.Name, BindNames: []string{fieldStruct.Name}, @@ -151,7 +153,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if num, ok := field.TagSettings["SIZE"]; ok { - var err error if field.Size, err = strconv.Atoi(num); err != nil { field.Size = -1 } @@ -181,22 +182,30 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.Bool: field.DataType = Bool if field.HasDefaultValue && field.DefaultValue != "" { - field.DefaultValueInterface, _ = strconv.ParseBool(field.DefaultValue) + if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err) + } } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: field.DataType = Int if field.HasDefaultValue && field.DefaultValue != "" { - field.DefaultValueInterface, _ = strconv.ParseInt(field.DefaultValue, 0, 64) + if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err) + } } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: field.DataType = Uint if field.HasDefaultValue && field.DefaultValue != "" { - field.DefaultValueInterface, _ = strconv.ParseUint(field.DefaultValue, 0, 64) + if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err) + } } case reflect.Float32, reflect.Float64: field.DataType = Float if field.HasDefaultValue && field.DefaultValue != "" { - field.DefaultValueInterface, _ = strconv.ParseFloat(field.DefaultValue, 64) + if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err) + } } case reflect.String: field.DataType = String From 3cd81ff646090931556cf5590c41ac5d5746357c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 18:42:32 +0800 Subject: [PATCH 644/881] Fix query with specified table and conditions, close #3382 --- statement.go | 8 ++++---- tests/query_test.go | 9 ++++++++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/statement.go b/statement.go index d72a086f..e16cf0ff 100644 --- a/statement.go +++ b/statement.go @@ -317,9 +317,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c if field.Readable { if v, isZero := field.ValueOf(reflectValue); !isZero { if field.DBName != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) } } } @@ -330,9 +330,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c if field.Readable { if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { if field.DBName != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) } } } diff --git a/tests/query_test.go b/tests/query_test.go index 6bb68cd3..795186da 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -202,7 +202,6 @@ func TestFind(t *testing.T) { } } }) - } func TestQueryWithAssociation(t *testing.T) { @@ -800,3 +799,11 @@ func TestScanNullValue(t *testing.T) { t.Fatalf("failed to query slice data with null age, got error %v", err) } } + +func TestQueryWithTableAndConditions(t *testing.T) { + result := DB.Session(&gorm.Session{DryRun: true}).Table("user").Find(&User{}, User{Name: "jinzhu"}) + + if !regexp.MustCompile(`SELECT \* FROM .user. WHERE .user.\..name. = .+ AND .user.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } +} From dd0d74fad06342a792a1cdc20101a57ee019f447 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 19:16:55 +0800 Subject: [PATCH 645/881] Fix transaction on closed conn when using prepared statement, close #3380 --- prepare_stmt.go | 14 ++++++++++++++ tests/tests_test.go | 4 ++-- tests/transaction_test.go | 21 +++++++++++++++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index de7e2a26..14a6aaec 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -99,6 +99,20 @@ type PreparedStmtTX struct { PreparedStmtDB *PreparedStmtDB } +func (tx *PreparedStmtTX) Commit() error { + if tx.Tx != nil { + return tx.Tx.Commit() + } + return ErrInvalidTransaction +} + +func (tx *PreparedStmtTX) Rollback() error { + if tx.Tx != nil { + return tx.Tx.Rollback() + } + return ErrInvalidTransaction +} + func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { stmt, err := tx.PreparedStmtDB.prepare(ctx, query) if err == nil { diff --git a/tests/tests_test.go b/tests/tests_test.go index 192160a0..cb73d267 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -21,7 +21,7 @@ var DB *gorm.DB func init() { var err error if DB, err = OpenTestConnection(); err != nil { - log.Printf("failed to connect database, got error %v\n", err) + log.Printf("failed to connect database, got error %v", err) os.Exit(1) } else { sqlDB, err := DB.DB() @@ -30,7 +30,7 @@ func init() { } if err != nil { - log.Printf("failed to connect database, got error %v\n", err) + log.Printf("failed to connect database, got error %v", err) } RunMigrations() diff --git a/tests/transaction_test.go b/tests/transaction_test.go index aea151d9..334600b8 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -282,3 +282,24 @@ func TestNestedTransactionWithBlock(t *testing.T) { t.Fatalf("Should find saved record") } } + +func TestTransactionOnClosedConn(t *testing.T) { + DB, err := OpenTestConnection() + if err != nil { + t.Fatalf("failed to connect database, got error %v", err) + } + rawDB, _ := DB.DB() + rawDB.Close() + + if err := DB.Transaction(func(tx *gorm.DB) error { + return nil + }); err == nil { + t.Errorf("should returns error when commit with closed conn, got error %v", err) + } + + if err := DB.Session(&gorm.Session{PrepareStmt: true}).Transaction(func(tx *gorm.DB) error { + return nil + }); err == nil { + t.Errorf("should returns error when commit with closed conn, got error %v", err) + } +} From 6a866464695e8b0291236f9038a032f68fb0b37d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 20:41:00 +0800 Subject: [PATCH 646/881] Fix use db function as integer's default value, close #3384 --- schema/field.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/schema/field.go b/schema/field.go index 0cb210f8..f8a73c60 100644 --- a/schema/field.go +++ b/schema/field.go @@ -178,41 +178,41 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Comment = val } + defaultValueFunc := strings.Contains(field.DefaultValue, "(") && + strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" switch reflect.Indirect(fieldValue).Kind() { case reflect.Bool: field.DataType = Bool - if field.HasDefaultValue && field.DefaultValue != "" { + if field.HasDefaultValue && !defaultValueFunc { if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil { schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err) } } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: field.DataType = Int - if field.HasDefaultValue && field.DefaultValue != "" { + if field.HasDefaultValue && !defaultValueFunc { if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil { schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err) } } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: field.DataType = Uint - if field.HasDefaultValue && field.DefaultValue != "" { + if field.HasDefaultValue && !defaultValueFunc { if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil { schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err) } } case reflect.Float32, reflect.Float64: field.DataType = Float - if field.HasDefaultValue && field.DefaultValue != "" { + if field.HasDefaultValue && !defaultValueFunc { if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil { schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err) } } case reflect.String: field.DataType = String - isFunc := strings.Contains(field.DefaultValue, "(") && - strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" - if field.HasDefaultValue && !isFunc { + if field.HasDefaultValue && !defaultValueFunc { field.DefaultValue = strings.Trim(field.DefaultValue, "'") field.DefaultValue = strings.Trim(field.DefaultValue, "\"") field.DefaultValueInterface = field.DefaultValue From 28121d44554b1f5db07658e7cc8343ace65d940d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 20:59:41 +0800 Subject: [PATCH 647/881] Fix panic when batch creating from slice contains invalid data, close #3385 --- callbacks/create.go | 6 ++++++ tests/create_test.go | 13 +++++++++++++ 2 files changed, 19 insertions(+) diff --git a/callbacks/create.go b/callbacks/create.go index e37c2c60..c00a0a73 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -1,6 +1,7 @@ package callbacks import ( + "fmt" "reflect" "gorm.io/gorm" @@ -259,6 +260,11 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { for i := 0; i < stmt.ReflectValue.Len(); i++ { rv := reflect.Indirect(stmt.ReflectValue.Index(i)) + if !rv.IsValid() { + stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) + return + } + values.Values[i] = make([]interface{}, len(values.Columns)) for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] diff --git a/tests/create_test.go b/tests/create_test.go index 59fdd8f1..00674eec 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "errors" "testing" "time" @@ -299,6 +300,18 @@ func TestCreateEmptySlice(t *testing.T) { } } +func TestCreateInvalidSlice(t *testing.T) { + users := []*User{ + GetUser("invalid_slice_1", Config{}), + GetUser("invalid_slice_2", Config{}), + nil, + } + + if err := DB.Create(&users).Error; !errors.Is(err, gorm.ErrInvalidData) { + t.Errorf("should returns error invalid data when creating from slice that contains invalid data") + } +} + func TestCreateWithExistingTimestamp(t *testing.T) { user := User{Name: "CreateUserExistingTimestamp"} curTime := now.MustParse("2016-01-01") From f1216222284fc2f91bee7018c5c54a3662b9a2b3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 4 Sep 2020 14:30:53 +0800 Subject: [PATCH 648/881] Don't add prefix for invalid embedded fields --- schema/field.go | 2 +- schema/schema_test.go | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/schema/field.go b/schema/field.go index f8a73c60..db044c23 100644 --- a/schema/field.go +++ b/schema/field.go @@ -340,7 +340,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...) } - if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { + if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok && ef.DBName != "" { ef.DBName = prefix + ef.DBName } diff --git a/schema/schema_test.go b/schema/schema_test.go index 4d13ebd2..6ca5b269 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -194,6 +194,7 @@ func TestEmbeddedStruct(t *testing.T) { ID int OwnerID int Name string + Ignored string `gorm:"-"` } type Corp struct { @@ -211,15 +212,18 @@ func TestEmbeddedStruct(t *testing.T) { {Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, {Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, } for _, f := range fields { checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { - f.Creatable = true - f.Updatable = true - f.Readable = true + if f.Name != "Ignored" { + f.Creatable = true + f.Updatable = true + f.Readable = true + } }) } } From d8ddccf1478bf1aaf3726f2301c08fe6a9ca4183 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 4 Sep 2020 19:02:37 +0800 Subject: [PATCH 649/881] Don't marshal to null for associations after preloading, close #3395 --- callbacks/preload.go | 14 ++++++++++++-- tests/preload_test.go | 24 ++++++++++++++++++++++++ tests/scan_test.go | 8 ++++++-- 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index 9b8f762a..aec10ec5 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -110,10 +110,20 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { // clean up old values before preloading switch reflectValue.Kind() { case reflect.Struct: - rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) + switch rel.Type { + case schema.HasMany, schema.Many2Many: + rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface()) + default: + rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) + } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { - rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + switch rel.Type { + case schema.HasMany, schema.Many2Many: + rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface()) + default: + rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + } } } diff --git a/tests/preload_test.go b/tests/preload_test.go index 76b72f14..d9035661 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -1,6 +1,8 @@ package tests_test import ( + "encoding/json" + "regexp" "sort" "strconv" "testing" @@ -188,3 +190,25 @@ func TestNestedPreloadWithConds(t *testing.T) { CheckPet(t, *users2[2].Pets[2], *users[2].Pets[2]) } } + +func TestPreloadEmptyData(t *testing.T) { + var user = *GetUser("user_without_associations", Config{}) + DB.Create(&user) + + DB.Preload("Team").Preload("Languages").Preload("Friends").First(&user, "name = ?", user.Name) + + if r, err := json.Marshal(&user); err != nil { + t.Errorf("failed to marshal users, got error %v", err) + } else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) { + t.Errorf("json marshal is not empty slice, got %v", string(r)) + } + + var results []User + DB.Preload("Team").Preload("Languages").Preload("Friends").Find(&results, "name = ?", user.Name) + + if r, err := json.Marshal(&results); err != nil { + t.Errorf("failed to marshal users, got error %v", err) + } else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) { + t.Errorf("json marshal is not empty slice, got %v", string(r)) + } +} diff --git a/tests/scan_test.go b/tests/scan_test.go index 3e66a25a..92e89521 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -51,11 +51,11 @@ func TestScan(t *testing.T) { DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results) sort.Slice(results, func(i, j int) bool { - return strings.Compare(results[i].Name, results[j].Name) < -1 + return strings.Compare(results[i].Name, results[j].Name) <= -1 }) if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { - t.Errorf("Scan into struct map") + t.Errorf("Scan into struct map, got %#v", results) } } @@ -84,6 +84,10 @@ func TestScanRows(t *testing.T) { results = append(results, result) } + sort.Slice(results, func(i, j int) bool { + return strings.Compare(results[i].Name, results[j].Name) <= -1 + }) + if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { t.Errorf("Should find expected results") } From 6e38a2c2d510a6823ad7b73c7e9321c8f7ceaff8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 6 Sep 2020 10:51:21 +0800 Subject: [PATCH 650/881] Fix many2many join table name rule --- schema/naming.go | 4 ++++ schema/relationship_test.go | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/schema/naming.go b/schema/naming.go index 9b7c9471..ecdab791 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -41,6 +41,10 @@ func (ns NamingStrategy) ColumnName(table, column string) string { // JoinTableName convert string to join table name func (ns NamingStrategy) JoinTableName(str string) string { + if strings.ToLower(str) == str { + return str + } + if ns.SingularTable { return ns.TablePrefix + toDBName(str) } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index f2d63323..b9279b9f 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -206,16 +206,16 @@ func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { type User struct { gorm.Model - Profiles []Profile `gorm:"many2many:user_profiles;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` + Profiles []Profile `gorm:"many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` Refer uint } checkStructRelation(t, &User{}, Relation{ Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", - JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"}, References: []Reference{ - {"ID", "User", "UserReferID", "user_profiles", "", true}, - {"ID", "Profile", "ProfileRefer", "user_profiles", "", false}, + {"ID", "User", "UserReferID", "user_profile", "", true}, + {"ID", "Profile", "ProfileRefer", "user_profile", "", false}, }, }) } From 05794298bd3d87dc8e98de8cde451b19093e2a4a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 6 Sep 2020 12:22:05 +0800 Subject: [PATCH 651/881] Fix Save with specified table, close #3396 --- finisher_api.go | 3 ++- tests/update_test.go | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/finisher_api.go b/finisher_api.go index 1d5ef5fc..6ece0f79 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -51,7 +51,8 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.callbacks.Update().Execute(tx) if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { - if err := tx.Session(&Session{}).First(value).Error; errors.Is(err, ErrRecordNotFound) { + result := reflect.New(tx.Statement.Schema.ModelType).Interface() + if err := tx.Session(&Session{WithConditions: true}).First(result).Error; errors.Is(err, ErrRecordNotFound) { return tx.Create(value) } } diff --git a/tests/update_test.go b/tests/update_test.go index 1944ed3f..a660647c 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -629,4 +629,26 @@ func TestSaveWithPrimaryValue(t *testing.T) { var result2 Language DB.First(&result2, "code = ?", "save") AssertEqual(t, result2, lang) + + DB.Table("langs").Migrator().DropTable(&Language{}) + DB.Table("langs").AutoMigrate(&Language{}) + + if err := DB.Table("langs").Save(&lang).Error; err != nil { + t.Errorf("no error should happen when creating data, but got %v", err) + } + + var result3 Language + if err := DB.Table("langs").First(&result3, "code = ?", lang.Code).Error; err != nil || result3.Name != lang.Name { + t.Errorf("failed to find created record, got error: %v, result: %+v", err, result3) + } + + lang.Name += "name2" + if err := DB.Table("langs").Save(&lang).Error; err != nil { + t.Errorf("no error should happen when creating data, but got %v", err) + } + + var result4 Language + if err := DB.Table("langs").First(&result4, "code = ?", lang.Code).Error; err != nil || result4.Name != lang.Name { + t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4) + } } From 6de0356a57f74da299e7cb2b8ccd44e86fe59675 Mon Sep 17 00:00:00 2001 From: egenchen Date: Tue, 8 Sep 2020 16:59:47 +0800 Subject: [PATCH 652/881] Fix monocolor log output inconsist with colorful log (#3425) --- logger/logger.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/logger/logger.go b/logger/logger.go index 49ae988c..0b0a7377 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -65,9 +65,9 @@ func New(writer Writer, config Config) Interface { infoStr = "%s\n[info] " warnStr = "%s\n[warn] " errStr = "%s\n[error] " - traceStr = "%s\n[%v] [rows:%d] %s" - traceWarnStr = "%s\n[%v] [rows:%d] %s" - traceErrStr = "%s %s\n[%v] [rows:%d] %s" + traceStr = "%s\n[%.3fms] [rows:%d] %s" + traceWarnStr = "%s\n[%.3fms] [rows:%d] %s" + traceErrStr = "%s %s\n[%.3fms] [rows:%d] %s" ) if config.Colorful { From c9d5c0b07aa7be8ed4bebeb376ccf158542730ea Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 8 Sep 2020 18:24:35 +0800 Subject: [PATCH 653/881] Fix create database foreign keys for same type having has many/one & many2many relationships, close #3424 --- migrator/migrator.go | 23 ++++++++++++++++++----- tests/embedded_struct_test.go | 4 +++- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 29d26c9e..98e92c96 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -586,6 +586,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i var ( modelNames, orderedModelNames []string orderedModelNamesMap = map[string]bool{} + parsedSchemas = map[*schema.Schema]bool{} valuesMap = map[string]Dependency{} insertIntoOrderedList func(name string) parseDependence func(value interface{}, addToList bool) @@ -595,23 +596,35 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i dep := Dependency{ Statement: &gorm.Statement{DB: m.DB, Dest: value}, } + beDependedOn := map[*schema.Schema]bool{} if err := dep.Parse(value); err != nil { m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) } + if _, ok := parsedSchemas[dep.Statement.Schema]; ok { + return + } + parsedSchemas[dep.Statement.Schema] = true for _, rel := range dep.Schema.Relationships.Relations { if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { dep.Depends = append(dep.Depends, c.ReferenceSchema) } + if rel.Type == schema.HasOne || rel.Type == schema.HasMany { + beDependedOn[rel.FieldSchema] = true + } + if rel.JoinTable != nil { - if rel.Schema != rel.FieldSchema { - dep.Depends = append(dep.Depends, rel.FieldSchema) - } // append join value - defer func(joinValue interface{}) { + defer func(rel *schema.Relationship, joinValue interface{}) { + if !beDependedOn[rel.FieldSchema] { + dep.Depends = append(dep.Depends, rel.FieldSchema) + } else { + fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface() + parseDependence(fieldValue, autoAdd) + } parseDependence(joinValue, autoAdd) - }(reflect.New(rel.JoinTable.ModelType).Interface()) + }(rel, reflect.New(rel.JoinTable.ModelType).Interface()) } } diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index c29078bd..312a5c37 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -163,6 +163,8 @@ func TestEmbeddedRelations(t *testing.T) { DB.Migrator().DropTable(&AdvancedUser{}) if err := DB.AutoMigrate(&AdvancedUser{}); err != nil { - t.Errorf("Failed to auto migrate advanced user, got error %v", err) + if DB.Dialector.Name() != "sqlite" { + t.Errorf("Failed to auto migrate advanced user, got error %v", err) + } } } From c70c097e88bd5372783da6af55c4742fa4fe83ab Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 8 Sep 2020 19:11:20 +0800 Subject: [PATCH 654/881] Refactor format SQL for driver.Valuer --- logger/sql.go | 20 ++++++++++++++++++++ tests/go.mod | 4 ---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 80645b0c..096b9407 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -38,6 +38,26 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } else { vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper } + case *time.Time: + if v != nil { + if v.IsZero() { + vars[idx] = escaper + "0000-00-00 00:00:00" + escaper + } else { + vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper + } + } else { + vars[idx] = "NULL" + } + case fmt.Stringer: + vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper + case driver.Valuer: + reflectValue := reflect.ValueOf(v) + if v != nil && reflectValue.IsValid() && (reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) { + r, _ := v.Value() + vars[idx] = fmt.Sprintf("%v", r) + } else { + vars[idx] = "NULL" + } case []byte: if isPrintable(v) { vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper diff --git a/tests/go.mod b/tests/go.mod index 4ddb0b69..76db6764 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,10 +6,6 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v1.0.1 - gorm.io/driver/postgres v1.0.0 - gorm.io/driver/sqlite v1.1.1 - gorm.io/driver/sqlserver v1.0.3 gorm.io/gorm v1.9.19 ) From aceb3dad3bbd43e79d0146992701f4f25f3eabb0 Mon Sep 17 00:00:00 2001 From: caelansar <819711623@qq.com> Date: Tue, 8 Sep 2020 21:28:04 +0800 Subject: [PATCH 655/881] correct generated sql --- clause/expression.go | 3 +++ tests/query_test.go | 16 ++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/clause/expression.go b/clause/expression.go index 3b914e68..55599571 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -37,6 +37,9 @@ func (expr Expr) Build(builder Builder) { } else { switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { case reflect.Slice, reflect.Array: + if rv.Len() == 0 { + builder.AddVar(builder, nil) + } for i := 0; i < rv.Len(); i++ { if i > 0 { builder.WriteByte(',') diff --git a/tests/query_test.go b/tests/query_test.go index 795186da..e695e825 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -202,6 +202,22 @@ func TestFind(t *testing.T) { } } }) + + var models []User + if err := DB.Where("name in ?", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models)) + } else { + for idx, user := range users { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models[idx], user) + }) + } + } + + var none []User + if err := DB.Where("name in ?", []string{}).Find(&none).Error; err != nil || len(none) != 0 { + t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) + } } func TestQueryWithAssociation(t *testing.T) { From 222427c474a3146bf79cb782fe50fae7d80aae69 Mon Sep 17 00:00:00 2001 From: "Jonathan A. Sternberg" Date: Tue, 8 Sep 2020 18:12:14 -0500 Subject: [PATCH 656/881] Release the connection when discovering the column types in the migrator When the migrator is used to discover the column types, such as when used with `AutoMigrate()`, it does not close the query result. This changes the migrator to close the query result and it also changes the query to use `LIMIT 1` to prevent additional work against the database when only discovering the schema. Fixes #3432. --- migrator/migrator.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 98e92c96..c0e22ae0 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -388,9 +388,10 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { err = m.RunWithValue(value, func(stmt *gorm.Statement) error { - rows, err := m.DB.Session(&gorm.Session{}).Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() + rows, err := m.DB.Session(&gorm.Session{}).Raw("select * from ? limit 1", clause.Table{Name: stmt.Table}).Rows() if err == nil { columnTypes, err = rows.ColumnTypes() + _ = rows.Close() } return err }) From 2242ac6c0ea490f7fa7c60c61126be0fdee0d72f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 9 Sep 2020 10:31:48 +0800 Subject: [PATCH 657/881] Fix tests & refactor for PR #3429 --- clause/expression.go | 11 ++++++----- tests/go.mod | 4 ++++ tests/query_test.go | 4 ++-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 55599571..dde236d3 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -39,12 +39,13 @@ func (expr Expr) Build(builder Builder) { case reflect.Slice, reflect.Array: if rv.Len() == 0 { builder.AddVar(builder, nil) - } - for i := 0; i < rv.Len(); i++ { - if i > 0 { - builder.WriteByte(',') + } else { + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) } - builder.AddVar(builder, rv.Index(i).Interface()) } default: builder.AddVar(builder, expr.Vars[idx]) diff --git a/tests/go.mod b/tests/go.mod index 76db6764..4ddb0b69 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,6 +6,10 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 + gorm.io/driver/mysql v1.0.1 + gorm.io/driver/postgres v1.0.0 + gorm.io/driver/sqlite v1.1.1 + gorm.io/driver/sqlserver v1.0.3 gorm.io/gorm v1.9.19 ) diff --git a/tests/query_test.go b/tests/query_test.go index e695e825..14150038 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -204,7 +204,7 @@ func TestFind(t *testing.T) { }) var models []User - if err := DB.Where("name in ?", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 { + if err := DB.Where("name in (?)", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 { t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models)) } else { for idx, user := range users { @@ -215,7 +215,7 @@ func TestFind(t *testing.T) { } var none []User - if err := DB.Where("name in ?", []string{}).Find(&none).Error; err != nil || len(none) != 0 { + if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 { t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) } } From 839e09e98558d946b4bf316bcd142edcf727ac37 Mon Sep 17 00:00:00 2001 From: caelansar <819711623@qq.com> Date: Tue, 8 Sep 2020 21:28:04 +0800 Subject: [PATCH 658/881] correct generated sql --- clause/expression.go | 3 +++ tests/query_test.go | 16 ++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/clause/expression.go b/clause/expression.go index 3b914e68..55599571 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -37,6 +37,9 @@ func (expr Expr) Build(builder Builder) { } else { switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { case reflect.Slice, reflect.Array: + if rv.Len() == 0 { + builder.AddVar(builder, nil) + } for i := 0; i < rv.Len(); i++ { if i > 0 { builder.WriteByte(',') diff --git a/tests/query_test.go b/tests/query_test.go index 795186da..e695e825 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -202,6 +202,22 @@ func TestFind(t *testing.T) { } } }) + + var models []User + if err := DB.Where("name in ?", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models)) + } else { + for idx, user := range users { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models[idx], user) + }) + } + } + + var none []User + if err := DB.Where("name in ?", []string{}).Find(&none).Error; err != nil || len(none) != 0 { + t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) + } } func TestQueryWithAssociation(t *testing.T) { From e7188c04ca9d81767ff090bc584177f4b6fb9fcc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 9 Sep 2020 10:31:48 +0800 Subject: [PATCH 659/881] Fix tests & refactor for PR #3429 --- clause/expression.go | 11 ++++++----- tests/go.mod | 4 ++++ tests/query_test.go | 4 ++-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 55599571..dde236d3 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -39,12 +39,13 @@ func (expr Expr) Build(builder Builder) { case reflect.Slice, reflect.Array: if rv.Len() == 0 { builder.AddVar(builder, nil) - } - for i := 0; i < rv.Len(); i++ { - if i > 0 { - builder.WriteByte(',') + } else { + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) } - builder.AddVar(builder, rv.Index(i).Interface()) } default: builder.AddVar(builder, expr.Vars[idx]) diff --git a/tests/go.mod b/tests/go.mod index 76db6764..4ddb0b69 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,6 +6,10 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 + gorm.io/driver/mysql v1.0.1 + gorm.io/driver/postgres v1.0.0 + gorm.io/driver/sqlite v1.1.1 + gorm.io/driver/sqlserver v1.0.3 gorm.io/gorm v1.9.19 ) diff --git a/tests/query_test.go b/tests/query_test.go index e695e825..14150038 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -204,7 +204,7 @@ func TestFind(t *testing.T) { }) var models []User - if err := DB.Where("name in ?", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 { + if err := DB.Where("name in (?)", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 { t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models)) } else { for idx, user := range users { @@ -215,7 +215,7 @@ func TestFind(t *testing.T) { } var none []User - if err := DB.Where("name in ?", []string{}).Find(&none).Error; err != nil || len(none) != 0 { + if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 { t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) } } From 567597f000606b2266ff4b43950f5a801c2f2f63 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 9 Sep 2020 10:53:13 +0800 Subject: [PATCH 660/881] Fix fail on sqlserver, #3433 --- migrator/migrator.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index c0e22ae0..53fd5ac0 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -388,10 +388,10 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { err = m.RunWithValue(value, func(stmt *gorm.Statement) error { - rows, err := m.DB.Session(&gorm.Session{}).Raw("select * from ? limit 1", clause.Table{Name: stmt.Table}).Rows() + rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() if err == nil { + defer rows.Close() columnTypes, err = rows.ColumnTypes() - _ = rows.Close() } return err }) From f6117b7f3dd21629b8196c376b0284d71672d1c7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 9 Sep 2020 16:26:11 +0800 Subject: [PATCH 661/881] Should not diplay SubQuery SQL log, close #3437 --- logger/logger.go | 14 +++++++++----- statement.go | 3 ++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/logger/logger.go b/logger/logger.go index 0b0a7377..831192fc 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -2,6 +2,7 @@ package logger import ( "context" + "io/ioutil" "log" "os" "time" @@ -54,11 +55,14 @@ type Interface interface { Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) } -var Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ - SlowThreshold: 100 * time.Millisecond, - LogLevel: Warn, - Colorful: true, -}) +var ( + Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) + Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ + SlowThreshold: 100 * time.Millisecond, + LogLevel: Warn, + Colorful: true, + }) +) func New(writer Writer, config Config) Interface { var ( diff --git a/statement.go b/statement.go index e16cf0ff..ee80f8cd 100644 --- a/statement.go +++ b/statement.go @@ -12,6 +12,7 @@ import ( "sync" "gorm.io/gorm/clause" + "gorm.io/gorm/logger" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) @@ -189,7 +190,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { writer.WriteString("(NULL)") } case *DB: - subdb := v.Session(&Session{DryRun: true, WithConditions: true}).getInstance() + subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true, WithConditions: true}).getInstance() subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...) subdb.callbacks.Query().Execute(subdb) writer.WriteString(subdb.Statement.SQL.String()) From f6ed895caffcde0b37d181201a5cadd442b8879e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 9 Sep 2020 16:32:29 +0800 Subject: [PATCH 662/881] Build relationships if fields are not ignored, fix #3181 --- schema/relationship.go | 2 +- schema/relationship_test.go | 23 +++++++++++++++++++++++ schema/schema.go | 4 ++-- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index 47b948dc..35af111f 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -300,7 +300,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel // build references for _, f := range relation.JoinTable.Fields { - if f.Creatable { + if f.Creatable || f.Readable || f.Updatable { // use same data type for foreign keys f.DataType = fieldsMap[f.Name].DataType f.GORMDataType = fieldsMap[f.Name].GORMDataType diff --git a/schema/relationship_test.go b/schema/relationship_test.go index b9279b9f..7d7fd9c9 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -220,6 +220,29 @@ func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { }) } +func TestBuildReadonlyMany2ManyRelation(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"->;many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"}, + References: []Reference{ + {"ID", "User", "UserReferID", "user_profile", "", true}, + {"ID", "Profile", "ProfileRefer", "user_profile", "", false}, + }, + }) +} + func TestMany2ManyWithMultiPrimaryKeys(t *testing.T) { type Tag struct { ID uint `gorm:"primary_key"` diff --git a/schema/schema.go b/schema/schema.go index ea81d683..c3d3f6e0 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -133,7 +133,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if field.DBName != "" { // nonexistence or shortest path or first appear prioritized if has permission - if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) { + if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { if _, ok := schema.FieldsByDBName[field.DBName]; !ok { schema.DBNames = append(schema.DBNames, field.DBName) } @@ -219,7 +219,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded { if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { for _, field := range schema.Fields { - if field.DataType == "" && field.Creatable { + if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { if schema.parseRelation(field); schema.err != nil { return schema, schema.err } From 619d306cef27adf4681bd04edfc0a620217471b2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 10 Sep 2020 10:55:02 +0800 Subject: [PATCH 663/881] ignore (-) when creating default values, #3434 --- migrator/migrator.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 53fd5ac0..4b069c8a 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -71,7 +71,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface) - } else { + } else if field.DefaultValue != "(-)" { expr.SQL += " DEFAULT " + field.DefaultValue } } From 231effe119fd25f368fa6ff5b5724e519bf59cd9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 10 Sep 2020 11:59:18 +0800 Subject: [PATCH 664/881] Fix parse blank default value, close #3442 --- schema/field.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/schema/field.go b/schema/field.go index db044c23..e52a8aef 100644 --- a/schema/field.go +++ b/schema/field.go @@ -178,33 +178,34 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Comment = val } - defaultValueFunc := strings.Contains(field.DefaultValue, "(") && - strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" + // default value is function or null or blank (primary keys) + skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") && + strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == "" switch reflect.Indirect(fieldValue).Kind() { case reflect.Bool: field.DataType = Bool - if field.HasDefaultValue && !defaultValueFunc { + if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil { schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err) } } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: field.DataType = Int - if field.HasDefaultValue && !defaultValueFunc { + if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil { schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err) } } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: field.DataType = Uint - if field.HasDefaultValue && !defaultValueFunc { + if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil { schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err) } } case reflect.Float32, reflect.Float64: field.DataType = Float - if field.HasDefaultValue && !defaultValueFunc { + if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil { schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err) } @@ -212,7 +213,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.String: field.DataType = String - if field.HasDefaultValue && !defaultValueFunc { + if field.HasDefaultValue && !skipParseDefaultValue { field.DefaultValue = strings.Trim(field.DefaultValue, "'") field.DefaultValue = strings.Trim(field.DefaultValue, "\"") field.DefaultValueInterface = field.DefaultValue From 53caa85cf48f2ff4eee47fb55a07a3f3f16388fe Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 10 Sep 2020 19:20:47 +0800 Subject: [PATCH 665/881] Use db's Logger for callbacks logs, close #3448, #3447 --- callbacks.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/callbacks.go b/callbacks.go index eace06ca..83d103df 100644 --- a/callbacks.go +++ b/callbacks.go @@ -8,7 +8,6 @@ import ( "sort" "time" - "gorm.io/gorm/logger" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) @@ -156,7 +155,7 @@ func (p *processor) compile() (err error) { p.callbacks = callbacks if p.fns, err = sortCallbacks(p.callbacks); err != nil { - logger.Default.Error(context.Background(), "Got error when compile callbacks, got %v", err) + p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err) } return } @@ -179,7 +178,7 @@ func (c *callback) Register(name string, fn func(*DB)) error { } func (c *callback) Remove(name string) error { - logger.Default.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum()) c.name = name c.remove = true c.processor.callbacks = append(c.processor.callbacks, c) @@ -187,7 +186,7 @@ func (c *callback) Remove(name string) error { } func (c *callback) Replace(name string, fn func(*DB)) error { - logger.Default.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) c.name = name c.handler = fn c.replace = true @@ -217,7 +216,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { for _, c := range cs { // show warning message the callback name already exists if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { - logger.Default.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) + c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) } names = append(names, c.name) } From 70a7bd52ca2bbf64443b7227524e4600997ea1b9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 10 Sep 2020 21:46:18 +0800 Subject: [PATCH 666/881] Support delete associations with Select when deleting --- callbacks/callbacks.go | 1 + callbacks/delete.go | 53 ++++++++++++++++++++++++++++++++++++++ tests/delete_test.go | 54 +++++++++++++++++++++++++++++++++++++++ tests/joins_table_test.go | 18 +++++++++++++ utils/utils.go | 2 +- 5 files changed, 127 insertions(+), 1 deletion(-) diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index 0a12468c..dda4b046 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -31,6 +31,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { deleteCallback := db.Callback().Delete() deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) deleteCallback.Register("gorm:before_delete", BeforeDelete) + deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) deleteCallback.Register("gorm:delete", Delete) deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) diff --git a/callbacks/delete.go b/callbacks/delete.go index e95117a1..510dfae4 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -21,6 +21,59 @@ func BeforeDelete(db *gorm.DB) { } } +func DeleteBeforeAssociations(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) + + if restricted { + for column, v := range selectColumns { + if v { + if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok { + switch rel.Type { + case schema.HasOne, schema.HasMany: + queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) + modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := db.Session(&gorm.Session{}).Model(modelValue) + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + case schema.Many2Many: + var ( + queryConds []clause.Expression + foreignFields []*schema.Field + relForeignKeys []string + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + table = rel.JoinTable.Table + tx = db.Session(&gorm.Session{}).Model(modelValue).Table(table) + ) + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + queryConds = append(queryConds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } + } + + _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) + column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) + queryConds = append(queryConds, clause.IN{Column: column, Values: values}) + + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + } + } + } + } + } + } +} + func Delete(db *gorm.DB) { if db.Error == nil { if db.Statement.Schema != nil && !db.Statement.Unscoped { diff --git a/tests/delete_test.go b/tests/delete_test.go index 17299677..4945e837 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -5,6 +5,7 @@ import ( "testing" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -127,3 +128,56 @@ func TestBlockGlobalDelete(t *testing.T) { t.Errorf("should returns no error while enable global update, but got err %v", err) } } + +func TestDeleteWithAssociations(t *testing.T) { + user := GetUser("delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1}) + + if err := DB.Create(user).Error; err != nil { + t.Fatalf("failed to create user, got error %v", err) + } + + if err := DB.Select(clause.Associations).Delete(&user).Error; err != nil { + t.Fatalf("failed to delete user, got error %v", err) + } + + for key, value := range map[string]int64{"Account": 1, "Pets": 2, "Toys": 4, "Company": 1, "Manager": 1, "Team": 1, "Languages": 0, "Friends": 0} { + if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } +} + +func TestDeleteSliceWithAssociations(t *testing.T) { + users := []User{ + *GetUser("delete_slice_with_associations1", Config{Account: true, Pets: 4, Toys: 1, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 4}), + *GetUser("delete_slice_with_associations2", Config{Account: true, Pets: 3, Toys: 2, Company: true, Manager: true, Team: 2, Languages: 2, Friends: 3}), + *GetUser("delete_slice_with_associations3", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 2}), + *GetUser("delete_slice_with_associations4", Config{Account: true, Pets: 1, Toys: 4, Company: true, Manager: true, Team: 4, Languages: 4, Friends: 1}), + } + + if err := DB.Create(users).Error; err != nil { + t.Fatalf("failed to create user, got error %v", err) + } + + if err := DB.Select(clause.Associations).Delete(&users).Error; err != nil { + t.Fatalf("failed to delete user, got error %v", err) + } + + for key, value := range map[string]int64{"Account": 4, "Pets": 10, "Toys": 10, "Company": 4, "Manager": 4, "Team": 10, "Languages": 0, "Friends": 0} { + if count := DB.Unscoped().Model(&users).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 4, "Manager": 4, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Model(&users).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } +} diff --git a/tests/joins_table_test.go b/tests/joins_table_test.go index b8c1be77..084c2f2c 100644 --- a/tests/joins_table_test.go +++ b/tests/joins_table_test.go @@ -5,12 +5,14 @@ import ( "time" "gorm.io/gorm" + "gorm.io/gorm/clause" ) type Person struct { ID int Name string Addresses []Address `gorm:"many2many:person_addresses;"` + DeletedAt gorm.DeletedAt } type Address struct { @@ -95,4 +97,20 @@ func TestOverrideJoinTable(t *testing.T) { if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 { t.Fatalf("address should be deleted when clear with unscoped") } + + address2_1 := Address{Name: "address 2-1"} + address2_2 := Address{Name: "address 2-2"} + person2 := Person{Name: "person_2", Addresses: []Address{address2_1, address2_2}} + DB.Create(&person2) + if err := DB.Select(clause.Associations).Delete(&person2).Error; err != nil { + t.Fatalf("failed to delete person, got error: %v", err) + } + + if count := DB.Unscoped().Model(&person2).Association("Addresses").Count(); count != 2 { + t.Errorf("person's addresses expects 2, got %v", count) + } + + if count := DB.Model(&person2).Association("Addresses").Count(); count != 0 { + t.Errorf("person's addresses expects 2, got %v", count) + } } diff --git a/utils/utils.go b/utils/utils.go index 905001a5..ecba7fb9 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -30,7 +30,7 @@ func FileWithLineNum() string { } func IsValidDBNameChar(c rune) bool { - return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' + return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } func CheckTruth(val interface{}) bool { From b8a74a80d732963df95580eae3316db140a882a4 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 11 Sep 2020 10:49:31 +0800 Subject: [PATCH 667/881] Fix embedded struct with default value, close #3451 --- schema/field.go | 24 +++++++++++++----------- tests/go.mod | 4 ++-- tests/query_test.go | 1 + 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/schema/field.go b/schema/field.go index e52a8aef..60dc8095 100644 --- a/schema/field.go +++ b/schema/field.go @@ -345,19 +345,21 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.DBName = prefix + ef.DBName } - if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { - ef.PrimaryKey = true - } else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { - ef.PrimaryKey = true - } else { - ef.PrimaryKey = false + if ef.PrimaryKey { + if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { + ef.PrimaryKey = true + } else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { + ef.PrimaryKey = true + } else { + ef.PrimaryKey = false - if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { - ef.AutoIncrement = false - } + if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { + ef.AutoIncrement = false + } - if ef.DefaultValue == "" { - ef.HasDefaultValue = false + if ef.DefaultValue == "" { + ef.HasDefaultValue = false + } } } diff --git a/tests/go.mod b/tests/go.mod index 4ddb0b69..f62365f8 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,8 +9,8 @@ require ( gorm.io/driver/mysql v1.0.1 gorm.io/driver/postgres v1.0.0 gorm.io/driver/sqlite v1.1.1 - gorm.io/driver/sqlserver v1.0.3 - gorm.io/gorm v1.9.19 + gorm.io/driver/sqlserver v1.0.4 + gorm.io/gorm v1.20.0 ) replace gorm.io/gorm => ../ diff --git a/tests/query_test.go b/tests/query_test.go index 14150038..36229e2c 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -648,6 +648,7 @@ func TestOffset(t *testing.T) { if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { t.Errorf("Offset should work") } + DB.Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { From e583dfa196400896932c073d05383fcf6cedeb4f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 11 Sep 2020 11:44:58 +0800 Subject: [PATCH 668/881] Allow negative number for limit --- clause/limit.go | 4 +--- tests/go.mod | 2 +- tests/query_test.go | 3 ++- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/clause/limit.go b/clause/limit.go index 1946820d..2082f4d9 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -33,10 +33,8 @@ func (limit Limit) MergeClause(clause *Clause) { clause.Name = "" if v, ok := clause.Expression.(Limit); ok { - if limit.Limit == 0 && v.Limit > 0 { + if limit.Limit == 0 && v.Limit != 0 { limit.Limit = v.Limit - } else if limit.Limit < 0 { - limit.Limit = 0 } if limit.Offset == 0 && v.Offset > 0 { diff --git a/tests/go.mod b/tests/go.mod index f62365f8..17a3b156 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,7 +8,7 @@ require ( github.com/lib/pq v1.6.0 gorm.io/driver/mysql v1.0.1 gorm.io/driver/postgres v1.0.0 - gorm.io/driver/sqlite v1.1.1 + gorm.io/driver/sqlite v1.1.2 gorm.io/driver/sqlserver v1.0.4 gorm.io/gorm v1.20.0 ) diff --git a/tests/query_test.go b/tests/query_test.go index 36229e2c..d3bcbdbe 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -625,6 +625,7 @@ func TestLimit(t *testing.T) { {Name: "LimitUser3", Age: 20}, {Name: "LimitUser4", Age: 10}, {Name: "LimitUser5", Age: 20}, + {Name: "LimitUser6", Age: 20}, } DB.Create(&users) @@ -633,7 +634,7 @@ func TestLimit(t *testing.T) { DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { - t.Errorf("Limit should works") + t.Errorf("Limit should works, users1 %v users2 %v users3 %v", len(users1), len(users2), len(users3)) } } From 02fb382ec0b67a320fc26cdd460a70468d037779 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 11 Sep 2020 15:01:02 +0800 Subject: [PATCH 669/881] Support scan into int, string data types --- finisher_api.go | 4 +++- scan.go | 2 +- tests/scan_test.go | 10 ++++++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 6ece0f79..f426839a 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -384,7 +384,9 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { tx := db.getInstance() - tx.Error = tx.Statement.Parse(dest) + if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) { + tx.AddError(err) + } tx.Statement.Dest = dest tx.Statement.ReflectValue = reflect.ValueOf(dest) for tx.Statement.ReflectValue.Kind() == reflect.Ptr { diff --git a/scan.go b/scan.go index 89d9a07a..be8782ed 100644 --- a/scan.go +++ b/scan.go @@ -82,7 +82,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { scanIntoMap(mapValue, values, columns) *dest = append(*dest, mapValue) } - case *int, *int64, *uint, *uint64, *float32, *float64: + case *int, *int64, *uint, *uint64, *float32, *float64, *string: for initialized || rows.Next() { initialized = false db.RowsAffected++ diff --git a/tests/scan_test.go b/tests/scan_test.go index 92e89521..785bb97e 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -91,4 +91,14 @@ func TestScanRows(t *testing.T) { if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { t.Errorf("Should find expected results") } + + var ages int + if err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("SUM(age)").Scan(&ages).Error; err != nil || ages != 30 { + t.Fatalf("failed to scan ages, got error %v, ages: %v", err, ages) + } + + var name string + if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name { + t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name) + } } From ed1b134e1c6d8d791fc87a7286e9c534fa2840f2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 11 Sep 2020 17:33:31 +0800 Subject: [PATCH 670/881] Fix use uint to for autoCreateTime, autoUpdateTime --- schema/field.go | 8 ++++++++ tests/customize_field_test.go | 22 +++++++++++----------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/schema/field.go b/schema/field.go index 60dc8095..4b8a5a2a 100644 --- a/schema/field.go +++ b/schema/field.go @@ -624,6 +624,14 @@ func (field *Field) setupValuerAndSetter() { field.ReflectValueOf(value).SetUint(uint64(data)) case []byte: return field.Set(value, string(data)) + case time.Time: + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { + field.ReflectValueOf(value).SetUint(uint64(data.UnixNano())) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6)) + } else { + field.ReflectValueOf(value).SetUint(uint64(data.Unix())) + } case string: if i, err := strconv.ParseUint(data, 0, 64); err == nil { field.ReflectValueOf(value).SetUint(i) diff --git a/tests/customize_field_test.go b/tests/customize_field_test.go index bf3c78fa..7802eb11 100644 --- a/tests/customize_field_test.go +++ b/tests/customize_field_test.go @@ -69,12 +69,12 @@ func TestCustomizeField(t *testing.T) { FieldAllowSave3 string `gorm:"->:false;<-:create"` FieldReadonly string `gorm:"->"` FieldIgnore string `gorm:"-"` - AutoUnixCreateTime int64 `gorm:"autocreatetime"` - AutoUnixMilliCreateTime int64 `gorm:"autocreatetime:milli"` + AutoUnixCreateTime int32 `gorm:"autocreatetime"` + AutoUnixMilliCreateTime int `gorm:"autocreatetime:milli"` AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"` - AutoUnixUpdateTime int64 `gorm:"autoupdatetime"` - AutoUnixMilliUpdateTime int64 `gorm:"autoupdatetime:milli"` - AutoUnixNanoUpdateTime int64 `gorm:"autoupdatetime:nano"` + AutoUnixUpdateTime uint32 `gorm:"autoupdatetime"` + AutoUnixMilliUpdateTime int `gorm:"autoupdatetime:milli"` + AutoUnixNanoUpdateTime uint64 `gorm:"autoupdatetime:nano"` } DB.Migrator().DropTable(&CustomizeFieldStruct{}) @@ -116,15 +116,15 @@ func TestCustomizeField(t *testing.T) { t.Fatalf("invalid result: %#v", result) } - if result.AutoUnixCreateTime != result.AutoUnixUpdateTime || result.AutoUnixCreateTime == 0 { + if int(result.AutoUnixCreateTime) != int(result.AutoUnixUpdateTime) || result.AutoUnixCreateTime == 0 { t.Fatalf("invalid create/update unix time: %#v", result) } - if result.AutoUnixMilliCreateTime != result.AutoUnixMilliUpdateTime || result.AutoUnixMilliCreateTime == 0 || result.AutoUnixMilliCreateTime/result.AutoUnixCreateTime < 1e3 { + if int(result.AutoUnixMilliCreateTime) != int(result.AutoUnixMilliUpdateTime) || result.AutoUnixMilliCreateTime == 0 || int(result.AutoUnixMilliCreateTime)/int(result.AutoUnixCreateTime) < 1e3 { t.Fatalf("invalid create/update unix milli time: %#v", result) } - if result.AutoUnixNanoCreateTime != result.AutoUnixNanoUpdateTime || result.AutoUnixNanoCreateTime == 0 || result.AutoUnixNanoCreateTime/result.AutoUnixCreateTime < 1e6 { + if int(result.AutoUnixNanoCreateTime) != int(result.AutoUnixNanoUpdateTime) || result.AutoUnixNanoCreateTime == 0 || int(result.AutoUnixNanoCreateTime)/int(result.AutoUnixCreateTime) < 1e6 { t.Fatalf("invalid create/update unix nano time: %#v", result) } @@ -178,15 +178,15 @@ func TestCustomizeField(t *testing.T) { var createWithDefaultTimeResult CustomizeFieldStruct DB.Find(&createWithDefaultTimeResult, "name = ?", createWithDefaultTime.Name) - if createWithDefaultTimeResult.AutoUnixCreateTime != createWithDefaultTimeResult.AutoUnixUpdateTime || createWithDefaultTimeResult.AutoUnixCreateTime != 100 { + if int(createWithDefaultTimeResult.AutoUnixCreateTime) != int(createWithDefaultTimeResult.AutoUnixUpdateTime) || createWithDefaultTimeResult.AutoUnixCreateTime != 100 { t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult) } - if createWithDefaultTimeResult.AutoUnixMilliCreateTime != createWithDefaultTimeResult.AutoUnixMilliUpdateTime || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 { + if int(createWithDefaultTimeResult.AutoUnixMilliCreateTime) != int(createWithDefaultTimeResult.AutoUnixMilliUpdateTime) || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 { t.Fatalf("invalid create/update unix milli time: %#v", createWithDefaultTimeResult) } - if createWithDefaultTimeResult.AutoUnixNanoCreateTime != createWithDefaultTimeResult.AutoUnixNanoUpdateTime || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 { + if int(createWithDefaultTimeResult.AutoUnixNanoCreateTime) != int(createWithDefaultTimeResult.AutoUnixNanoUpdateTime) || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 { t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult) } } From 0ec10d4907762e94ac942903670184a93e7ed456 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 14 Sep 2020 12:37:16 +0800 Subject: [PATCH 671/881] Fix format SQL log, close #3465 --- logger/sql.go | 16 ++++++++++++++-- logger/sql_test.go | 6 ++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 096b9407..69a6b10e 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -96,9 +96,21 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } if numericPlaceholder == nil { - for _, v := range vars { - sql = strings.Replace(sql, "?", v, 1) + var idx int + var newSQL strings.Builder + + for _, v := range []byte(sql) { + if v == '?' { + if len(vars) > idx { + newSQL.WriteString(vars[idx]) + idx++ + continue + } + } + newSQL.WriteByte(v) } + + sql = newSQL.String() } else { sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") for idx, v := range vars { diff --git a/logger/sql_test.go b/logger/sql_test.go index 180570b8..b78f761c 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -29,6 +29,12 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, + }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", NumericRegexp: regexp.MustCompile(`@p(\d+)`), From 1d5f910b6e1a377f7f7defadb606a3e9c7a09c01 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 14 Sep 2020 15:29:47 +0800 Subject: [PATCH 672/881] Update workflows template --- .github/labels.json | 5 +++++ .github/workflows/invalid_question.yml | 22 ++++++++++++++++++++++ .github/workflows/missing_playground.yml | 2 +- 3 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/invalid_question.yml diff --git a/.github/labels.json b/.github/labels.json index 8b1ce849..6b9c2034 100644 --- a/.github/labels.json +++ b/.github/labels.json @@ -10,6 +10,11 @@ "colour": "#EDEDED", "description": "general questions" }, + "invalid_question": { + "name": "type:invalid question", + "colour": "#CF2E1F", + "description": "invalid question (not related to GORM or described in document or not enough information provided)" + }, "with_playground": { "name": "type:with reproduction steps", "colour": "#00ff00", diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml new file mode 100644 index 00000000..5b0bd981 --- /dev/null +++ b/.github/workflows/invalid_question.yml @@ -0,0 +1,22 @@ +name: "Close invalid questions issues" +on: + schedule: + - cron: "*/10 * * * *" + +jobs: + stale: + runs-on: ubuntu-latest + env: + ACTIONS_STEP_DEBUG: true + steps: + - name: Close Stale Issues + uses: actions/stale@v3.0.7 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-label: "status:stale" + days-before-stale: 0 + days-before-close: 2 + remove-stale-when-updated: true + only-labels: "type:invalid question" + diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index 422cb9f5..ea3207d6 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -13,7 +13,7 @@ jobs: uses: actions/stale@v3.0.7 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "This issue has been automatically marked as stale as it missing playground pull request link, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details, it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-label: "status:stale" days-before-stale: 0 days-before-close: 2 From 06d534d6eaa7f8534e51742b9930818511aaf28c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 15 Sep 2020 12:41:45 +0800 Subject: [PATCH 673/881] Cascade delete associations, close #3473 --- callbacks/delete.go | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index 510dfae4..549a94e7 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -34,8 +34,23 @@ func DeleteBeforeAssociations(db *gorm.DB) { queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() tx := db.Session(&gorm.Session{}).Model(modelValue) - if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { - return + withoutConditions := false + + if len(db.Statement.Selects) > 0 { + tx = tx.Select(db.Statement.Selects) + } + + for _, cond := range queryConds { + if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { + withoutConditions = true + break + } + } + + if !withoutConditions { + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } } case schema.Many2Many: var ( From a932175ccf98130aaa3028b75daf047a32b6dca0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 15 Sep 2020 14:28:26 +0800 Subject: [PATCH 674/881] Refactor cascade delete associations --- callbacks/delete.go | 14 +++++++++++++- tests/delete_test.go | 2 +- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index 549a94e7..85f11f4b 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -2,6 +2,7 @@ package callbacks import ( "reflect" + "strings" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -37,7 +38,18 @@ func DeleteBeforeAssociations(db *gorm.DB) { withoutConditions := false if len(db.Statement.Selects) > 0 { - tx = tx.Select(db.Statement.Selects) + var selects []string + for _, s := range db.Statement.Selects { + if s == clause.Associations { + selects = append(selects, s) + } else if strings.HasPrefix(s, column+".") { + selects = append(selects, strings.TrimPrefix(s, column+".")) + } + } + + if len(selects) > 0 { + tx = tx.Select(selects) + } } for _, cond := range queryConds { diff --git a/tests/delete_test.go b/tests/delete_test.go index 4945e837..ecd5ec39 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -136,7 +136,7 @@ func TestDeleteWithAssociations(t *testing.T) { t.Fatalf("failed to create user, got error %v", err) } - if err := DB.Select(clause.Associations).Delete(&user).Error; err != nil { + if err := DB.Select(clause.Associations, "Pets.Toy").Delete(&user).Error; err != nil { t.Fatalf("failed to delete user, got error %v", err) } From d002c70cf6ac6f35e4a2840606e65d84d33c5391 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 17 Sep 2020 21:52:41 +0800 Subject: [PATCH 675/881] Support named argument for struct --- clause/expression.go | 12 ++++++++++++ clause/expression_test.go | 10 ++++++++++ tests/go.mod | 4 ++-- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index dde236d3..49924ef7 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -3,6 +3,7 @@ package clause import ( "database/sql" "database/sql/driver" + "go/ast" "reflect" ) @@ -89,6 +90,17 @@ func (expr NamedExpr) Build(builder Builder) { for k, v := range value { namedMap[k] = v } + default: + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Struct: + modelType := reflectValue.Type() + for i := 0; i < modelType.NumField(); i++ { + if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { + namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface() + } + } + } } } diff --git a/clause/expression_test.go b/clause/expression_test.go index 17af737d..53d79c8f 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -37,6 +37,11 @@ func TestExpr(t *testing.T) { } func TestNamedExpr(t *testing.T) { + type NamedArgument struct { + Name1 string + Name2 string + } + results := []struct { SQL string Result string @@ -66,6 +71,11 @@ func TestNamedExpr(t *testing.T) { Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, + }, { + SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @Notexist", + Vars: []interface{}{NamedArgument{Name1: "jinzhu", Name2: "jinzhu2"}}, + Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, }} for idx, result := range results { diff --git a/tests/go.mod b/tests/go.mod index 17a3b156..0db87934 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,9 +8,9 @@ require ( github.com/lib/pq v1.6.0 gorm.io/driver/mysql v1.0.1 gorm.io/driver/postgres v1.0.0 - gorm.io/driver/sqlite v1.1.2 + gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlserver v1.0.4 - gorm.io/gorm v1.20.0 + gorm.io/gorm v1.20.1 ) replace gorm.io/gorm => ../ From 072f1de83a842a991ea76cecfd14a7e93d5e67c1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 18 Sep 2020 21:34:44 +0800 Subject: [PATCH 676/881] Add DryRunModeUnsupported Error for Row/Rows --- errors.go | 2 ++ finisher_api.go | 12 ++++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/errors.go b/errors.go index 508f6957..08755083 100644 --- a/errors.go +++ b/errors.go @@ -29,4 +29,6 @@ var ( ErrInvalidField = errors.New("invalid field") // ErrEmptySlice empty slice found ErrEmptySlice = errors.New("empty slice found") + // ErrDryRunModeUnsupported dry run mode unsupported + ErrDryRunModeUnsupported = errors.New("dry run mode unsupported") ) diff --git a/finisher_api.go b/finisher_api.go index f426839a..2c56d763 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -334,13 +334,21 @@ func (db *DB) Count(count *int64) (tx *DB) { func (db *DB) Row() *sql.Row { tx := db.getInstance().InstanceSet("rows", false) tx.callbacks.Row().Execute(tx) - return tx.Statement.Dest.(*sql.Row) + row, ok := tx.Statement.Dest.(*sql.Row) + if !ok && tx.DryRun { + db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error()) + } + return row } func (db *DB) Rows() (*sql.Rows, error) { tx := db.getInstance().InstanceSet("rows", true) tx.callbacks.Row().Execute(tx) - return tx.Statement.Dest.(*sql.Rows), tx.Error + rows, ok := tx.Statement.Dest.(*sql.Rows) + if !ok && tx.DryRun && tx.Error == nil { + tx.Error = ErrDryRunModeUnsupported + } + return rows, tx.Error } // Scan scan value to a struct From c9165fe3cafc9a66e2513caae381e6864fa0a15b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 18 Sep 2020 21:42:27 +0800 Subject: [PATCH 677/881] Don't panic when using unmatched vars in query, close #3488 --- clause/expression.go | 4 ++-- clause/expression_test.go | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 49924ef7..6a0dde8d 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -31,7 +31,7 @@ func (expr Expr) Build(builder Builder) { ) for _, v := range []byte(expr.SQL) { - if v == '?' { + if v == '?' && len(expr.Vars) > idx { if afterParenthesis { if _, ok := expr.Vars[idx].(driver.Valuer); ok { builder.AddVar(builder, expr.Vars[idx]) @@ -122,7 +122,7 @@ func (expr NamedExpr) Build(builder Builder) { } builder.WriteByte(v) - } else if v == '?' { + } else if v == '?' && len(expr.Vars) > idx { builder.AddVar(builder, expr.Vars[idx]) idx++ } else if inName { diff --git a/clause/expression_test.go b/clause/expression_test.go index 53d79c8f..19e30e6c 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -76,6 +76,10 @@ func TestNamedExpr(t *testing.T) { Vars: []interface{}{NamedArgument{Name1: "jinzhu", Name2: "jinzhu2"}}, Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, + }, { + SQL: "create table ? (? ?, ? ?)", + Vars: []interface{}{}, + Result: "create table ? (? ?, ? ?)", }} for idx, result := range results { From 089939c767f89087366799e47ab24d5b7b36c5e4 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 18 Sep 2020 21:50:11 +0800 Subject: [PATCH 678/881] AutoMigrate should auto create indexes, close #3486 --- migrator/migrator.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/migrator/migrator.go b/migrator/migrator.go index 4b069c8a..f390ff9f 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -133,6 +133,15 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } } + + for _, idx := range stmt.Schema.ParseIndexes() { + if !tx.Migrator().HasIndex(value, idx.Name) { + if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil { + return err + } + } + } + return nil }); err != nil { return err From 68920449f92f24c8b17d90986eb155c251ed8fc7 Mon Sep 17 00:00:00 2001 From: caelansar <31852257+caelansar@users.noreply.github.com> Date: Sat, 19 Sep 2020 13:48:34 +0800 Subject: [PATCH 679/881] Fix format sql log (#3492) --- logger/sql.go | 4 ++-- logger/sql_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 69a6b10e..138a35ec 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -52,9 +52,9 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper case driver.Valuer: reflectValue := reflect.ValueOf(v) - if v != nil && reflectValue.IsValid() && (reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) { + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { r, _ := v.Value() - vars[idx] = fmt.Sprintf("%v", r) + convertParams(r, idx) } else { vars[idx] = "NULL" } diff --git a/logger/sql_test.go b/logger/sql_test.go index b78f761c..71aa841a 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -1,13 +1,39 @@ package logger_test import ( + "database/sql/driver" + "encoding/json" + "fmt" "regexp" + "strings" "testing" "github.com/jinzhu/now" "gorm.io/gorm/logger" ) +type JSON json.RawMessage + +func (j JSON) Value() (driver.Value, error) { + if len(j) == 0 { + return nil, nil + } + return json.RawMessage(j).MarshalJSON() +} + +type ExampleStruct struct { + Name string + Val string +} + +func (s ExampleStruct) Value() (driver.Value, error) { + return json.Marshal(s) +} + +func format(v []byte, escaper string) string { + return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper +} + func TestExplainSQL(t *testing.T) { type role string type password []byte @@ -15,6 +41,10 @@ func TestExplainSQL(t *testing.T) { tt = now.MustParse("2020-02-23 11:10:10") myrole = role("admin") pwd = password([]byte("pass")) + jsVal = []byte(`{"Name":"test","Val":"test"}`) + js = JSON(jsVal) + esVal = []byte(`{"Name":"test","Val":"test"}`) + es = ExampleStruct{Name: "test", Val: "test"} ) results := []struct { @@ -53,6 +83,18 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, } for idx, r := range results { From 1a526e6802a9692a1340277551a9117644af21f0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 24 Sep 2020 11:32:38 +0800 Subject: [PATCH 680/881] Fix NamingStrategy with embedded struct, close #3513 --- schema/field.go | 2 +- schema/naming.go | 2 +- schema/naming_test.go | 26 ++++++++++++++++ schema/schema.go | 3 ++ schema/schema_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++ schema/utils.go | 5 ++++ tests/go.mod | 2 +- 7 files changed, 107 insertions(+), 3 deletions(-) diff --git a/schema/field.go b/schema/field.go index 4b8a5a2a..ce2808a8 100644 --- a/schema/field.go +++ b/schema/field.go @@ -326,7 +326,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { cacheStore := &sync.Map{} cacheStore.Store(embeddedCacheKey, true) - if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil { + if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { schema.err = err } diff --git a/schema/naming.go b/schema/naming.go index ecdab791..af753ce5 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -14,7 +14,7 @@ import ( type Namer interface { TableName(table string) string ColumnName(table, column string) string - JoinTableName(table string) string + JoinTableName(joinTable string) string RelationshipFKName(Relationship) string CheckerName(table, column string) string IndexName(table, column string) string diff --git a/schema/naming_test.go b/schema/naming_test.go index 96b83ced..a4600ceb 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -1,6 +1,7 @@ package schema import ( + "strings" "testing" ) @@ -32,3 +33,28 @@ func TestToDBName(t *testing.T) { } } } + +type NewNamingStrategy struct { + NamingStrategy +} + +func (ns NewNamingStrategy) ColumnName(table, column string) string { + baseColumnName := ns.NamingStrategy.ColumnName(table, column) + + if table == "" { + return baseColumnName + } + + s := strings.Split(table, "_") + + var prefix string + switch len(s) { + case 1: + prefix = s[0][:3] + case 2: + prefix = s[0][:1] + s[1][:2] + default: + prefix = s[0][:1] + s[1][:1] + s[2][:1] + } + return prefix + "_" + baseColumnName +} diff --git a/schema/schema.go b/schema/schema.go index c3d3f6e0..cffc19a7 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -97,6 +97,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } + if en, ok := namer.(embeddedNamer); ok { + tableName = en.Table + } schema := &Schema{ Name: modelType.Name(), diff --git a/schema/schema_test.go b/schema/schema_test.go index 6ca5b269..a426cd90 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "strings" "sync" "testing" @@ -227,3 +228,72 @@ func TestEmbeddedStruct(t *testing.T) { }) } } + +type CustomizedNamingStrategy struct { + schema.NamingStrategy +} + +func (ns CustomizedNamingStrategy) ColumnName(table, column string) string { + baseColumnName := ns.NamingStrategy.ColumnName(table, column) + + if table == "" { + return baseColumnName + } + + s := strings.Split(table, "_") + + var prefix string + switch len(s) { + case 1: + prefix = s[0][:3] + case 2: + prefix = s[0][:1] + s[1][:2] + default: + prefix = s[0][:1] + s[1][:1] + s[2][:1] + } + return prefix + "_" + baseColumnName +} + +func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) { + type CorpBase struct { + gorm.Model + OwnerID string + } + + type Company struct { + ID int + OwnerID int + Name string + Ignored string `gorm:"-"` + } + + type Corp struct { + CorpBase + Base Company `gorm:"embedded;embeddedPrefix:company_"` + } + + cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, CustomizedNamingStrategy{schema.NamingStrategy{}}) + + if err != nil { + t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) + } + + fields := []schema.Field{ + {Name: "ID", DBName: "cor_id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, + {Name: "ID", DBName: "company_cor_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Name", DBName: "company_cor_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "company_cor_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "cor_owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, + } + + for _, f := range fields { + checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { + if f.Name != "Ignored" { + f.Creatable = true + f.Updatable = true + f.Readable = true + } + }) + } +} diff --git a/schema/utils.go b/schema/utils.go index 41bd9d60..55cbdeb4 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -190,3 +190,8 @@ func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interfa return columns, queryValues } } + +type embeddedNamer struct { + Table string + Namer +} diff --git a/tests/go.mod b/tests/go.mod index 0db87934..c92fa0cf 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,7 +7,7 @@ require ( github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 gorm.io/driver/mysql v1.0.1 - gorm.io/driver/postgres v1.0.0 + gorm.io/driver/postgres v1.0.1 gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlserver v1.0.4 gorm.io/gorm v1.20.1 From 52287359153b5788d95960c963f74bebcdea88c7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 24 Sep 2020 15:00:13 +0800 Subject: [PATCH 681/881] Don't build IN condition if value implemented Valuer interface, #3517 --- statement.go | 16 +++++++++++----- tests/query_test.go | 5 +++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/statement.go b/statement.go index ee80f8cd..38d35926 100644 --- a/statement.go +++ b/statement.go @@ -299,12 +299,18 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) switch reflectValue.Kind() { case reflect.Slice, reflect.Array: - values := make([]interface{}, reflectValue.Len()) - for i := 0; i < reflectValue.Len(); i++ { - values[i] = reflectValue.Index(i).Interface() - } + if _, ok := v[key].(driver.Valuer); ok { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } else if _, ok := v[key].(Valuer); ok { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } else { + values := make([]interface{}, reflectValue.Len()) + for i := 0; i < reflectValue.Len(); i++ { + values[i] = reflectValue.Index(i).Interface() + } - conds = append(conds, clause.IN{Column: key, Values: values}) + conds = append(conds, clause.IN{Column: key, Values: values}) + } default: conds = append(conds, clause.Eq{Column: key, Value: v[key]}) } diff --git a/tests/query_test.go b/tests/query_test.go index d3bcbdbe..9c9ad9f2 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -345,6 +345,11 @@ func TestNot(t *testing.T) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } + result = dryDB.Where(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + result = dryDB.Not("name = ?", "jinzhu").Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) From c0de3c505176b0fea74c2e09fb9cae7c595b7020 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 24 Sep 2020 19:28:52 +0800 Subject: [PATCH 682/881] Support FullSaveAssociations Mode, close #3487, #3506 --- callbacks/associations.go | 61 +++++++++++++++++++-------------- callbacks/create.go | 5 ++- gorm.go | 7 ++++ logger/logger.go | 7 ++-- tests/update_belongs_to_test.go | 19 ++++++++++ tests/update_has_many_test.go | 41 ++++++++++++++++++++++ tests/update_has_one_test.go | 35 +++++++++++++++++++ tests/update_many2many_test.go | 25 ++++++++++++++ 8 files changed, 171 insertions(+), 29 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 0c677f47..64d79f24 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -66,9 +66,7 @@ func SaveBeforeAssociations(db *gorm.DB) { } if elems.Len() > 0 { - if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(elems.Interface()).Error) == nil { + if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } @@ -81,9 +79,7 @@ func SaveBeforeAssociations(db *gorm.DB) { rv = rv.Addr() } - if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(rv.Interface()).Error) == nil { + if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil { setupReferences(db.Statement.ReflectValue, rv) } } @@ -145,10 +141,9 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: onConflictColumns(rel.FieldSchema), - DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(elems.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses( + onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), + ).Create(elems.Interface()).Error) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { @@ -168,10 +163,9 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: onConflictColumns(rel.FieldSchema), - DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(f.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses( + onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), + ).Create(f.Interface()).Error) } } } @@ -230,10 +224,9 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: onConflictColumns(rel.FieldSchema), - DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(elems.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses( + onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), + ).Create(elems.Interface()).Error) } } @@ -298,7 +291,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if elems.Len() > 0 { - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) for i := 0; i < elems.Len(); i++ { appendToJoins(objs[i], elems.Index(i)) @@ -312,13 +305,31 @@ func SaveAfterAssociations(db *gorm.DB) { } } -func onConflictColumns(s *schema.Schema) (columns []clause.Column) { - if s.PrioritizedPrimaryField != nil { - return []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}} +func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) clause.OnConflict { + if stmt.DB.FullSaveAssociations { + defaultUpdatingColumns = make([]string, 0, len(s.DBNames)) + for _, dbName := range s.DBNames { + if !s.LookUpField(dbName).PrimaryKey { + defaultUpdatingColumns = append(defaultUpdatingColumns, dbName) + } + } } - for _, dbName := range s.PrimaryFieldDBNames { - columns = append(columns, clause.Column{Name: dbName}) + if len(defaultUpdatingColumns) > 0 { + var columns []clause.Column + if s.PrioritizedPrimaryField != nil { + columns = []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}} + } else { + for _, dbName := range s.PrimaryFieldDBNames { + columns = append(columns, clause.Column{Name: dbName}) + } + } + + return clause.OnConflict{ + Columns: columns, + DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns), + } } - return + + return clause.OnConflict{DoNothing: true} } diff --git a/callbacks/create.go b/callbacks/create.go index c00a0a73..8e2454e8 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -88,7 +88,10 @@ func Create(config *Config) func(db *gorm.DB) { } case reflect.Struct: if insertID > 0 { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { + + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } } } } else { diff --git a/gorm.go b/gorm.go index 8efd8a73..e5c4a8a4 100644 --- a/gorm.go +++ b/gorm.go @@ -20,6 +20,8 @@ type Config struct { SkipDefaultTransaction bool // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer + // FullSaveAssociations full save associations + FullSaveAssociations bool // Logger Logger logger.Interface // NowFunc the function to be used when creating a new timestamp @@ -64,6 +66,7 @@ type Session struct { WithConditions bool SkipDefaultTransaction bool AllowGlobalUpdate bool + FullSaveAssociations bool Context context.Context Logger logger.Interface NowFunc func() time.Time @@ -161,6 +164,10 @@ func (db *DB) Session(config *Session) *DB { txConfig.AllowGlobalUpdate = true } + if config.FullSaveAssociations { + txConfig.FullSaveAssociations = true + } + if config.Context != nil { tx.Statement = tx.Statement.clone() tx.Statement.DB = tx diff --git a/logger/logger.go b/logger/logger.go index 831192fc..e568fb24 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -20,6 +20,7 @@ const ( Magenta = "\033[35m" Cyan = "\033[36m" White = "\033[37m" + BlueBold = "\033[34;1m" MagentaBold = "\033[35;1m" RedBold = "\033[31;1m" YellowBold = "\033[33;1m" @@ -76,11 +77,11 @@ func New(writer Writer, config Config) Interface { if config.Colorful { infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset - warnStr = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset + warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset - traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" + traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s" traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset - traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" + traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s" } return &logger{ diff --git a/tests/update_belongs_to_test.go b/tests/update_belongs_to_test.go index 47076e69..736dfc5b 100644 --- a/tests/update_belongs_to_test.go +++ b/tests/update_belongs_to_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -22,4 +23,22 @@ func TestUpdateBelongsTo(t *testing.T) { var user2 User DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + user.Company.Name += "new" + user.Manager.Name += "new" + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Company").Preload("Manager").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) } diff --git a/tests/update_has_many_test.go b/tests/update_has_many_test.go index 01ea2e3a..9066cbac 100644 --- a/tests/update_has_many_test.go +++ b/tests/update_has_many_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -22,6 +23,26 @@ func TestUpdateHasManyAssociations(t *testing.T) { DB.Preload("Pets").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + for _, pet := range user.Pets { + pet.Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Pets").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Pets").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) + t.Run("Polymorphic", func(t *testing.T) { var user = *GetUser("update-has-many", Config{}) @@ -37,5 +58,25 @@ func TestUpdateHasManyAssociations(t *testing.T) { var user2 User DB.Preload("Toys").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + for idx := range user.Toys { + user.Toys[idx].Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Toys").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Toys").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) }) } diff --git a/tests/update_has_one_test.go b/tests/update_has_one_test.go index 7b29f424..54568546 100644 --- a/tests/update_has_one_test.go +++ b/tests/update_has_one_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -23,6 +24,23 @@ func TestUpdateHasOne(t *testing.T) { DB.Preload("Account").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + user.Account.Number += "new" + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Account").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Account").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) + t.Run("Polymorphic", func(t *testing.T) { var pet = Pet{Name: "create"} @@ -39,5 +57,22 @@ func TestUpdateHasOne(t *testing.T) { var pet2 Pet DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID) CheckPet(t, pet2, pet) + + pet.Toy.Name += "new" + if err := DB.Save(&pet).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var pet3 Pet + DB.Preload("Toy").Find(&pet3, "id = ?", pet.ID) + CheckPet(t, pet2, pet3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&pet).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var pet4 Pet + DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID) + CheckPet(t, pet4, pet) }) } diff --git a/tests/update_many2many_test.go b/tests/update_many2many_test.go index a46deeb0..d94ef4ab 100644 --- a/tests/update_many2many_test.go +++ b/tests/update_many2many_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -26,4 +27,28 @@ func TestUpdateMany2ManyAssociations(t *testing.T) { var user2 User DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + for idx := range user.Friends { + user.Friends[idx].Name += "new" + } + + for idx := range user.Languages { + user.Languages[idx].Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Languages").Preload("Friends").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Languages").Preload("Friends").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) } From ba253982bf558543187f3eb88295b88610cdc83b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 24 Sep 2020 20:08:24 +0800 Subject: [PATCH 683/881] Fix Pluck with Time and Scanner --- scan.go | 13 +++++++++++-- schema/field.go | 6 ++++-- tests/query_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/scan.go b/scan.go index be8782ed..d7cddbe6 100644 --- a/scan.go +++ b/scan.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "reflect" "strings" + "time" "gorm.io/gorm/schema" ) @@ -82,7 +83,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { scanIntoMap(mapValue, values, columns) *dest = append(*dest, mapValue) } - case *int, *int64, *uint, *uint64, *float32, *float64, *string: + case *int, *int64, *uint, *uint64, *float32, *float64, *string, *time.Time: for initialized || rows.Next() { initialized = false db.RowsAffected++ @@ -134,7 +135,15 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } // pluck values into slice of data - isPluck := len(fields) == 1 && reflectValueType.Kind() != reflect.Struct + isPluck := false + if len(fields) == 1 { + if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok { + isPluck = true + } else if reflectValueType.Kind() != reflect.Struct || reflectValueType.ConvertibleTo(schema.TimeReflectType) { + isPluck = true + } + } + for initialized || rows.Next() { initialized = false db.RowsAffected++ diff --git a/schema/field.go b/schema/field.go index ce2808a8..db516c33 100644 --- a/schema/field.go +++ b/schema/field.go @@ -18,6 +18,8 @@ type DataType string type TimeType int64 +var TimeReflectType = reflect.TypeOf(time.Time{}) + const ( UnixSecond TimeType = 1 UnixMillisecond TimeType = 2 @@ -102,7 +104,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { var getRealFieldValue func(reflect.Value) getRealFieldValue = func(v reflect.Value) { rv := reflect.Indirect(v) - if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) { for i := 0; i < rv.Type().NumField(); i++ { newFieldType := rv.Type().Field(i).Type for newFieldType.Kind() == reflect.Ptr { @@ -221,7 +223,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.Struct: if _, ok := fieldValue.Interface().(*time.Time); ok { field.DataType = Time - } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + } else if fieldValue.Type().ConvertibleTo(TimeReflectType) { field.DataType = Time } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { field.DataType = Time diff --git a/tests/query_test.go b/tests/query_test.go index 9c9ad9f2..431ccce2 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "database/sql" "fmt" "reflect" "regexp" @@ -431,6 +432,33 @@ func TestPluck(t *testing.T) { t.Errorf("Unexpected result on pluck id, got %+v", ids) } } + + var times []time.Time + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range times { + AssertEqual(t, tv, users[idx].CreatedAt) + } + + var ptrtimes []*time.Time + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &ptrtimes).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range ptrtimes { + AssertEqual(t, tv, users[idx].CreatedAt) + } + + var nulltimes []sql.NullTime + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &nulltimes).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range nulltimes { + AssertEqual(t, tv.Time, users[idx].CreatedAt) + } } func TestSelect(t *testing.T) { From 9eec6ae06638665661f9872e783a42613527e146 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Sep 2020 12:25:38 +0800 Subject: [PATCH 684/881] Fix affected rows for Scan, change affected rows count for row/rows to '-', close #3532 --- callbacks.go | 1 - callbacks/row.go | 2 ++ finisher_api.go | 8 ++++++++ logger/logger.go | 49 +++++++++++++++++++++++++++++++++++++++--------- scan.go | 1 + 5 files changed, 51 insertions(+), 10 deletions(-) diff --git a/callbacks.go b/callbacks.go index 83d103df..fdde21e9 100644 --- a/callbacks.go +++ b/callbacks.go @@ -74,7 +74,6 @@ func (cs *callbacks) Raw() *processor { func (p *processor) Execute(db *DB) { curTime := time.Now() stmt := db.Statement - db.RowsAffected = 0 if stmt.Model == nil { stmt.Model = stmt.Dest diff --git a/callbacks/row.go b/callbacks/row.go index a36c0116..4f985d7b 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -16,6 +16,8 @@ func RowQuery(db *gorm.DB) { } else { db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } + + db.RowsAffected = -1 } } } diff --git a/finisher_api.go b/finisher_api.go index 2c56d763..63061553 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -8,6 +8,7 @@ import ( "strings" "gorm.io/gorm/clause" + "gorm.io/gorm/logger" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) @@ -353,7 +354,9 @@ func (db *DB) Rows() (*sql.Rows, error) { // Scan scan value to a struct func (db *DB) Scan(dest interface{}) (tx *DB) { + currentLogger, newLogger := db.Logger, logger.Recorder.New() tx = db.getInstance() + tx.Logger = newLogger if rows, err := tx.Rows(); err != nil { tx.AddError(err) } else { @@ -362,6 +365,11 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { tx.ScanRows(rows, dest) } } + + currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) { + return newLogger.SQL, tx.RowsAffected + }, tx.Error) + tx.Logger = currentLogger return } diff --git a/logger/logger.go b/logger/logger.go index e568fb24..b278ad5d 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -63,6 +63,7 @@ var ( LogLevel: Warn, Colorful: true, }) + Recorder = traceRecorder{Interface: Default} ) func New(writer Writer, config Config) Interface { @@ -70,18 +71,18 @@ func New(writer Writer, config Config) Interface { infoStr = "%s\n[info] " warnStr = "%s\n[warn] " errStr = "%s\n[error] " - traceStr = "%s\n[%.3fms] [rows:%d] %s" - traceWarnStr = "%s\n[%.3fms] [rows:%d] %s" - traceErrStr = "%s %s\n[%.3fms] [rows:%d] %s" + traceStr = "%s\n[%.3fms] [rows:%v] %s" + traceWarnStr = "%s\n[%.3fms] [rows:%v] %s" + traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s" ) if config.Colorful { infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset - traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s" - traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset - traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s" + traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" + traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset + traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" } return &logger{ @@ -138,13 +139,43 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i switch { case err != nil && l.LogLevel >= Error: sql, rows := fc() - l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) + if rows == -1 { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: sql, rows := fc() - l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + if rows == -1 { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + } case l.LogLevel >= Info: sql, rows := fc() - l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + if rows == -1 { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + } } } } + +type traceRecorder struct { + Interface + BeginAt time.Time + SQL string + RowsAffected int64 + Err error +} + +func (l traceRecorder) New() *traceRecorder { + return &traceRecorder{Interface: l.Interface} +} + +func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { + l.BeginAt = begin + l.SQL, l.RowsAffected = fc() + l.Err = err +} diff --git a/scan.go b/scan.go index d7cddbe6..8d737b17 100644 --- a/scan.go +++ b/scan.go @@ -52,6 +52,7 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns func Scan(rows *sql.Rows, db *DB, initialized bool) { columns, _ := rows.Columns() values := make([]interface{}, len(columns)) + db.RowsAffected = 0 switch dest := db.Statement.Dest.(type) { case map[string]interface{}, *map[string]interface{}: From a2faa41cbe55dc37e2e0c30cab0fcd1b6d00c5fe Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 28 Sep 2020 10:55:27 +0800 Subject: [PATCH 685/881] Refactor NamingStrategy, close #3540 --- schema/naming.go | 7 ++++--- schema/naming_test.go | 46 ++++++++++++++++++++++++------------------- 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/schema/naming.go b/schema/naming.go index af753ce5..dbc71e04 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -42,7 +42,7 @@ func (ns NamingStrategy) ColumnName(table, column string) string { // JoinTableName convert string to join table name func (ns NamingStrategy) JoinTableName(str string) string { if strings.ToLower(str) == str { - return str + return ns.TablePrefix + str } if ns.SingularTable { @@ -53,17 +53,18 @@ func (ns NamingStrategy) JoinTableName(str string) string { // RelationshipFKName generate fk name for relation func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { - return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name)) + return strings.Replace(fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name)), ".", "_", -1) } // CheckerName generate checker name func (ns NamingStrategy) CheckerName(table, column string) string { - return fmt.Sprintf("chk_%s_%s", table, column) + return strings.Replace(fmt.Sprintf("chk_%s_%s", table, column), ".", "_", -1) } // IndexName generate index name func (ns NamingStrategy) IndexName(table, column string) string { idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column)) + idxName = strings.Replace(idxName, ".", "_", -1) if utf8.RuneCountInString(idxName) > 64 { h := sha1.New() diff --git a/schema/naming_test.go b/schema/naming_test.go index a4600ceb..26b0dcf6 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -1,7 +1,6 @@ package schema import ( - "strings" "testing" ) @@ -34,27 +33,34 @@ func TestToDBName(t *testing.T) { } } -type NewNamingStrategy struct { - NamingStrategy -} +func TestNamingStrategy(t *testing.T) { + var ns = NamingStrategy{ + TablePrefix: "public.", + SingularTable: true, + } + idxName := ns.IndexName("public.table", "name") -func (ns NewNamingStrategy) ColumnName(table, column string) string { - baseColumnName := ns.NamingStrategy.ColumnName(table, column) - - if table == "" { - return baseColumnName + if idxName != "idx_public_table_name" { + t.Errorf("invalid index name generated, got %v", idxName) } - s := strings.Split(table, "_") - - var prefix string - switch len(s) { - case 1: - prefix = s[0][:3] - case 2: - prefix = s[0][:1] + s[1][:2] - default: - prefix = s[0][:1] + s[1][:1] + s[2][:1] + chkName := ns.CheckerName("public.table", "name") + if chkName != "chk_public_table_name" { + t.Errorf("invalid checker name generated, got %v", chkName) + } + + joinTable := ns.JoinTableName("user_languages") + if joinTable != "public.user_languages" { + t.Errorf("invalid join table generated, got %v", joinTable) + } + + joinTable2 := ns.JoinTableName("UserLanguage") + if joinTable2 != "public.user_language" { + t.Errorf("invalid join table generated, got %v", joinTable2) + } + + tableName := ns.TableName("Company") + if tableName != "public.company" { + t.Errorf("invalid table name generated, got %v", tableName) } - return prefix + "_" + baseColumnName } From dbc6b34dce7f5c4ce6f358d23bc70ac738af7793 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 29 Sep 2020 15:42:58 +0800 Subject: [PATCH 686/881] Add detailed error information when missing table name --- callbacks.go | 6 +++++- tests/go.mod | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/callbacks.go b/callbacks.go index fdde21e9..e21e0718 100644 --- a/callbacks.go +++ b/callbacks.go @@ -83,7 +83,11 @@ func (p *processor) Execute(db *DB) { if stmt.Model != nil { if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { - db.AddError(err) + if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" { + db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err)) + } else { + db.AddError(err) + } } } diff --git a/tests/go.mod b/tests/go.mod index c92fa0cf..cbafcd7e 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,10 +7,10 @@ require ( github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 gorm.io/driver/mysql v1.0.1 - gorm.io/driver/postgres v1.0.1 + gorm.io/driver/postgres v1.0.2 gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlserver v1.0.4 - gorm.io/gorm v1.20.1 + gorm.io/gorm v1.20.2 ) replace gorm.io/gorm => ../ From 7faf1ca80fa00e0737f0c2efb2c57fb036adebdf Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 9 Oct 2020 11:52:12 +0800 Subject: [PATCH 687/881] Fix Select with AS, close #3581, #3567 --- chainable_api.go | 2 +- tests/go.mod | 2 +- tests/query_test.go | 10 ++++++++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index ae2ac4f1..c3a02d20 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -97,7 +97,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { // normal field names if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { - tx.Statement.Selects = fields + tx.Statement.Selects = []string{v} for _, arg := range args { switch arg := arg.(type) { diff --git a/tests/go.mod b/tests/go.mod index cbafcd7e..9b36f1ed 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,7 +6,7 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v1.0.1 + gorm.io/driver/mysql v1.0.2 gorm.io/driver/postgres v1.0.2 gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlserver v1.0.4 diff --git a/tests/query_test.go b/tests/query_test.go index 431ccce2..bb9aa26d 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -475,6 +475,16 @@ func TestSelect(t *testing.T) { t.Errorf("Should have user Name when selected it") } + var result2 User + DB.Where("name = ?", user.Name).Select("name as name").Find(&result2) + if result2.ID != 0 { + t.Errorf("Should not have ID because only selected name, %+v", result2.ID) + } + + if user.Name != result2.Name { + t.Errorf("Should have user Name when selected it") + } + dryDB := DB.Session(&gorm.Session{DryRun: true}) r := dryDB.Select("name", "age").Find(&User{}) if !regexp.MustCompile("SELECT .*name.*,.*age.* FROM .*users.*").MatchString(r.Statement.SQL.String()) { From 3d846957cd57c1660233ce7e0f6c56f21a030ccf Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 9 Oct 2020 17:39:35 +0800 Subject: [PATCH 688/881] Compatible with tag notNull --- schema/field.go | 2 ++ tests/default_value_test.go | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/schema/field.go b/schema/field.go index db516c33..e7f5b708 100644 --- a/schema/field.go +++ b/schema/field.go @@ -170,6 +170,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { field.NotNull = true + } else if val, ok := field.TagSettings["NOTNULL"]; ok && utils.CheckTruth(val) { + field.NotNull = true } if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) { diff --git a/tests/default_value_test.go b/tests/default_value_test.go index aa4a511a..14a0a977 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -10,9 +10,9 @@ func TestDefaultValue(t *testing.T) { type Harumph struct { gorm.Model Email string `gorm:"not null;index:,unique"` - Name string `gorm:"not null;default:foo"` + Name string `gorm:"notNull;default:foo"` Name2 string `gorm:"size:233;not null;default:'foo'"` - Name3 string `gorm:"size:233;not null;default:''"` + Name3 string `gorm:"size:233;notNull;default:''"` Age int `gorm:"default:18"` Enabled bool `gorm:"default:true"` } From 063b1ca0c41740577655ff3b0c524bcbe587a54f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Oct 2020 10:56:00 +0800 Subject: [PATCH 689/881] Refactor SlowSQL log --- logger/logger.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/logger/logger.go b/logger/logger.go index b278ad5d..6782736c 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -2,6 +2,7 @@ package logger import ( "context" + "fmt" "io/ioutil" "log" "os" @@ -59,7 +60,7 @@ type Interface interface { var ( Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ - SlowThreshold: 100 * time.Millisecond, + SlowThreshold: 200 * time.Millisecond, LogLevel: Warn, Colorful: true, }) @@ -72,7 +73,7 @@ func New(writer Writer, config Config) Interface { warnStr = "%s\n[warn] " errStr = "%s\n[error] " traceStr = "%s\n[%.3fms] [rows:%v] %s" - traceWarnStr = "%s\n[%.3fms] [rows:%v] %s" + traceWarnStr = "%s %s\n[%.3fms] [rows:%v] %s" traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s" ) @@ -81,7 +82,7 @@ func New(writer Writer, config Config) Interface { warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" - traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset + traceWarnStr = Green + "%s " + Yellow + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" } @@ -146,10 +147,11 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i } case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: sql, rows := fc() + slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold) if rows == -1 { - l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql) } else { - l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql) } case l.LogLevel >= Info: sql, rows := fc() From 689d6e23319ea84c07b4943341361bd0ea09b780 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 13 Oct 2020 14:12:03 +0800 Subject: [PATCH 690/881] Fix DeletedAt marshalling, close #3598 --- soft_delete.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/soft_delete.go b/soft_delete.go index b13fc63f..b15a8148 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -3,6 +3,7 @@ package gorm import ( "database/sql" "database/sql/driver" + "encoding/json" "reflect" "gorm.io/gorm/clause" @@ -24,6 +25,18 @@ func (n DeletedAt) Value() (driver.Value, error) { return n.Time, nil } +func (n DeletedAt) MarshalJSON() ([]byte, error) { + return json.Marshal(n.Time) +} + +func (n *DeletedAt) UnmarshalJSON(b []byte) error { + err := json.Unmarshal(b, &n.Time) + if err == nil { + n.Valid = true + } + return err +} + func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { return []clause.Interface{SoftDeleteQueryClause{Field: f}} } From 08ecef8e0b12f8db0b2127b0bcddf7caea447fe3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 13 Oct 2020 15:32:25 +0800 Subject: [PATCH 691/881] Fix NamedArguments with nested struct, close #3596 --- clause/expression.go | 23 ++++++++++++++++------- clause/expression_test.go | 8 ++++++-- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 6a0dde8d..5822a314 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -91,16 +91,25 @@ func (expr NamedExpr) Build(builder Builder) { namedMap[k] = v } default: - reflectValue := reflect.Indirect(reflect.ValueOf(value)) - switch reflectValue.Kind() { - case reflect.Struct: - modelType := reflectValue.Type() - for i := 0; i < modelType.NumField(); i++ { - if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { - namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface() + var appendFieldsToMap func(reflect.Value) + appendFieldsToMap = func(reflectValue reflect.Value) { + reflectValue = reflect.Indirect(reflectValue) + switch reflectValue.Kind() { + case reflect.Struct: + modelType := reflectValue.Type() + for i := 0; i < modelType.NumField(); i++ { + if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { + namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface() + + if fieldStruct.Anonymous { + appendFieldsToMap(reflectValue.Field(i)) + } + } } } } + + appendFieldsToMap(reflect.ValueOf(value)) } } diff --git a/clause/expression_test.go b/clause/expression_test.go index 19e30e6c..83082486 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -37,9 +37,13 @@ func TestExpr(t *testing.T) { } func TestNamedExpr(t *testing.T) { + type Base struct { + Name2 string + } + type NamedArgument struct { Name1 string - Name2 string + Base } results := []struct { @@ -73,7 +77,7 @@ func TestNamedExpr(t *testing.T) { ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, }, { SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @Notexist", - Vars: []interface{}{NamedArgument{Name1: "jinzhu", Name2: "jinzhu2"}}, + Vars: []interface{}{NamedArgument{Name1: "jinzhu", Base: Base{Name2: "jinzhu2"}}}, Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, }, { From d825554307ba34292c6d9cbbd425c550c2ddb818 Mon Sep 17 00:00:00 2001 From: TABRIZ ATAYI Date: Sun, 18 Oct 2020 00:05:43 +0200 Subject: [PATCH 692/881] nil point transfer '' not transfer NULL #3604 --- logger/sql.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 138a35ec..0ffe6b41 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -48,8 +48,6 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } else { vars[idx] = "NULL" } - case fmt.Stringer: - vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper case driver.Valuer: reflectValue := reflect.ValueOf(v) if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { From a1ea1713b008c7e3bf01771701ffab50a98461d5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 19 Oct 2020 11:04:18 +0800 Subject: [PATCH 693/881] Fix log Stringer --- logger/sql.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/logger/sql.go b/logger/sql.go index 0ffe6b41..d080def2 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -48,6 +48,13 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } else { vars[idx] = "NULL" } + case fmt.Stringer: + reflectValue := reflect.ValueOf(v) + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper + } else { + vars[idx] = "NULL" + } case driver.Valuer: reflectValue := reflect.ValueOf(v) if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { From 9dbef26feb3c9554aecdb792c4029fb3a68ac16e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 19 Oct 2020 11:49:03 +0800 Subject: [PATCH 694/881] Fix feature request label --- .github/labels.json | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/.github/labels.json b/.github/labels.json index 6b9c2034..5c7eb7d1 100644 --- a/.github/labels.json +++ b/.github/labels.json @@ -10,6 +10,11 @@ "colour": "#EDEDED", "description": "general questions" }, + "feature": { + "name": "type:feature_request", + "colour": "#43952A", + "description": "feature request" + }, "invalid_question": { "name": "type:invalid question", "colour": "#CF2E1F", @@ -82,8 +87,21 @@ } ] }, + "feature": { + "requires": 1, + "conditions": [ + { + "type": "titleMatches", + "pattern": "/feature/i" + }, + { + "type": "descriptionMatches", + "pattern": "/Describe the feature/i" + } + ] + }, "without_playground": { - "requires": 5, + "requires": 6, "conditions": [ { "type": "descriptionMatches", @@ -97,6 +115,10 @@ "type": "descriptionMatches", "pattern": "/^((?!question).)*$/is" }, + { + "type": "descriptionMatches", + "pattern": "/^((?!Describe the feature).)*$/is" + }, { "type": "titleMatches", "pattern": "/^((?!critical|urgent).)*$/s" From 9b2181199d88ed6f74650d73fa9d20264dd134c0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 19 Oct 2020 14:49:42 +0800 Subject: [PATCH 695/881] Fix soft delete with OrCondition, close #3627 --- clause/where.go | 37 ++++++++++++++----------------------- finisher_api.go | 2 ++ soft_delete.go | 13 +++++++++++++ tests/count_test.go | 2 +- tests/sql_builder_test.go | 6 +++--- 5 files changed, 33 insertions(+), 27 deletions(-) diff --git a/clause/where.go b/clause/where.go index a3774e1c..00b1a40e 100644 --- a/clause/where.go +++ b/clause/where.go @@ -26,17 +26,22 @@ func (where Where) Build(builder Builder) { } } + buildExprs(where.Exprs, builder, " AND ") +} + +func buildExprs(exprs []Expression, builder Builder, joinCond string) { wrapInParentheses := false - for idx, expr := range where.Exprs { + + for idx, expr := range exprs { if idx > 0 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { builder.WriteString(" OR ") } else { - builder.WriteString(" AND ") + builder.WriteString(joinCond) } } - if len(where.Exprs) > 1 { + if len(exprs) > 1 { switch v := expr.(type) { case OrConditions: if len(v.Exprs) == 1 { @@ -97,19 +102,10 @@ type AndConditions struct { func (and AndConditions) Build(builder Builder) { if len(and.Exprs) > 1 { builder.WriteByte('(') - } - for idx, c := range and.Exprs { - if idx > 0 { - if orConditions, ok := c.(OrConditions); ok && len(orConditions.Exprs) == 1 { - builder.WriteString(" OR ") - } else { - builder.WriteString(" AND ") - } - } - c.Build(builder) - } - if len(and.Exprs) > 1 { + buildExprs(and.Exprs, builder, " AND ") builder.WriteByte(')') + } else { + buildExprs(and.Exprs, builder, " AND ") } } @@ -127,15 +123,10 @@ type OrConditions struct { func (or OrConditions) Build(builder Builder) { if len(or.Exprs) > 1 { builder.WriteByte('(') - } - for idx, c := range or.Exprs { - if idx > 0 { - builder.WriteString(" OR ") - } - c.Build(builder) - } - if len(or.Exprs) > 1 { + buildExprs(or.Exprs, builder, " OR ") builder.WriteByte(')') + } else { + buildExprs(or.Exprs, builder, " OR ") } } diff --git a/finisher_api.go b/finisher_api.go index 63061553..2951fdef 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -154,6 +154,8 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) } } + } else if andCond, ok := expr.(clause.AndConditions); ok { + tx.assignInterfacesToValue(andCond.Exprs) } } case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: diff --git a/soft_delete.go b/soft_delete.go index b15a8148..b3280ff7 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -57,6 +57,19 @@ func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) { func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok { + if c, ok := stmt.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 { + for _, expr := range where.Exprs { + if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 { + where.Exprs = []clause.Expression{clause.And(where.Exprs...)} + c.Expression = where + stmt.Clauses["WHERE"] = c + break + } + } + } + } + stmt.AddClause(clause.Where{Exprs: []clause.Expression{ clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil}, }}) diff --git a/tests/count_test.go b/tests/count_test.go index 216fa3a1..0d348227 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -69,7 +69,7 @@ func TestCount(t *testing.T) { } var count4 int64 - if err := DB.Debug().Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 { + if err := DB.Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 { t.Errorf("count with join, got error: %v, count %v", err, count) } } diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index c0176fc3..acb08130 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -198,17 +198,17 @@ func TestCombineStringConditions(t *testing.T) { } sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String() - if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR \(c = .+ and d = .+\)\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) } sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String() - if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR c = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR c = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) } sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() - if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\)\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) } From 33a11767eafce30831d105a6b64cc7b54a279352 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 20 Oct 2020 19:13:15 +0800 Subject: [PATCH 696/881] Upgrade test go.mod dependencies --- tests/go.mod | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index 9b36f1ed..87d221ca 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,12 +7,10 @@ require ( github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 gorm.io/driver/mysql v1.0.2 - gorm.io/driver/postgres v1.0.2 + gorm.io/driver/postgres v1.0.3 gorm.io/driver/sqlite v1.1.3 - gorm.io/driver/sqlserver v1.0.4 + gorm.io/driver/sqlserver v1.0.5 gorm.io/gorm v1.20.2 ) replace gorm.io/gorm => ../ - -replace github.com/jackc/pgx/v4 => github.com/jinzhu/pgx/v4 v4.8.2 From bdb30da0a7af4329238ba2a17b46860aa4d18a65 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 21 Oct 2020 15:47:46 +0800 Subject: [PATCH 697/881] Fix copy lock for prepared statement, close #3642, #3607 --- gorm.go | 1 + prepare_stmt.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/gorm.go b/gorm.go index e5c4a8a4..affa8e69 100644 --- a/gorm.go +++ b/gorm.go @@ -117,6 +117,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { preparedStmt := &PreparedStmtDB{ ConnPool: db.ConnPool, Stmts: map[string]*sql.Stmt{}, + Mux: &sync.RWMutex{}, PreparedSQL: make([]string, 0, 100), } db.cacheStore.Store("preparedStmt", preparedStmt) diff --git a/prepare_stmt.go b/prepare_stmt.go index 14a6aaec..eddee1f2 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -9,7 +9,7 @@ import ( type PreparedStmtDB struct { Stmts map[string]*sql.Stmt PreparedSQL []string - Mux sync.RWMutex + Mux *sync.RWMutex ConnPool } From 635dcc9ad4faa02cf625b050a7b439bd44292407 Mon Sep 17 00:00:00 2001 From: Michelle Date: Wed, 21 Oct 2020 12:35:33 +0200 Subject: [PATCH 698/881] add gorm ColumnType interface, remove sql one (#3647) --- migrator.go | 14 ++++++++++---- migrator/migrator.go | 15 ++++++++++----- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/migrator.go b/migrator.go index 162fe680..ac06a144 100644 --- a/migrator.go +++ b/migrator.go @@ -1,8 +1,6 @@ package gorm import ( - "database/sql" - "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) @@ -24,6 +22,14 @@ type ViewOption struct { Query *DB } +type ColumnType interface { + Name() string + DatabaseTypeName() string + Length() (length int64, ok bool) + DecimalSize() (precision int64, scale int64, ok bool) + Nullable() (nullable bool, ok bool) +} + type Migrator interface { // AutoMigrate AutoMigrate(dst ...interface{}) error @@ -42,10 +48,10 @@ type Migrator interface { AddColumn(dst interface{}, field string) error DropColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error - MigrateColumn(dst interface{}, field *schema.Field, columnType *sql.ColumnType) error + MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error HasColumn(dst interface{}, field string) bool RenameColumn(dst interface{}, oldName, field string) error - ColumnTypes(dst interface{}) ([]*sql.ColumnType, error) + ColumnTypes(dst interface{}) ([]ColumnType, error) // Views CreateView(name string, option ViewOption) error diff --git a/migrator/migrator.go b/migrator/migrator.go index f390ff9f..ca8e63ca 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -2,7 +2,6 @@ package migrator import ( "context" - "database/sql" "fmt" "reflect" "regexp" @@ -92,7 +91,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { columnTypes, _ := m.DB.Migrator().ColumnTypes(value) for _, field := range stmt.Schema.FieldsByDBName { - var foundColumn *sql.ColumnType + var foundColumn gorm.ColumnType for _, columnType := range columnTypes { if columnType.Name() == field.DBName { @@ -352,7 +351,7 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error }) } -func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType *sql.ColumnType) error { +func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { // found, smart migrate fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) realDataType := strings.ToLower(columnType.DatabaseTypeName()) @@ -395,12 +394,18 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy return nil } -func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { +func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) { + columnTypes = make([]gorm.ColumnType, 0) err = m.RunWithValue(value, func(stmt *gorm.Statement) error { rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() if err == nil { defer rows.Close() - columnTypes, err = rows.ColumnTypes() + rawColumnTypes, err := rows.ColumnTypes() + if err == nil { + for _, c := range rawColumnTypes { + columnTypes = append(columnTypes, c) + } + } } return err }) From 5fee5b1b24227e6bda03caa4c27cb05b4a81b717 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 21 Oct 2020 20:15:49 +0800 Subject: [PATCH 699/881] Add option tag support for index --- migrator/migrator.go | 12 +++++++++++- schema/index.go | 5 +++++ schema/index_test.go | 5 +++-- tests/go.mod | 2 +- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index ca8e63ca..c564cb67 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -188,7 +188,13 @@ func (m Migrator) CreateTable(values ...interface{}) error { if idx.Class != "" { createTableSQL += idx.Class + " " } - createTableSQL += "INDEX ? ?," + createTableSQL += "INDEX ? ?" + + if idx.Option != "" { + createTableSQL += " " + idx.Option + } + + createTableSQL += "," values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) } } @@ -543,6 +549,10 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { createIndexSQL += " USING " + idx.Type } + if idx.Option != "" { + createIndexSQL += " " + idx.Option + } + return m.DB.Exec(createIndexSQL, values...).Error } diff --git a/schema/index.go b/schema/index.go index fb7ea501..b54e08ad 100644 --- a/schema/index.go +++ b/schema/index.go @@ -12,6 +12,7 @@ type Index struct { Type string // btree, hash, gist, spgist, gin, and brin Where string Comment string + Option string // WITH PARSER parser_name Fields []IndexOption } @@ -45,6 +46,9 @@ func (schema *Schema) ParseIndexes() map[string]Index { if idx.Comment == "" { idx.Comment = index.Comment } + if idx.Option == "" { + idx.Option = index.Option + } idx.Fields = append(idx.Fields, index.Fields...) sort.Slice(idx.Fields, func(i, j int) bool { @@ -119,6 +123,7 @@ func parseFieldIndexes(field *Field) (indexes []Index) { Type: settings["TYPE"], Where: settings["WHERE"], Comment: settings["COMMENT"], + Option: settings["OPTION"], Fields: []IndexOption{{ Field: field, Expression: settings["EXPRESSION"], diff --git a/schema/index_test.go b/schema/index_test.go index dc1fb43b..bc6bb8b6 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -15,7 +15,7 @@ type UserIndex struct { Name4 string `gorm:"uniqueIndex"` Name5 int64 `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"` Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"` - Age int64 `gorm:"index:profile,expression:ABS(age)"` + Age int64 `gorm:"index:profile,expression:ABS(age),option:WITH PARSER parser_name"` OID int64 `gorm:"index:idx_id;index:idx_oid,unique"` MemberNumber string `gorm:"index:idx_id,priority:1"` } @@ -63,6 +63,7 @@ func TestParseIndex(t *testing.T) { Name: "profile", Comment: "hello , world", Where: "age > 10", + Option: "WITH PARSER parser_name", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name6"}}, { Field: &schema.Field{Name: "Age"}, Expression: "ABS(age)", @@ -87,7 +88,7 @@ func TestParseIndex(t *testing.T) { t.Fatalf("Failed to found index %v from parsed indices %+v", k, indices) } - for _, name := range []string{"Name", "Class", "Type", "Where", "Comment"} { + for _, name := range []string{"Name", "Class", "Type", "Where", "Comment", "Option"} { if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() { t.Errorf( "index %v %v should equal, expects %v, got %v", diff --git a/tests/go.mod b/tests/go.mod index 87d221ca..ddb1773b 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,7 +7,7 @@ require ( github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 gorm.io/driver/mysql v1.0.2 - gorm.io/driver/postgres v1.0.3 + gorm.io/driver/postgres v1.0.4 gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlserver v1.0.5 gorm.io/gorm v1.20.2 From 231aba53c58fcb9ca0e3a70375eba88b337ad4cc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 22 Oct 2020 11:28:43 +0800 Subject: [PATCH 700/881] Fix count with order by --- finisher_api.go | 9 +++++++++ tests/count_test.go | 5 +++++ 2 files changed, 14 insertions(+) diff --git a/finisher_api.go b/finisher_api.go index 2951fdef..30616284 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -326,6 +326,15 @@ func (db *DB) Count(count *int64) (tx *DB) { defer tx.Statement.AddClause(clause.Select{}) } + if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok { + if _, ok := db.Statement.Clauses["GROUP BY"]; !ok { + delete(db.Statement.Clauses, "ORDER BY") + defer func() { + db.Statement.Clauses["ORDER BY"] = orderByClause + }() + } + } + tx.Statement.Dest = count tx.callbacks.Query().Execute(tx) if tx.RowsAffected != 1 { diff --git a/tests/count_test.go b/tests/count_test.go index 0d348227..41bad71d 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -70,6 +70,11 @@ func TestCount(t *testing.T) { var count4 int64 if err := DB.Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 { + t.Errorf("count with join, got error: %v, count %v", err, count4) + } + + var count5 int64 + if err := DB.Table("users").Where("users.name = ?", user1.Name).Order("name").Count(&count5).Error; err != nil || count5 != 1 { t.Errorf("count with join, got error: %v, count %v", err, count) } } From 6d90d09cb86f5e57aacaee925d510cedaf839cae Mon Sep 17 00:00:00 2001 From: qifengzhang007 <15087404+qifengzhang007@users.noreply.github.com> Date: Thu, 22 Oct 2020 14:09:09 +0800 Subject: [PATCH 701/881] =?UTF-8?q?Recorder=E8=BF=BD=E8=B8=AA=E5=87=BD?= =?UTF-8?q?=E6=95=B0trace=E5=9C=A8finish=5Fapi=E6=96=87=E4=BB=B6358?= =?UTF-8?q?=E8=A1=8Cscan=E5=87=BD=E6=95=B0=E6=89=80=E5=9C=A8=E7=9A=84371?= =?UTF-8?q?=E8=A1=8C=E8=A2=AB=E8=B0=83=E7=94=A8=E6=97=B6=EF=BC=8CBeginAt?= =?UTF-8?q?=20=E6=B2=A1=E6=9C=89=E8=B5=8B=E5=80=BC=EF=BC=8C=E9=BB=98?= =?UTF-8?q?=E8=AE=A4=E5=80=BC0001-0:0:0=E5=AF=BC=E8=87=B4=E8=BF=BD?= =?UTF-8?q?=E8=B8=AA=E6=97=A5=E5=BF=97=E6=98=BE=E7=A4=BA=E7=9A=84sql?= =?UTF-8?q?=E8=80=97=E6=97=B6=E6=97=A0=E9=99=90=E5=A4=A7.=20(#3657)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 张奇峰 <10515935zwj> --- logger/logger.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/logger/logger.go b/logger/logger.go index 6782736c..11619c92 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -64,7 +64,7 @@ var ( LogLevel: Warn, Colorful: true, }) - Recorder = traceRecorder{Interface: Default} + Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} ) func New(writer Writer, config Config) Interface { @@ -173,7 +173,7 @@ type traceRecorder struct { } func (l traceRecorder) New() *traceRecorder { - return &traceRecorder{Interface: l.Interface} + return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} } func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { From 0aef8acc11c783808d1986e03b5e665f0c60fda4 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 22 Oct 2020 14:00:10 +0800 Subject: [PATCH 702/881] Add smart auto migrate tests --- migrator/migrator.go | 6 +++--- tests/go.mod | 6 +++--- tests/migrate_test.go | 16 +++++++++------- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index c564cb67..c455a294 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -370,9 +370,9 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy alterColumn = true } else { // has size in data type and not equal - matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllStringSubmatch(realDataType, -1) - matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1) - if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size)) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) { + matches := regexp.MustCompile(`[^\d](\d+)[^\d]?`).FindAllStringSubmatch(realDataType, -1) + matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]?`).FindAllStringSubmatch(fullDataType, -1) + if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) { alterColumn = true } } diff --git a/tests/go.mod b/tests/go.mod index ddb1773b..3fa011f1 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,11 +6,11 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v1.0.2 - gorm.io/driver/postgres v1.0.4 + gorm.io/driver/mysql v1.0.3 + gorm.io/driver/postgres v1.0.5 gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlserver v1.0.5 - gorm.io/gorm v1.20.2 + gorm.io/gorm v1.20.4 ) replace gorm.io/gorm => ../ diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 4cc8a7c3..275fe634 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -48,11 +48,13 @@ func TestMigrate(t *testing.T) { } func TestSmartMigrateColumn(t *testing.T) { + fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()] + type UserMigrateColumn struct { ID uint Name string Salary float64 - Birthday time.Time + Birthday time.Time `gorm:"precision:4"` } DB.Migrator().DropTable(&UserMigrateColumn{}) @@ -78,15 +80,15 @@ func TestSmartMigrateColumn(t *testing.T) { for _, columnType := range columnTypes { switch columnType.Name() { case "name": - if length, _ := columnType.Length(); length != 0 && length != 128 { + if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 128 { t.Fatalf("name's length should be 128, but got %v", length) } case "salary": - if precision, o, _ := columnType.DecimalSize(); precision != 0 && precision != 2 { + if precision, o, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 { t.Fatalf("salary's precision should be 2, but got %v %v", precision, o) } case "birthday": - if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 2 { + if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 { t.Fatalf("birthday's precision should be 2, but got %v", precision) } } @@ -111,15 +113,15 @@ func TestSmartMigrateColumn(t *testing.T) { for _, columnType := range columnTypes { switch columnType.Name() { case "name": - if length, _ := columnType.Length(); length != 0 && length != 256 { + if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 256 { t.Fatalf("name's length should be 128, but got %v", length) } case "salary": - if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 { + if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 { t.Fatalf("salary's precision should be 2, but got %v", precision) } case "birthday": - if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 { + if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 { t.Fatalf("birthday's precision should be 2, but got %v", precision) } } From db2630cb3a02edcc92678ed78e49d1e85d268224 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 22 Oct 2020 17:32:39 +0800 Subject: [PATCH 703/881] Fix data race problem when using Scan, close #3662 --- finisher_api.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 30616284..857f9419 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -365,9 +365,13 @@ func (db *DB) Rows() (*sql.Rows, error) { // Scan scan value to a struct func (db *DB) Scan(dest interface{}) (tx *DB) { - currentLogger, newLogger := db.Logger, logger.Recorder.New() + config := *db.Config + currentLogger, newLogger := config.Logger, logger.Recorder.New() + config.Logger = newLogger + tx = db.getInstance() - tx.Logger = newLogger + tx.Config = &config + if rows, err := tx.Rows(); err != nil { tx.AddError(err) } else { From dd92f8bdc0ba926a538dce7a84fd3b630d45c168 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 23 Oct 2020 11:01:45 +0800 Subject: [PATCH 704/881] Allow create table for other database/schema #3640 --- migrator/migrator.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/migrator/migrator.go b/migrator/migrator.go index c455a294..9493a00c 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -32,6 +32,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error stmt := &gorm.Statement{DB: m.DB} if m.DB.Statement != nil { stmt.Table = m.DB.Statement.Table + stmt.TableExpr = m.DB.Statement.TableExpr } if table, ok := value.(string); ok { @@ -161,6 +162,10 @@ func (m Migrator) CreateTable(values ...interface{}) error { hasPrimaryKeyInDataType bool ) + if stmt.TableExpr != nil { + values[0] = *stmt.TableExpr + } + for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.FieldsByDBName[dbName] createTableSQL += "? ?" From cb591a71299532f881a104cdb0abf7ae5b794a6f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 23 Oct 2020 18:40:05 +0800 Subject: [PATCH 705/881] Fix panic when using FirstOrCreate with soft delete, close #3671 --- schema/field.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/schema/field.go b/schema/field.go index e7f5b708..b303fb30 100644 --- a/schema/field.go +++ b/schema/field.go @@ -762,13 +762,15 @@ func (field *Field) setupValuerAndSetter() { // pointer scanner field.Set = func(value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) - if reflectV.Type().AssignableTo(field.FieldType) { + if !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { - if reflectV.IsNil() { + if reflectV.IsNil() || !reflectV.IsValid() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { - err = field.Set(value, reflectV.Elem().Interface()) + return field.Set(value, reflectV.Elem().Interface()) } } else { fieldValue := field.ReflectValueOf(value) From d011ebe7afbce397db6bf50a7aa12855cb74877f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 26 Oct 2020 10:17:25 +0800 Subject: [PATCH 706/881] Fix clone statement for Unscoped, UpdatingColumn, close #3681 --- statement.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/statement.go b/statement.go index 38d35926..567df869 100644 --- a/statement.go +++ b/statement.go @@ -408,6 +408,7 @@ func (stmt *Statement) clone() *Statement { TableExpr: stmt.TableExpr, Table: stmt.Table, Model: stmt.Model, + Unscoped: stmt.Unscoped, Dest: stmt.Dest, ReflectValue: stmt.ReflectValue, Clauses: map[string]clause.Clause{}, @@ -419,6 +420,7 @@ func (stmt *Statement) clone() *Statement { Schema: stmt.Schema, Context: stmt.Context, RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, + UpdatingColumn: stmt.UpdatingColumn, } for k, c := range stmt.Clauses { From 4009ec58163b97294633edc19f5d792546cd612c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 27 Oct 2020 18:14:36 +0800 Subject: [PATCH 707/881] Fix call hook methods when updating with struct --- callbacks/callmethod.go | 2 +- statement.go | 36 +++++++++++++++++++++++++++++------- tests/go.mod | 2 +- tests/hooks_test.go | 16 +++++++++++++--- 4 files changed, 44 insertions(+), 12 deletions(-) diff --git a/callbacks/callmethod.go b/callbacks/callmethod.go index 0160f354..b81fc915 100644 --- a/callbacks/callmethod.go +++ b/callbacks/callmethod.go @@ -8,7 +8,7 @@ import ( func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { tx := db.Session(&gorm.Session{}) - if called := fc(db.Statement.Dest, tx); !called { + if called := fc(db.Statement.ReflectValue.Interface(), tx); !called { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: db.Statement.CurDestIndex = 0 diff --git a/statement.go b/statement.go index 567df869..82ebdd91 100644 --- a/statement.go +++ b/statement.go @@ -451,6 +451,27 @@ func (stmt *Statement) SetColumn(name string, value interface{}) { v[name] = value } else if stmt.Schema != nil { if field := stmt.Schema.LookUpField(name); field != nil { + destValue := reflect.ValueOf(stmt.Dest) + for destValue.Kind() == reflect.Ptr { + destValue = destValue.Elem() + } + + if stmt.ReflectValue != destValue { + if !destValue.CanAddr() { + destValueCanAddr := reflect.New(destValue.Type()) + destValueCanAddr.Elem().Set(destValue) + stmt.Dest = destValueCanAddr.Interface() + destValue = destValueCanAddr.Elem() + } + + switch destValue.Kind() { + case reflect.Struct: + field.Set(destValue, value) + default: + stmt.AddError(ErrInvalidData) + } + } + switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) @@ -467,11 +488,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}) { // Changed check model changed or not when updating func (stmt *Statement) Changed(fields ...string) bool { - modelValue := reflect.ValueOf(stmt.Model) - for modelValue.Kind() == reflect.Ptr { - modelValue = modelValue.Elem() - } - + modelValue := stmt.ReflectValue switch modelValue.Kind() { case reflect.Slice, reflect.Array: modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex) @@ -488,8 +505,13 @@ func (stmt *Statement) Changed(fields ...string) bool { return !utils.AssertEqual(fv, fieldValue) } } else { - changedValue, _ := field.ValueOf(stmt.ReflectValue) - return !utils.AssertEqual(changedValue, fieldValue) + destValue := reflect.ValueOf(stmt.Dest) + for destValue.Kind() == reflect.Ptr { + destValue = destValue.Elem() + } + + changedValue, zero := field.ValueOf(destValue) + return !zero && !utils.AssertEqual(changedValue, fieldValue) } } return false diff --git a/tests/go.mod b/tests/go.mod index 3fa011f1..55495de3 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -10,7 +10,7 @@ require ( gorm.io/driver/postgres v1.0.5 gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlserver v1.0.5 - gorm.io/gorm v1.20.4 + gorm.io/gorm v1.20.5 ) replace gorm.io/gorm => ../ diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 3612857b..d8b1770e 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -354,10 +354,20 @@ func TestSetColumn(t *testing.T) { AssertEqual(t, result, product) - // Code changed, price not selected, price should not change - DB.Model(&product).Select("code").Updates(map[string]interface{}{"name": "L1214"}) + // Select to change Code, but nothing updated, price should not change + DB.Model(&product).Select("code").Updates(Product3{Name: "L1214", Code: "L1213"}) - if product.Price != 220 || product.Code != "L1213" { + if product.Price != 220 || product.Code != "L1213" || product.Name != "Product New3" { + t.Errorf("invalid data after update, got %+v", product) + } + + DB.Model(&product).Updates(Product3{Code: "L1214"}) + if product.Price != 270 || product.Code != "L1214" { + t.Errorf("invalid data after update, got %+v", product) + } + + DB.Model(&product).UpdateColumns(Product3{Code: "L1215"}) + if product.Price != 270 || product.Code != "L1215" { t.Errorf("invalid data after update, got %+v", product) } From a8141b6cc92b15d7d6f7936942749a5e044f9c9a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 30 Oct 2020 18:15:07 +0800 Subject: [PATCH 708/881] Fix DeletedAt marshal and unmarshal, close #3693 --- soft_delete.go | 2 +- tests/soft_delete_test.go | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/soft_delete.go b/soft_delete.go index b3280ff7..f3272246 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -31,7 +31,7 @@ func (n DeletedAt) MarshalJSON() ([]byte, error) { func (n *DeletedAt) UnmarshalJSON(b []byte) error { err := json.Unmarshal(b, &n.Time) - if err == nil { + if err == nil && !n.Time.IsZero() { n.Valid = true } return err diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go index 40d46fd8..c77675f7 100644 --- a/tests/soft_delete_test.go +++ b/tests/soft_delete_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "encoding/json" "errors" "testing" @@ -42,3 +43,14 @@ func TestSoftDelete(t *testing.T) { t.Errorf("Can't find permanently deleted record") } } + +func TestDeletedAtUnMarshal(t *testing.T) { + expected := &gorm.Model{} + b, _ := json.Marshal(expected) + + result := &gorm.Model{} + _ = json.Unmarshal(b, result) + if result.DeletedAt != expected.DeletedAt { + t.Errorf("Failed, result.DeletedAt: %v is not same as expected.DeletedAt: %v", result.DeletedAt, expected.DeletedAt) + } +} From 3ebdcbdb180b9b89e7f270c22640e5ae4ba22f5b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 30 Oct 2020 19:08:20 +0800 Subject: [PATCH 709/881] Marshal invalid DeletedAt as null, fix #3693 --- soft_delete.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/soft_delete.go b/soft_delete.go index f3272246..b68cee43 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -26,7 +26,10 @@ func (n DeletedAt) Value() (driver.Value, error) { } func (n DeletedAt) MarshalJSON() ([]byte, error) { - return json.Marshal(n.Time) + if n.Valid { + return json.Marshal(n.Time) + } + return json.Marshal(nil) } func (n *DeletedAt) UnmarshalJSON(b []byte) error { From 57b033e2dd17b89d171570475b706c5bc671f52f Mon Sep 17 00:00:00 2001 From: Amit Basuri Date: Mon, 2 Nov 2020 07:33:39 +0530 Subject: [PATCH 710/881] Marshalling zero valued Deleted at to nullhttps://github.com/go-gorm/gorm/issues/3693 (#3695) --- soft_delete.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/soft_delete.go b/soft_delete.go index b68cee43..284129a1 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -33,8 +33,12 @@ func (n DeletedAt) MarshalJSON() ([]byte, error) { } func (n *DeletedAt) UnmarshalJSON(b []byte) error { + if string(b) == "null" { + n.Valid = false + return nil + } err := json.Unmarshal(b, &n.Time) - if err == nil && !n.Time.IsZero() { + if err == nil { n.Valid = true } return err From c915471169b7e6696edfa9bfc2c8e7b816e70ad6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 3 Nov 2020 10:30:05 +0800 Subject: [PATCH 711/881] Support Expression for OrderBy clause --- clause/expression.go | 7 ++++--- clause/order_by.go | 21 +++++++++++++-------- clause/order_by_test.go | 8 ++++++++ tests/query_test.go | 10 ++++++++++ 4 files changed, 35 insertions(+), 11 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 5822a314..725a4909 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -19,8 +19,9 @@ type NegationExpressionBuilder interface { // Expr raw expression type Expr struct { - SQL string - Vars []interface{} + SQL string + Vars []interface{} + WithoutParentheses bool } // Build build raw expression @@ -32,7 +33,7 @@ func (expr Expr) Build(builder Builder) { for _, v := range []byte(expr.SQL) { if v == '?' && len(expr.Vars) > idx { - if afterParenthesis { + if afterParenthesis || expr.WithoutParentheses { if _, ok := expr.Vars[idx].(driver.Valuer); ok { builder.AddVar(builder, expr.Vars[idx]) } else { diff --git a/clause/order_by.go b/clause/order_by.go index a8a9539a..41218025 100644 --- a/clause/order_by.go +++ b/clause/order_by.go @@ -7,7 +7,8 @@ type OrderByColumn struct { } type OrderBy struct { - Columns []OrderByColumn + Columns []OrderByColumn + Expression Expression } // Name where clause name @@ -17,14 +18,18 @@ func (orderBy OrderBy) Name() string { // Build build where clause func (orderBy OrderBy) Build(builder Builder) { - for idx, column := range orderBy.Columns { - if idx > 0 { - builder.WriteByte(',') - } + if orderBy.Expression != nil { + orderBy.Expression.Build(builder) + } else { + for idx, column := range orderBy.Columns { + if idx > 0 { + builder.WriteByte(',') + } - builder.WriteQuoted(column.Column) - if column.Desc { - builder.WriteString(" DESC") + builder.WriteQuoted(column.Column) + if column.Desc { + builder.WriteString(" DESC") + } } } } diff --git a/clause/order_by_test.go b/clause/order_by_test.go index 2ea2d192..8fd1e2a8 100644 --- a/clause/order_by_test.go +++ b/clause/order_by_test.go @@ -39,6 +39,14 @@ func TestOrderBy(t *testing.T) { }, "SELECT * FROM `users` ORDER BY `name`", nil, }, + { + []clause.Interface{ + clause.Select{}, clause.From{}, clause.OrderBy{ + Expression: clause.Expr{SQL: "FIELD(id, ?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, + }, + }, + "SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)", []interface{}{1, 2, 3}, + }, } for idx, result := range results { diff --git a/tests/query_test.go b/tests/query_test.go index bb9aa26d..dc2907e6 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -12,6 +12,7 @@ import ( "time" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -659,6 +660,15 @@ func TestOrder(t *testing.T) { if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY age desc,name").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) } + + stmt := dryDB.Clauses(clause.OrderBy{ + Expression: clause.Expr{SQL: "FIELD(id,?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, + }).Find(&User{}).Statement + + explainedSQL := dryDB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY FIELD\\(id,1,2,3\\)").MatchString(explainedSQL) { + t.Fatalf("Build Order condition, but got %v", explainedSQL) + } } func TestLimit(t *testing.T) { From 560d303e71eb75dc77a115f0d0cba26b645b172f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 4 Nov 2020 11:03:22 +0800 Subject: [PATCH 712/881] Fix Scan with soft delete, close #3712 --- callbacks/query.go | 214 +++++++++++++++++++------------------- callbacks/row.go | 4 +- tests/soft_delete_test.go | 18 ++++ 3 files changed, 126 insertions(+), 110 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 0703b92e..8613e46d 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -13,15 +13,7 @@ import ( func Query(db *gorm.DB) { if db.Error == nil { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.QueryClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.String() == "" { - BuildQuerySQL(db) - } + BuildQuerySQL(db) if !db.DryRun && db.Error == nil { rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) @@ -37,131 +29,139 @@ func Query(db *gorm.DB) { } func BuildQuerySQL(db *gorm.DB) { - db.Statement.SQL.Grow(100) - clauseSelect := clause.Select{Distinct: db.Statement.Distinct} - - if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { - var conds []clause.Expression - for _, primaryField := range db.Statement.Schema.PrimaryFields { - if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { - conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) - } - } - - if len(conds) > 0 { - db.Statement.AddClause(clause.Where{Exprs: conds}) + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.QueryClauses { + db.Statement.AddClause(c) } } - if len(db.Statement.Selects) > 0 { - clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects)) - for idx, name := range db.Statement.Selects { - if db.Statement.Schema == nil { - clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} - } else if f := db.Statement.Schema.LookUpField(name); f != nil { - clauseSelect.Columns[idx] = clause.Column{Name: f.DBName} - } else { - clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} + if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(100) + clauseSelect := clause.Select{Distinct: db.Statement.Distinct} + + if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { + var conds []clause.Expression + for _, primaryField := range db.Statement.Schema.PrimaryFields { + if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { + conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) + } + } + + if len(conds) > 0 { + db.Statement.AddClause(clause.Where{Exprs: conds}) } } - } else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 { - selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false) - clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames)) - for _, dbName := range db.Statement.Schema.DBNames { - if v, ok := selectColumns[dbName]; (ok && v) || !ok { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Name: dbName}) + + if len(db.Statement.Selects) > 0 { + clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects)) + for idx, name := range db.Statement.Selects { + if db.Statement.Schema == nil { + clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} + } else if f := db.Statement.Schema.LookUpField(name); f != nil { + clauseSelect.Columns[idx] = clause.Column{Name: f.DBName} + } else { + clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} + } + } + } else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 { + selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false) + clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames)) + for _, dbName := range db.Statement.Schema.DBNames { + if v, ok := selectColumns[dbName]; (ok && v) || !ok { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Name: dbName}) + } + } + } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() { + smallerStruct := false + switch db.Statement.ReflectValue.Kind() { + case reflect.Struct: + smallerStruct = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType + case reflect.Slice: + smallerStruct = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType } - } - } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() { - smallerStruct := false - switch db.Statement.ReflectValue.Kind() { - case reflect.Struct: - smallerStruct = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType - case reflect.Slice: - smallerStruct = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType - } - if smallerStruct { - stmt := gorm.Statement{DB: db} - // smaller struct - if err := stmt.Parse(db.Statement.Dest); err == nil && stmt.Schema.ModelType != db.Statement.Schema.ModelType { - clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) + if smallerStruct { + stmt := gorm.Statement{DB: db} + // smaller struct + if err := stmt.Parse(db.Statement.Dest); err == nil && stmt.Schema.ModelType != db.Statement.Schema.ModelType { + clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) - for idx, dbName := range stmt.Schema.DBNames { - clauseSelect.Columns[idx] = clause.Column{Name: dbName} + for idx, dbName := range stmt.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Name: dbName} + } } } } - } - // inline joins - if len(db.Statement.Joins) != 0 { - if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil { - clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) - for idx, dbName := range db.Statement.Schema.DBNames { - clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} + // inline joins + if len(db.Statement.Joins) != 0 { + if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil { + clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) + for idx, dbName := range db.Statement.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} + } } - } - joins := []clause.Join{} - for _, join := range db.Statement.Joins { - if db.Statement.Schema == nil { - joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, - }) - } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { - tableAliasName := relation.Name - - for _, s := range relation.FieldSchema.DBNames { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: tableAliasName, - Name: s, - Alias: tableAliasName + "__" + s, + joins := []clause.Join{} + for _, join := range db.Statement.Joins { + if db.Statement.Schema == nil { + joins = append(joins, clause.Join{ + Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, }) - } + } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { + tableAliasName := relation.Name - exprs := make([]clause.Expression, len(relation.References)) - for idx, ref := range relation.References { - if ref.OwnPrimaryKey { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - } - } else { - if ref.PrimaryValue == "" { + for _, s := range relation.FieldSchema.DBNames { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: tableAliasName, + Name: s, + Alias: tableAliasName + "__" + s, + }) + } + + exprs := make([]clause.Expression, len(relation.References)) + for idx, ref := range relation.References { + if ref.OwnPrimaryKey { exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, } } else { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, + if ref.PrimaryValue == "" { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + } + } else { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + } } } } - } - joins = append(joins, clause.Join{ - Type: clause.LeftJoin, - Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: exprs}, - }) - } else { - joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, - }) + joins = append(joins, clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + }) + } else { + joins = append(joins, clause.Join{ + Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, + }) + } } + + db.Statement.AddClause(clause.From{Joins: joins}) + } else { + db.Statement.AddClauseIfNotExists(clause.From{}) } - db.Statement.AddClause(clause.From{Joins: joins}) - } else { - db.Statement.AddClauseIfNotExists(clause.From{}) + db.Statement.AddClauseIfNotExists(clauseSelect) + + db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } - - db.Statement.AddClauseIfNotExists(clauseSelect) - - db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } func Preload(db *gorm.DB) { diff --git a/callbacks/row.go b/callbacks/row.go index 4f985d7b..10e880e1 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -6,9 +6,7 @@ import ( func RowQuery(db *gorm.DB) { if db.Error == nil { - if db.Statement.SQL.String() == "" { - BuildQuerySQL(db) - } + BuildQuerySQL(db) if !db.DryRun { if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) { diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go index c77675f7..283a4c34 100644 --- a/tests/soft_delete_test.go +++ b/tests/soft_delete_test.go @@ -14,10 +14,16 @@ func TestSoftDelete(t *testing.T) { DB.Save(&user) var count int64 + var age uint + if DB.Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 1 { t.Errorf("Count soft deleted record, expects: %v, got: %v", 1, count) } + if DB.Model(&User{}).Select("age").Where("name = ?", user.Name).Scan(&age).Error != nil || age != user.Age { + t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, age) + } + if err := DB.Delete(&user).Error; err != nil { t.Fatalf("No error should happen when soft delete user, but got %v", err) } @@ -26,18 +32,30 @@ func TestSoftDelete(t *testing.T) { t.Errorf("Can't find a soft deleted record") } + count = 0 if DB.Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 0 { t.Errorf("Count soft deleted record, expects: %v, got: %v", 0, count) } + age = 0 + if DB.Model(&User{}).Select("age").Where("name = ?", user.Name).Scan(&age).Error != nil || age != 0 { + t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, age) + } + if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err) } + count = 0 if DB.Unscoped().Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 1 { t.Errorf("Count soft deleted record, expects: %v, count: %v", 1, count) } + age = 0 + if DB.Unscoped().Model(&User{}).Select("age").Where("name = ?", user.Name).Scan(&age).Error != nil || age != user.Age { + t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, age) + } + DB.Unscoped().Delete(&user) if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("Can't find permanently deleted record") From fcf2ab6c0ee201e95ce9d30b69f33b507e8e45ff Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 5 Nov 2020 11:20:08 +0800 Subject: [PATCH 713/881] Add deleted_at check when soft deleting, fix #3720 --- callbacks/delete.go | 2 +- soft_delete.go | 6 ++++++ tests/delete_test.go | 2 +- tests/soft_delete_test.go | 6 ++++++ 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index 85f11f4b..0f4bcd6b 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -135,7 +135,7 @@ func Delete(db *gorm.DB) { db.Statement.Build("DELETE", "FROM", "WHERE") } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil { db.AddError(gorm.ErrMissingWhereClause) return } diff --git a/soft_delete.go b/soft_delete.go index 284129a1..cb56035d 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -124,6 +124,12 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } } + if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { + stmt.DB.AddError(ErrMissingWhereClause) + } else { + SoftDeleteQueryClause{Field: sd.Field}.ModifyStatement(stmt) + } + stmt.AddClauseIfNotExists(clause.Update{}) stmt.Build("UPDATE", "SET", "WHERE") } diff --git a/tests/delete_test.go b/tests/delete_test.go index ecd5ec39..954c7097 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -49,7 +49,7 @@ func TestDelete(t *testing.T) { t.Errorf("errors happened when delete: %v", err) } - if err := DB.Delete(User{}).Error; err != gorm.ErrMissingWhereClause { + if err := DB.Delete(&User{}).Error; err != gorm.ErrMissingWhereClause { t.Errorf("errors happened when delete: %v", err) } diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go index 283a4c34..f1ea8a51 100644 --- a/tests/soft_delete_test.go +++ b/tests/soft_delete_test.go @@ -3,6 +3,7 @@ package tests_test import ( "encoding/json" "errors" + "regexp" "testing" "gorm.io/gorm" @@ -28,6 +29,11 @@ func TestSoftDelete(t *testing.T) { t.Fatalf("No error should happen when soft delete user, but got %v", err) } + sql := DB.Session(&gorm.Session{DryRun: true}).Delete(&user).Statement.SQL.String() + if !regexp.MustCompile(`UPDATE .users. SET .deleted_at.=.* WHERE .users.\..id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + if DB.First(&User{}, "name = ?", user.Name).Error == nil { t.Errorf("Can't find a soft deleted record") } From 85e9f66d2652a4a4c422f22c3e7bf24fd7a2c33c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 5 Nov 2020 11:43:21 +0800 Subject: [PATCH 714/881] Fix create index for other database/schema, close #3698 --- migrator/migrator.go | 47 +++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 9493a00c..016ebfc7 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -158,14 +158,10 @@ func (m Migrator) CreateTable(values ...interface{}) error { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { var ( createTableSQL = "CREATE TABLE ? (" - values = []interface{}{clause.Table{Name: stmt.Table}} + values = []interface{}{m.CurrentTable(stmt)} hasPrimaryKeyInDataType bool ) - if stmt.TableExpr != nil { - values[0] = *stmt.TableExpr - } - for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.FieldsByDBName[dbName] createTableSQL += "? ?" @@ -243,7 +239,7 @@ func (m Migrator) DropTable(values ...interface{}) error { for i := len(values) - 1; i >= 0; i-- { tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { - return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error + return tx.Exec("DROP TABLE IF EXISTS ?", m.CurrentTable(stmt)).Error }); err != nil { return err } @@ -263,30 +259,30 @@ func (m Migrator) HasTable(value interface{}) bool { } func (m Migrator) RenameTable(oldName, newName interface{}) error { - var oldTable, newTable string + var oldTable, newTable interface{} if v, ok := oldName.(string); ok { - oldTable = v + oldTable = clause.Table{Name: v} } else { stmt := &gorm.Statement{DB: m.DB} if err := stmt.Parse(oldName); err == nil { - oldTable = stmt.Table + oldTable = m.CurrentTable(stmt) } else { return err } } if v, ok := newName.(string); ok { - newTable = v + newTable = clause.Table{Name: v} } else { stmt := &gorm.Statement{DB: m.DB} if err := stmt.Parse(newName); err == nil { - newTable = stmt.Table + newTable = m.CurrentTable(stmt) } else { return err } } - return m.DB.Exec("ALTER TABLE ? RENAME TO ?", clause.Table{Name: oldTable}, clause.Table{Name: newTable}).Error + return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error } func (m Migrator) AddColumn(value interface{}, field string) error { @@ -294,7 +290,7 @@ func (m Migrator) AddColumn(value interface{}, field string) error { if field := stmt.Schema.LookUpField(field); field != nil { return m.DB.Exec( "ALTER TABLE ? ADD ? ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), ).Error } return fmt.Errorf("failed to look up field with name: %s", field) @@ -308,7 +304,7 @@ func (m Migrator) DropColumn(value interface{}, name string) error { } return m.DB.Exec( - "ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: name}, + "ALTER TABLE ? DROP COLUMN ?", m.CurrentTable(stmt), clause.Column{Name: name}, ).Error }) } @@ -319,7 +315,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { fileType := clause.Expr{SQL: m.DataTypeOf(field)} return m.DB.Exec( "ALTER TABLE ? ALTER COLUMN ? TYPE ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, fileType, + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, ).Error } @@ -357,7 +353,7 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error return m.DB.Exec( "ALTER TABLE ? RENAME COLUMN ? TO ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, + m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName}, ).Error }) } @@ -459,14 +455,14 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { if chk, ok := checkConstraints[name]; ok { return m.DB.Exec( "ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)", - clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, + m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, ).Error } for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { sql, values := buildConstraint(constraint) - return m.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error + return m.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{m.CurrentTable(stmt)}, values...)...).Error } } @@ -495,7 +491,7 @@ func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Exec( "ALTER TABLE ? DROP CONSTRAINT ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: name}, + m.CurrentTable(stmt), clause.Column{Name: name}, ).Error }) } @@ -542,7 +538,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) - values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} + values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} createIndexSQL := "CREATE " if idx.Class != "" { @@ -571,7 +567,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error { name = idx.Name } - return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error + return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error }) } @@ -596,7 +592,7 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Exec( "ALTER TABLE ? RENAME INDEX ? TO ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, + m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName}, ).Error }) } @@ -701,3 +697,10 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i } return } + +func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { + if stmt.TableExpr != nil { + return *stmt.TableExpr + } + return clause.Table{Name: stmt.Table} +} From 832abda7a49a134530f6b7fd734c9111f7fbc74a Mon Sep 17 00:00:00 2001 From: LeoZhan Date: Sun, 8 Nov 2020 09:41:43 +0800 Subject: [PATCH 715/881] refactor: simplify the writing instead of using struct literal (#3728) --- clause/expression.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 725a4909..40265ac6 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -202,7 +202,7 @@ func (eq Eq) Build(builder Builder) { } func (eq Eq) NegationBuild(builder Builder) { - Neq{eq.Column, eq.Value}.Build(builder) + Neq(eq).Build(builder) } // Neq not equal to for where @@ -220,7 +220,7 @@ func (neq Neq) Build(builder Builder) { } func (neq Neq) NegationBuild(builder Builder) { - Eq{neq.Column, neq.Value}.Build(builder) + Eq(neq).Build(builder) } // Gt greater than for where @@ -233,7 +233,7 @@ func (gt Gt) Build(builder Builder) { } func (gt Gt) NegationBuild(builder Builder) { - Lte{gt.Column, gt.Value}.Build(builder) + Lte(gt).Build(builder) } // Gte greater than or equal to for where @@ -246,7 +246,7 @@ func (gte Gte) Build(builder Builder) { } func (gte Gte) NegationBuild(builder Builder) { - Lt{gte.Column, gte.Value}.Build(builder) + Lt(gte).Build(builder) } // Lt less than for where @@ -259,7 +259,7 @@ func (lt Lt) Build(builder Builder) { } func (lt Lt) NegationBuild(builder Builder) { - Gte{lt.Column, lt.Value}.Build(builder) + Gte(lt).Build(builder) } // Lte less than or equal to for where @@ -272,7 +272,7 @@ func (lte Lte) Build(builder Builder) { } func (lte Lte) NegationBuild(builder Builder) { - Gt{lte.Column, lte.Value}.Build(builder) + Gt(lte).Build(builder) } // Like whether string matches regular expression From 1e241aa6455fd821102bfce366d47a646b71161e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 10 Nov 2020 18:38:24 +0800 Subject: [PATCH 716/881] Reduce GC alloc --- callbacks/associations.go | 10 +++++----- callbacks/create.go | 11 ++++------- callbacks/helper.go | 4 ++-- callbacks/preload.go | 4 ++-- gorm.go | 1 + scan.go | 26 +++++++++++++------------- schema/schema.go | 2 +- schema/utils.go | 2 +- statement.go | 9 +++++---- 9 files changed, 34 insertions(+), 35 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 64d79f24..1e6f62c5 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -46,7 +46,7 @@ func SaveBeforeAssociations(db *gorm.DB) { fieldType = reflect.PtrTo(fieldType) } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) @@ -109,7 +109,7 @@ func SaveAfterAssociations(db *gorm.DB) { fieldType = reflect.PtrTo(fieldType) } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) @@ -181,7 +181,7 @@ func SaveAfterAssociations(db *gorm.DB) { if !isPtr { fieldType = reflect.PtrTo(fieldType) } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) appendToElems := func(v reflect.Value) { if _, zero := rel.Field.ValueOf(v); !zero { f := reflect.Indirect(rel.Field.ReflectValueOf(v)) @@ -241,8 +241,8 @@ func SaveAfterAssociations(db *gorm.DB) { if !isPtr { fieldType = reflect.PtrTo(fieldType) } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) - joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 0) + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) objs := []reflect.Value{} appendToJoins := func(obj reflect.Value, elem reflect.Value) { diff --git a/callbacks/create.go b/callbacks/create.go index 8e2454e8..67f3ab14 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -57,7 +57,7 @@ func Create(config *Config) func(db *gorm.DB) { db.RowsAffected, _ = result.RowsAffected() if db.RowsAffected > 0 { if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - if insertID, err := result.LastInsertId(); err == nil { + if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: if config.LastInsertIDReversed { @@ -87,11 +87,8 @@ func Create(config *Config) func(db *gorm.DB) { } } case reflect.Struct: - if insertID > 0 { - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { - - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) - } + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) } } } else { @@ -253,7 +250,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - stmt.SQL.Grow(stmt.ReflectValue.Len() * 15) + stmt.SQL.Grow(stmt.ReflectValue.Len() * 18) values.Values = make([][]interface{}, stmt.ReflectValue.Len()) defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} if stmt.ReflectValue.Len() == 0 { diff --git a/callbacks/helper.go b/callbacks/helper.go index 09ec4582..3ac63fa1 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -12,7 +12,7 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter values.Columns = make([]clause.Column, 0, len(mapValue)) selectColumns, restricted := stmt.SelectAndOmitColumns(true, false) - var keys []string + var keys = make([]string, 0, len(mapValue)) for k := range mapValue { keys = append(keys, k) } @@ -41,7 +41,7 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter // ConvertSliceOfMapToValuesForCreate convert slice of map to values func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { var ( - columns = []string{} + columns = make([]string, 0, len(mapValues)) result = map[string][]interface{}{} selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) ) diff --git a/callbacks/preload.go b/callbacks/preload.go index aec10ec5..d60079e4 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -112,7 +112,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { case reflect.Struct: switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface()) + rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) default: rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) } @@ -120,7 +120,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { for i := 0; i < reflectValue.Len(); i++ { switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface()) + rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) default: rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) } diff --git a/gorm.go b/gorm.go index affa8e69..2dfbb855 100644 --- a/gorm.go +++ b/gorm.go @@ -286,6 +286,7 @@ func (db *DB) getInstance() *DB { ConnPool: db.Statement.ConnPool, Context: db.Statement.Context, Clauses: map[string]clause.Clause{}, + Vars: make([]interface{}, 0, 8), } } else { // with clone statement diff --git a/scan.go b/scan.go index 8d737b17..c9c8f442 100644 --- a/scan.go +++ b/scan.go @@ -106,7 +106,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { reflectValueType = reflectValueType.Elem() } - db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) + db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20)) if Schema != nil { if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { @@ -117,13 +117,13 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { if field := Schema.LookUpField(column); field != nil && field.Readable { fields[idx] = field } else if names := strings.Split(column, "__"); len(names) > 1 { - if len(joinFields) == 0 { - joinFields = make([][2]*schema.Field, len(columns)) - } - if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { fields[idx] = field + + if len(joinFields) == 0 { + joinFields = make([][2]*schema.Field, len(columns)) + } joinFields[idx] = [2]*schema.Field{rel.Field, field} continue } @@ -138,9 +138,9 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { // pluck values into slice of data isPluck := false if len(fields) == 1 { - if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok { - isPluck = true - } else if reflectValueType.Kind() != reflect.Struct || reflectValueType.ConvertibleTo(schema.TimeReflectType) { + if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner + reflectValueType.Kind() != reflect.Struct || // is not struct + Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time isPluck = true } } @@ -149,9 +149,9 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { initialized = false db.RowsAffected++ - elem := reflect.New(reflectValueType).Elem() + elem := reflect.New(reflectValueType) if isPluck { - db.AddError(rows.Scan(elem.Addr().Interface())) + db.AddError(rows.Scan(elem.Interface())) } else { for idx, field := range fields { if field != nil { @@ -181,9 +181,9 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } if isPtr { - db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr())) - } else { db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem)) + } else { + db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem())) } } case reflect.Struct: @@ -216,8 +216,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { field.Set(db.Statement.ReflectValue, values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := Schema.Relationships.Relations[names[0]]; ok { - relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) value := reflect.ValueOf(values[idx]).Elem() if relValue.Kind() == reflect.Ptr && relValue.IsNil() { diff --git a/schema/schema.go b/schema/schema.go index cffc19a7..05db641f 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -50,7 +50,7 @@ func (schema Schema) String() string { } func (schema Schema) MakeSlice() reflect.Value { - slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 0) + slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20) results := reflect.New(slice.Type()) results.Elem().Set(slice) return results diff --git a/schema/utils.go b/schema/utils.go index 55cbdeb4..6e5fd528 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -61,7 +61,7 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct // GetRelationsValues get relations's values from a reflect value func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { for _, rel := range rels { - reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 0) + reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1) appendToResults := func(value reflect.Value) { if _, isZero := rel.Field.ValueOf(value); !isZero { diff --git a/statement.go b/statement.go index 82ebdd91..7c0af59c 100644 --- a/statement.go +++ b/statement.go @@ -239,12 +239,12 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { } // BuildCondition build condition -func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) { +func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression { if s, ok := query.(string); ok { // if it is a number, then treats it as primary key if _, err := strconv.Atoi(s); err != nil { if s == "" && len(args) == 0 { - return + return nil } else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { // looks like a where condition return []clause.Expression{clause.Expr{SQL: s, Vars: args}} @@ -257,6 +257,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c } } + conds := make([]clause.Expression, 0, 4) args = append([]interface{}{query}, args...) for _, arg := range args { if valuer, ok := arg.(driver.Valuer); ok { @@ -358,7 +359,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c if len(values) > 0 { conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) } - return + return conds } } @@ -367,7 +368,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c } } - return + return conds } // Build build sql with clauses names From c1bb8e4551a5b371fbc637802a56e15b421f31f7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 16 Nov 2020 11:20:13 +0800 Subject: [PATCH 717/881] Should not display the record not found error when using FirstOrXXX, close #3748 --- finisher_api.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 857f9419..2e7e5f4e 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -186,7 +186,11 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { } func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { - if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { + queryTx := db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + + if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { tx.assignInterfacesToValue(where.Exprs) @@ -197,7 +201,6 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { if len(tx.Statement.attrs) > 0 { tx.assignInterfacesToValue(tx.Statement.attrs...) } - tx.Error = nil } // initialize with attrs, conds @@ -208,9 +211,11 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { } func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { - if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { - tx.Error = nil + queryTx := db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { tx.assignInterfacesToValue(where.Exprs) From a9f54d53fbb4cfdda6a635369229379fb73bd694 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 16 Nov 2020 12:23:13 +0800 Subject: [PATCH 718/881] Don't preload when there are any error happened --- callbacks/query.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/callbacks/query.go b/callbacks/query.go index 8613e46d..92f711f5 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -206,7 +206,9 @@ func Preload(db *gorm.DB) { } } - preload(db, rels, db.Statement.Preloads[name]) + if db.Error == nil { + preload(db, rels, db.Statement.Preloads[name]) + } } } } From a4c0c6b400586283cfd2ec74d1bb8c5c0a5dd4fb Mon Sep 17 00:00:00 2001 From: alresvor Date: Mon, 16 Nov 2020 15:16:15 +0800 Subject: [PATCH 719/881] cache converted name (#3736) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BenchmarkToName-8 2322307 521 ns/op 88 B/op 5 allocs/op ↓ BenchmarkToName-8 19997366 55.0 ns/op 0 B/op 0 allocs/op --- schema/naming.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/schema/naming.go b/schema/naming.go index dbc71e04..e3b2104a 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -95,7 +95,7 @@ func toDBName(name string) string { if name == "" { return "" } else if v, ok := smap.Load(name); ok { - return fmt.Sprint(v) + return v.(string) } var ( @@ -134,6 +134,7 @@ func toDBName(name string) string { } else { buf.WriteByte(value[len(value)-1]) } - - return buf.String() + ret := buf.String() + smap.Store(name, ret) + return ret } From 62be27d3cafd48d3dcb348bd1d17a5be31867f13 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 16 Nov 2020 20:22:08 +0800 Subject: [PATCH 720/881] Add OnConflict UpdateAll support --- callbacks/create.go | 33 ++++++++++++++++++--------------- clause/on_conflict.go | 1 + finisher_api.go | 4 +++- tests/upsert_test.go | 10 ++++++++++ 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 67f3ab14..ad91ebc3 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -329,26 +329,29 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } } - if stmt.UpdatingColumn { - if stmt.Schema != nil && len(values.Columns) > 1 { - columns := make([]string, 0, len(values.Columns)-1) - for _, column := range values.Columns { - if field := stmt.Schema.LookUpField(column.Name); field != nil { - if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 { - columns = append(columns, column.Name) + if c, ok := stmt.Clauses["ON CONFLICT"]; ok { + if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll { + if stmt.Schema != nil && len(values.Columns) > 1 { + columns := make([]string, 0, len(values.Columns)-1) + for _, column := range values.Columns { + if field := stmt.Schema.LookUpField(column.Name); field != nil { + if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 { + columns = append(columns, column.Name) + } } } - } - onConflict := clause.OnConflict{ - Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)), - DoUpdates: clause.AssignmentColumns(columns), - } + onConflict := clause.OnConflict{ + Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)), + DoUpdates: clause.AssignmentColumns(columns), + } - for idx, field := range stmt.Schema.PrimaryFields { - onConflict.Columns[idx] = clause.Column{Name: field.DBName} + for idx, field := range stmt.Schema.PrimaryFields { + onConflict.Columns[idx] = clause.Column{Name: field.DBName} + } + + stmt.AddClause(onConflict) } - stmt.AddClause(onConflict) } } diff --git a/clause/on_conflict.go b/clause/on_conflict.go index 47f69fc9..47fe169c 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -5,6 +5,7 @@ type OnConflict struct { Where Where DoNothing bool DoUpdates Set + UpdateAll bool } func (OnConflict) Name() string { diff --git a/finisher_api.go b/finisher_api.go index 2e7e5f4e..67423b23 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -29,7 +29,9 @@ func (db *DB) Save(value interface{}) (tx *DB) { reflectValue := reflect.Indirect(reflect.ValueOf(value)) switch reflectValue.Kind() { case reflect.Slice, reflect.Array: - tx.Statement.UpdatingColumn = true + if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { + tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) + } tx.callbacks.Create().Execute(tx) case reflect.Struct: if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { diff --git a/tests/upsert_test.go b/tests/upsert_test.go index ba7c1a9d..0ba8b9f0 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -41,6 +41,16 @@ func TestUpsert(t *testing.T) { } else if langs[0].Name != "upsert-new" { t.Errorf("should update name on conflict, but got name %+v", langs[0].Name) } + + lang = Language{Code: "upsert", Name: "Upsert-Newname"} + if err := DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&lang).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + var result Language + if err := DB.Find(&result, "code = ?", lang.Code).Error; err != nil || result.Name != lang.Name { + t.Fatalf("failed to upsert, got name %v", result.Name) + } } func TestUpsertSlice(t *testing.T) { From a8db54afd665dafe763e0d2d881d57fb602fd30d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 16 Nov 2020 21:42:30 +0800 Subject: [PATCH 721/881] Add CreateInBatches supports --- finisher_api.go | 23 +++++++++++++++++++++++ tests/create_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/finisher_api.go b/finisher_api.go index 67423b23..c9e2a3b2 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -21,6 +21,29 @@ func (db *DB) Create(value interface{}) (tx *DB) { return } +// CreateInBatches insert the value in batches into database +func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + tx = db.getInstance() + for i := 0; i < reflectValue.Len(); i += batchSize { + tx.AddError(tx.Transaction(func(tx *DB) error { + ends := i + batchSize + if ends > reflectValue.Len() { + ends = reflectValue.Len() + } + + return tx.Create(reflectValue.Slice(i, ends).Interface()).Error + })) + } + default: + return db.Create(value) + } + return +} + // Save update value in database, if the value doesn't have primary key, will insert it func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() diff --git a/tests/create_test.go b/tests/create_test.go index 00674eec..8d005d0b 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -40,6 +40,32 @@ func TestCreate(t *testing.T) { } } +func TestCreateInBatches(t *testing.T) { + users := []User{ + *GetUser("create_in_batches_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), + *GetUser("create_in_batches_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), + *GetUser("create_in_batches_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), + *GetUser("create_in_batches_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), + *GetUser("create_in_batches_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), + *GetUser("create_in_batches_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), + } + + DB.CreateInBatches(&users, 2) + + for _, user := range users { + if user.ID == 0 { + t.Fatalf("failed to fill user's ID, got %v", user.ID) + } else { + var newUser User + if err := DB.Where("id = ?", user.ID).Preload(clause.Associations).First(&newUser).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + CheckUser(t, newUser, user) + } + } + } +} + func TestCreateFromMap(t *testing.T) { if err := DB.Model(&User{}).Create(map[string]interface{}{"Name": "create_from_map", "Age": 18}).Error; err != nil { t.Fatalf("failed to create data from map, got error: %v", err) From 320f33061caf42da9397101157a91323043d4c0a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 17 Nov 2020 11:19:04 +0800 Subject: [PATCH 722/881] Fix FindInBatches to modify the query conditions, close #3734 --- finisher_api.go | 21 +++++++++++++++------ tests/query_test.go | 13 +++++++++++++ 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index c9e2a3b2..211e2f8f 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -140,13 +140,18 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { } // FindInBatches find records in batches -func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) (tx *DB) { - tx = db.Session(&Session{WithConditions: true}) - rowsAffected := int64(0) - batch := 0 +func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { + var ( + tx = db.Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }).Session(&Session{WithConditions: true}) + queryDB = tx + rowsAffected int64 + batch int + ) for { - result := tx.Limit(batchSize).Offset(batch * batchSize).Find(dest) + result := queryDB.Limit(batchSize).Find(dest) rowsAffected += result.RowsAffected batch++ @@ -156,11 +161,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat if tx.Error != nil || int(result.RowsAffected) < batchSize { break + } else { + resultsValue := reflect.Indirect(reflect.ValueOf(dest)) + primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) + queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } } tx.RowsAffected = rowsAffected - return + return tx } func (tx *DB) assignInterfacesToValue(values ...interface{}) { diff --git a/tests/query_test.go b/tests/query_test.go index dc2907e6..bb77dfae 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -260,6 +260,13 @@ func TestFindInBatches(t *testing.T) { if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { totalBatch += batch + for idx := range results { + results[idx].Name = results[idx].Name + "_new" + } + if err := tx.Save(results).Error; err != nil { + t.Errorf("failed to save users, got error %v", err) + } + if tx.RowsAffected != 2 { t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected) } @@ -276,6 +283,12 @@ func TestFindInBatches(t *testing.T) { if totalBatch != 6 { t.Errorf("incorrect total batch, expects: %v, got %v", 6, totalBatch) } + + var count int64 + DB.Model(&User{}).Where("name = ?", "find_in_batches_new").Count(&count) + if count != 6 { + t.Errorf("incorrect count after update, expects: %v, got %v", 6, count) + } } func TestFillSmallerStruct(t *testing.T) { From f5c2126c29e375955b4db406fe6c6440f5c46b8e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 17 Nov 2020 13:14:34 +0800 Subject: [PATCH 723/881] Fix FindInBatches tests --- callbacks/create.go | 2 ++ tests/query_test.go | 15 ++++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index ad91ebc3..aec0afe9 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -55,6 +55,7 @@ func Create(config *Config) func(db *gorm.DB) { if err == nil { db.RowsAffected, _ = result.RowsAffected() + if db.RowsAffected > 0 { if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { @@ -138,6 +139,7 @@ func CreateWithReturning(db *gorm.DB) { } if !db.DryRun && db.Error == nil { + db.RowsAffected = 0 rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { diff --git a/tests/query_test.go b/tests/query_test.go index bb77dfae..20968c7e 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -260,13 +260,6 @@ func TestFindInBatches(t *testing.T) { if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { totalBatch += batch - for idx := range results { - results[idx].Name = results[idx].Name + "_new" - } - if err := tx.Save(results).Error; err != nil { - t.Errorf("failed to save users, got error %v", err) - } - if tx.RowsAffected != 2 { t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected) } @@ -275,6 +268,14 @@ func TestFindInBatches(t *testing.T) { t.Errorf("Incorrect users length, expects: 2, got %v", len(results)) } + for idx := range results { + results[idx].Name = results[idx].Name + "_new" + } + + if err := tx.Save(results).Error; err != nil { + t.Errorf("failed to save users, got error %v", err) + } + return nil }); result.Error != nil || result.RowsAffected != 6 { t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) From f6e1786ca28f671b8d045524e5ec3b1cbfd1b1e9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 17 Nov 2020 15:19:58 +0800 Subject: [PATCH 724/881] Add skip hooks support --- callbacks/create.go | 4 ++-- gorm.go | 11 +++++++++-- tests/hooks_test.go | 5 +++++ 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index aec0afe9..a58549a5 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -10,7 +10,7 @@ import ( ) func BeforeCreate(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { if i, ok := value.(BeforeSaveInterface); ok { @@ -203,7 +203,7 @@ func CreateWithReturning(db *gorm.DB) { } func AfterCreate(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { if i, ok := value.(AfterSaveInterface); ok { diff --git a/gorm.go b/gorm.go index 2dfbb855..3bf2479a 100644 --- a/gorm.go +++ b/gorm.go @@ -64,6 +64,7 @@ type Session struct { DryRun bool PrepareStmt bool WithConditions bool + SkipHooks bool SkipDefaultTransaction bool AllowGlobalUpdate bool FullSaveAssociations bool @@ -169,15 +170,17 @@ func (db *DB) Session(config *Session) *DB { txConfig.FullSaveAssociations = true } - if config.Context != nil { + if config.Context != nil || config.PrepareStmt || config.SkipHooks { tx.Statement = tx.Statement.clone() tx.Statement.DB = tx + } + + if config.Context != nil { tx.Statement.Context = config.Context } if config.PrepareStmt { if v, ok := db.cacheStore.Load("preparedStmt"); ok { - tx.Statement = tx.Statement.clone() preparedStmt := v.(*PreparedStmtDB) tx.Statement.ConnPool = &PreparedStmtDB{ ConnPool: db.Config.ConnPool, @@ -189,6 +192,10 @@ func (db *DB) Session(config *Session) *DB { } } + if config.SkipHooks { + tx.Statement.UpdatingColumn = true + } + if config.WithConditions { tx.clone = 2 } diff --git a/tests/hooks_test.go b/tests/hooks_test.go index d8b1770e..7e3ae4e4 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -371,6 +371,11 @@ func TestSetColumn(t *testing.T) { t.Errorf("invalid data after update, got %+v", product) } + DB.Model(&product).Session(&gorm.Session{SkipHooks: true}).Updates(Product3{Code: "L1216"}) + if product.Price != 270 || product.Code != "L1216" { + t.Errorf("invalid data after update, got %+v", product) + } + var result2 Product3 DB.First(&result2, product.ID) From 26504f5caeb8c31dff62e8ddab68cee6b85a6580 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 17 Nov 2020 15:41:17 +0800 Subject: [PATCH 725/881] Use NewDB to replace WithConditions for Session --- association.go | 4 ++-- callbacks/associations.go | 14 +++++++------- callbacks/callmethod.go | 2 +- callbacks/delete.go | 4 ++-- callbacks/preload.go | 2 +- finisher_api.go | 8 ++++---- gorm.go | 9 ++++----- migrator.go | 2 +- migrator/migrator.go | 8 ++++---- statement.go | 2 +- tests/count_test.go | 2 +- tests/hooks_test.go | 7 +++++++ 12 files changed, 35 insertions(+), 29 deletions(-) diff --git a/association.go b/association.go index 140ae6ac..0f2102f7 100644 --- a/association.go +++ b/association.go @@ -417,7 +417,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) // TODO support save slice data, sql with case? - association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Index(i).Addr().Interface()).Error + association.Error = association.DB.Session(&Session{NewDB: true}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Index(i).Addr().Interface()).Error } case reflect.Struct: // clear old data @@ -439,7 +439,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } if len(values) > 0 { - association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Addr().Interface()).Error + association.Error = association.DB.Session(&Session{NewDB: true}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Addr().Interface()).Error } } diff --git a/callbacks/associations.go b/callbacks/associations.go index 1e6f62c5..1702f442 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -66,7 +66,7 @@ func SaveBeforeAssociations(db *gorm.DB) { } if elems.Len() > 0 { - if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil { + if db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } @@ -79,7 +79,7 @@ func SaveBeforeAssociations(db *gorm.DB) { rv = rv.Addr() } - if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil { + if db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil { setupReferences(db.Statement.ReflectValue, rv) } } @@ -141,7 +141,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses( + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses( onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), ).Create(elems.Interface()).Error) } @@ -163,7 +163,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses( + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses( onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), ).Create(f.Interface()).Error) } @@ -224,7 +224,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses( + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses( onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), ).Create(elems.Interface()).Error) } @@ -291,7 +291,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if elems.Len() > 0 { - db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) for i := 0; i < elems.Len(); i++ { appendToJoins(objs[i], elems.Index(i)) @@ -299,7 +299,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if joins.Len() > 0 { - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()).Error) + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()).Error) } } } diff --git a/callbacks/callmethod.go b/callbacks/callmethod.go index b81fc915..bcaa03f3 100644 --- a/callbacks/callmethod.go +++ b/callbacks/callmethod.go @@ -7,7 +7,7 @@ import ( ) func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { - tx := db.Session(&gorm.Session{}) + tx := db.Session(&gorm.Session{NewDB: true}) if called := fc(db.Statement.ReflectValue.Interface(), tx); !called { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: diff --git a/callbacks/delete.go b/callbacks/delete.go index 0f4bcd6b..4a289e0c 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -34,7 +34,7 @@ func DeleteBeforeAssociations(db *gorm.DB) { case schema.HasOne, schema.HasMany: queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() - tx := db.Session(&gorm.Session{}).Model(modelValue) + tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) withoutConditions := false if len(db.Statement.Selects) > 0 { @@ -71,7 +71,7 @@ func DeleteBeforeAssociations(db *gorm.DB) { relForeignKeys []string modelValue = reflect.New(rel.JoinTable.ModelType).Interface() table = rel.JoinTable.Table - tx = db.Session(&gorm.Session{}).Model(modelValue).Table(table) + tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) ) for _, ref := range rel.References { diff --git a/callbacks/preload.go b/callbacks/preload.go index d60079e4..e1dfdace 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -13,7 +13,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { var ( reflectValue = db.Statement.ReflectValue rel = rels[len(rels)-1] - tx = db.Session(&gorm.Session{}) + tx = db.Session(&gorm.Session{NewDB: true}) relForeignKeys []string relForeignFields []*schema.Field foreignFields []*schema.Field diff --git a/finisher_api.go b/finisher_api.go index 211e2f8f..d1390a15 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -78,7 +78,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { result := reflect.New(tx.Statement.Schema.ModelType).Interface() - if err := tx.Session(&Session{WithConditions: true}).First(result).Error; errors.Is(err, ErrRecordNotFound) { + if err := tx.Session(&Session{}).First(result).Error; errors.Is(err, ErrRecordNotFound) { return tx.Create(value) } } @@ -144,7 +144,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat var ( tx = db.Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, - }).Session(&Session{WithConditions: true}) + }).Session(&Session{}) queryDB = tx rowsAffected int64 batch int @@ -480,7 +480,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } }() - err = fc(db.Session(&Session{WithConditions: true})) + err = fc(db.Session(&Session{})) } else { tx := db.Begin(opts...) @@ -506,7 +506,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er func (db *DB) Begin(opts ...*sql.TxOptions) *DB { var ( // clone statement - tx = db.Session(&Session{WithConditions: true, Context: db.Statement.Context}) + tx = db.Session(&Session{Context: db.Statement.Context}) opt *sql.TxOptions err error ) diff --git a/gorm.go b/gorm.go index 3bf2479a..f7c18b08 100644 --- a/gorm.go +++ b/gorm.go @@ -63,7 +63,7 @@ type DB struct { type Session struct { DryRun bool PrepareStmt bool - WithConditions bool + NewDB bool SkipHooks bool SkipDefaultTransaction bool AllowGlobalUpdate bool @@ -196,7 +196,7 @@ func (db *DB) Session(config *Session) *DB { tx.Statement.UpdatingColumn = true } - if config.WithConditions { + if !config.NewDB { tx.clone = 2 } @@ -217,14 +217,13 @@ func (db *DB) Session(config *Session) *DB { // WithContext change current instance db's context to ctx func (db *DB) WithContext(ctx context.Context) *DB { - return db.Session(&Session{WithConditions: true, Context: ctx}) + return db.Session(&Session{Context: ctx}) } // Debug start debug mode func (db *DB) Debug() (tx *DB) { return db.Session(&Session{ - WithConditions: true, - Logger: db.Logger.LogMode(logger.Info), + Logger: db.Logger.LogMode(logger.Info), }) } diff --git a/migrator.go b/migrator.go index ac06a144..28ac35e7 100644 --- a/migrator.go +++ b/migrator.go @@ -7,7 +7,7 @@ import ( // Migrator returns migrator func (db *DB) Migrator() Migrator { - return db.Dialector.Migrator(db.Session(&Session{WithConditions: true})) + return db.Dialector.Migrator(db.Session(&Session{})) } // AutoMigrate run auto migration for given models diff --git a/migrator/migrator.go b/migrator/migrator.go index 016ebfc7..5de820a8 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -82,7 +82,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { // AutoMigrate func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { - tx := m.DB.Session(&gorm.Session{}) + tx := m.DB.Session(&gorm.Session{NewDB: true}) if !tx.Migrator().HasTable(value) { if err := tx.Migrator().CreateTable(value); err != nil { return err @@ -154,7 +154,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { - tx := m.DB.Session(&gorm.Session{}) + tx := m.DB.Session(&gorm.Session{NewDB: true}) if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { var ( createTableSQL = "CREATE TABLE ? (" @@ -237,7 +237,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { func (m Migrator) DropTable(values ...interface{}) error { values = m.ReorderModels(values, false) for i := len(values) - 1; i >= 0; i-- { - tx := m.DB.Session(&gorm.Session{}) + tx := m.DB.Session(&gorm.Session{NewDB: true}) if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { return tx.Exec("DROP TABLE IF EXISTS ?", m.CurrentTable(stmt)).Error }); err != nil { @@ -404,7 +404,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) { columnTypes = make([]gorm.ColumnType, 0) err = m.RunWithValue(value, func(stmt *gorm.Statement) error { - rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() + rows, err := m.DB.Session(&gorm.Session{NewDB: true}).Table(stmt.Table).Limit(1).Rows() if err == nil { defer rows.Close() rawColumnTypes, err := rows.ColumnTypes() diff --git a/statement.go b/statement.go index 7c0af59c..3f46ae0a 100644 --- a/statement.go +++ b/statement.go @@ -190,7 +190,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { writer.WriteString("(NULL)") } case *DB: - subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true, WithConditions: true}).getInstance() + subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...) subdb.callbacks.Query().Execute(subdb) writer.WriteString(subdb.Statement.SQL.String()) diff --git a/tests/count_test.go b/tests/count_test.go index 41bad71d..55fb71e2 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -41,7 +41,7 @@ func TestCount(t *testing.T) { t.Errorf("multiple count in chain should works") } - tx := DB.Model(&User{}).Where("name = ?", user1.Name).Session(&gorm.Session{WithConditions: true}) + tx := DB.Model(&User{}).Where("name = ?", user1.Name).Session(&gorm.Session{}) tx.Count(&count1) tx.Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2) if count1 != 1 || count2 != 3 { diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 7e3ae4e4..fe3f7d08 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -380,6 +380,13 @@ func TestSetColumn(t *testing.T) { DB.First(&result2, product.ID) AssertEqual(t, result2, product) + + product2 := Product3{Name: "Product", Price: 0} + DB.Session(&gorm.Session{SkipHooks: true}).Create(&product2) + + if product2.Price != 0 { + t.Errorf("invalid price after create without hooks, got %+v", product2) + } } func TestHooksForSlice(t *testing.T) { From 9df9f7688bd67062fa9f178cbd2179a1372c992f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 17 Nov 2020 17:49:43 +0800 Subject: [PATCH 726/881] Change UpdatingColumn to SkipHooks --- callbacks/create.go | 4 ++-- callbacks/delete.go | 4 ++-- callbacks/query.go | 2 +- callbacks/update.go | 8 ++++---- finisher_api.go | 4 ++-- gorm.go | 2 +- statement.go | 4 ++-- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index a58549a5..3ca56d73 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -10,7 +10,7 @@ import ( ) func BeforeCreate(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { if i, ok := value.(BeforeSaveInterface); ok { @@ -203,7 +203,7 @@ func CreateWithReturning(db *gorm.DB) { } func AfterCreate(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { if i, ok := value.(AfterSaveInterface); ok { diff --git a/callbacks/delete.go b/callbacks/delete.go index 4a289e0c..867aa697 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -10,7 +10,7 @@ import ( ) func BeforeDelete(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(BeforeDeleteInterface); ok { db.AddError(i.BeforeDelete(tx)) @@ -153,7 +153,7 @@ func Delete(db *gorm.DB) { } func AfterDelete(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(AfterDeleteInterface); ok { db.AddError(i.AfterDelete(tx)) diff --git a/callbacks/query.go b/callbacks/query.go index 92f711f5..89f02f58 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -214,7 +214,7 @@ func Preload(db *gorm.DB) { } func AfterQuery(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(AfterFindInterface); ok { db.AddError(i.AfterFind(tx)) diff --git a/callbacks/update.go b/callbacks/update.go index 46f59157..c8f3922e 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -29,7 +29,7 @@ func SetupUpdateReflectValue(db *gorm.DB) { } func BeforeUpdate(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { if i, ok := value.(BeforeSaveInterface); ok { @@ -87,7 +87,7 @@ func Update(db *gorm.DB) { } func AfterUpdate(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { if i, ok := value.(AfterSaveInterface); ok { @@ -198,7 +198,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } - if !stmt.UpdatingColumn && stmt.Schema != nil { + if !stmt.SkipHooks && stmt.Schema != nil { for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.LookUpField(dbName) if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { @@ -228,7 +228,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { value, isZero := field.ValueOf(updatingValue) - if !stmt.UpdatingColumn { + if !stmt.SkipHooks { if field.AutoUpdateTime > 0 { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() diff --git a/finisher_api.go b/finisher_api.go index d1390a15..1efa2e46 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -307,7 +307,7 @@ func (db *DB) Updates(values interface{}) (tx *DB) { func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} - tx.Statement.UpdatingColumn = true + tx.Statement.SkipHooks = true tx.callbacks.Update().Execute(tx) return } @@ -315,7 +315,7 @@ func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { func (db *DB) UpdateColumns(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values - tx.Statement.UpdatingColumn = true + tx.Statement.SkipHooks = true tx.callbacks.Update().Execute(tx) return } diff --git a/gorm.go b/gorm.go index f7c18b08..59e4fd6c 100644 --- a/gorm.go +++ b/gorm.go @@ -193,7 +193,7 @@ func (db *DB) Session(config *Session) *DB { } if config.SkipHooks { - tx.Statement.UpdatingColumn = true + tx.Statement.SkipHooks = true } if !config.NewDB { diff --git a/statement.go b/statement.go index 3f46ae0a..27edf9da 100644 --- a/statement.go +++ b/statement.go @@ -37,7 +37,7 @@ type Statement struct { Schema *schema.Schema Context context.Context RaiseErrorOnNotFound bool - UpdatingColumn bool + SkipHooks bool SQL strings.Builder Vars []interface{} CurDestIndex int @@ -421,7 +421,7 @@ func (stmt *Statement) clone() *Statement { Schema: stmt.Schema, Context: stmt.Context, RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, - UpdatingColumn: stmt.UpdatingColumn, + SkipHooks: stmt.SkipHooks, } for k, c := range stmt.Clauses { From 694e42d6a1de36adba2702088be5aa5658072f7f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 17 Nov 2020 19:11:16 +0800 Subject: [PATCH 727/881] Fix clause.IN with only one value of multiple rows --- clause/expression.go | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 40265ac6..b30c46b0 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -160,8 +160,13 @@ func (in IN) Build(builder Builder) { case 0: builder.WriteString(" IN (NULL)") case 1: - builder.WriteString(" = ") - builder.AddVar(builder, in.Values...) + if _, ok := in.Values[0].([]interface{}); !ok { + builder.WriteString(" = ") + builder.AddVar(builder, in.Values[0]) + break + } + + fallthrough default: builder.WriteString(" IN (") builder.AddVar(builder, in.Values...) @@ -173,9 +178,14 @@ func (in IN) NegationBuild(builder Builder) { switch len(in.Values) { case 0: case 1: - builder.WriteQuoted(in.Column) - builder.WriteString(" <> ") - builder.AddVar(builder, in.Values...) + if _, ok := in.Values[0].([]interface{}); !ok { + builder.WriteQuoted(in.Column) + builder.WriteString(" <> ") + builder.AddVar(builder, in.Values[0]) + break + } + + fallthrough default: builder.WriteQuoted(in.Column) builder.WriteString(" NOT IN (") From 50df9da6a1821cfd5bc5100dcbd007ad9defa1d8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 17 Nov 2020 20:24:08 +0800 Subject: [PATCH 728/881] Allow to skip associations when creating join table for many2many, close #3605 --- callbacks/associations.go | 4 +++- tests/associations_many2many_test.go | 22 ++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 1702f442..ce91c2ee 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -291,7 +291,9 @@ func SaveAfterAssociations(db *gorm.DB) { } if elems.Len() > 0 { - db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) + if v, ok := selectColumns[rel.Name+".*"]; !ok || v { + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) + } for i := 0; i < elems.Len(); i++ { appendToJoins(objs[i], elems.Index(i)) diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index 2ecf7b66..1ddd3b85 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -93,6 +93,28 @@ func TestMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Languages", 0, "after clear") } +func TestMany2ManyOmitAssociations(t *testing.T) { + var user = *GetUser("many2many_omit_associations", Config{Languages: 2}) + + if err := DB.Omit("Languages.*").Create(&user).Error; err == nil { + t.Fatalf("should raise error when create users without languages reference") + } + + if err := DB.Create(&user.Languages).Error; err != nil { + t.Fatalf("no error should happen when create languages, but got %v", err) + } + + if err := DB.Omit("Languages.*").Create(&user).Error; err != nil { + t.Fatalf("no error should happen when create user when languages exists, but got %v", err) + } + + // Find + var languages []Language + if DB.Model(&user).Association("Languages").Find(&languages); len(languages) != 2 { + t.Errorf("languages count should be %v, but got %v", 2, len(languages)) + } +} + func TestMany2ManyAssociationForSlice(t *testing.T) { var users = []User{ *GetUser("slice-many2many-1", Config{Languages: 2}), From 54b80b18bcc796b1f03f6ea3495f1322c59988f0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 17 Nov 2020 21:49:40 +0800 Subject: [PATCH 729/881] Allow to omit fields in associations, close #3752 --- callbacks/associations.go | 53 +++++++++++++++++++++++------- tests/associations_has_one_test.go | 14 ++++++++ 2 files changed, 55 insertions(+), 12 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index ce91c2ee..ea90780c 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -2,6 +2,7 @@ package callbacks import ( "reflect" + "strings" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -66,7 +67,7 @@ func SaveBeforeAssociations(db *gorm.DB) { } if elems.Len() > 0 { - if db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil { + if saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), elems.Interface()) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } @@ -79,7 +80,7 @@ func SaveBeforeAssociations(db *gorm.DB) { rv = rv.Addr() } - if db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil { + if saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), rv.Interface()) == nil { setupReferences(db.Statement.ReflectValue, rv) } } @@ -141,9 +142,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses( - onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), - ).Create(elems.Interface()).Error) + saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), elems.Interface()) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { @@ -163,9 +162,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses( - onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), - ).Create(f.Interface()).Error) + saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), f.Interface()) } } } @@ -224,9 +221,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses( - onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), - ).Create(elems.Interface()).Error) + saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), elems.Interface()) } } @@ -292,7 +287,7 @@ func SaveAfterAssociations(db *gorm.DB) { if elems.Len() > 0 { if v, ok := selectColumns[rel.Name+".*"]; !ok || v { - db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) + saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), elems.Interface()) } for i := 0; i < elems.Len(); i++ { @@ -335,3 +330,37 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingCol return clause.OnConflict{DoNothing: true} } + +func saveAssociations(db *gorm.DB, selectColumns map[string]bool, refName string, onConflict clause.OnConflict, values interface{}) error { + var selects, omits []string + refName = refName + "." + + for name, ok := range selectColumns { + columnName := "" + if strings.HasPrefix(name, refName) { + columnName = strings.TrimPrefix(name, refName) + } else if strings.HasPrefix(name, clause.Associations) { + columnName = name + } + + if columnName != "" { + if ok { + selects = append(selects, columnName) + } else { + omits = append(omits, columnName) + } + } + } + + tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict) + + if len(selects) > 0 { + tx = tx.Select(selects) + } + + if len(omits) > 0 { + tx = tx.Omit(omits...) + } + + return db.AddError(tx.Create(values).Error) +} diff --git a/tests/associations_has_one_test.go b/tests/associations_has_one_test.go index f487bd9e..a4fc8c4f 100644 --- a/tests/associations_has_one_test.go +++ b/tests/associations_has_one_test.go @@ -83,6 +83,20 @@ func TestHasOneAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Account", 0, "after clear") } +func TestHasOneAssociationWithSelect(t *testing.T) { + var user = *GetUser("hasone", Config{Account: true}) + + DB.Omit("Account.Number").Create(&user) + + AssertAssociationCount(t, user, "Account", 1, "") + + var account Account + DB.Model(&user).Association("Account").Find(&account) + if account.Number != "" { + t.Errorf("account's number should not be saved") + } +} + func TestHasOneAssociationForSlice(t *testing.T) { var users = []User{ *GetUser("slice-hasone-1", Config{Account: true}), From a1a30c38de195d7af91db243bc8503c88ccb951c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 18 Nov 2020 19:06:49 +0800 Subject: [PATCH 730/881] Allow to omit fields when upsert associations, close #3762 --- callbacks/associations.go | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index ea90780c..0fa47868 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -67,7 +67,7 @@ func SaveBeforeAssociations(db *gorm.DB) { } if elems.Len() > 0 { - if saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), elems.Interface()) == nil { + if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } @@ -80,7 +80,7 @@ func SaveBeforeAssociations(db *gorm.DB) { rv = rv.Addr() } - if saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), rv.Interface()) == nil { + if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { setupReferences(db.Statement.ReflectValue, rv) } } @@ -142,7 +142,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), elems.Interface()) + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { @@ -162,7 +162,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), f.Interface()) + saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) } } } @@ -221,7 +221,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), elems.Interface()) + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) } } @@ -287,7 +287,7 @@ func SaveAfterAssociations(db *gorm.DB) { if elems.Len() > 0 { if v, ok := selectColumns[rel.Name+".*"]; !ok || v { - saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), elems.Interface()) + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) } for i := 0; i < elems.Len(); i++ { @@ -302,10 +302,14 @@ func SaveAfterAssociations(db *gorm.DB) { } } -func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) clause.OnConflict { +func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) clause.OnConflict { if stmt.DB.FullSaveAssociations { defaultUpdatingColumns = make([]string, 0, len(s.DBNames)) for _, dbName := range s.DBNames { + if v, ok := selectColumns[dbName]; (ok && !v) || (!ok && restricted) { + continue + } + if !s.LookUpField(dbName).PrimaryKey { defaultUpdatingColumns = append(defaultUpdatingColumns, dbName) } @@ -331,9 +335,12 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingCol return clause.OnConflict{DoNothing: true} } -func saveAssociations(db *gorm.DB, selectColumns map[string]bool, refName string, onConflict clause.OnConflict, values interface{}) error { - var selects, omits []string - refName = refName + "." +func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { + var ( + selects, omits []string + onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns) + refName = rel.Name + "." + ) for name, ok := range selectColumns { columnName := "" From e7f45d5b0112fdce04b479d27f60c8dd8c66f3c0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 19 Nov 2020 10:45:17 +0800 Subject: [PATCH 731/881] Add error check for Transaction --- finisher_api.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 1efa2e46..f2aed8da 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -472,7 +472,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { // nested transaction - db.SavePoint(fmt.Sprintf("sp%p", fc)) + err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error defer func() { // Make sure to rollback when panic, Block error or Commit error if panicked || err != nil { @@ -480,7 +480,9 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } }() - err = fc(db.Session(&Session{})) + if err == nil { + err = fc(db.Session(&Session{})) + } } else { tx := db.Begin(opts...) @@ -491,7 +493,9 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } }() - err = fc(tx) + if err = tx.Error; err == nil { + err = fc(tx) + } if err == nil { err = tx.Commit().Error From d66af581b4b6467b9f09a1eade855b29394d0150 Mon Sep 17 00:00:00 2001 From: Deviller Date: Thu, 19 Nov 2020 14:24:34 +0300 Subject: [PATCH 732/881] Fix Association.Replace() error returning (#3766) * Fix Association.Replace() error returning * Fallback to gorm.Model at TestAssociationNotNullClear() --- association.go | 4 ++-- tests/associations_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/association.go b/association.go index 0f2102f7..7adb8c91 100644 --- a/association.go +++ b/association.go @@ -118,7 +118,7 @@ func (association *Association) Replace(values ...interface{}) error { if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) - tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap) + association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error } case schema.Many2Many: var ( @@ -154,7 +154,7 @@ func (association *Association) Replace(values ...interface{}) error { tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) } - tx.Delete(modelValue) + association.Error = tx.Delete(modelValue).Error } } return association.Error diff --git a/tests/associations_test.go b/tests/associations_test.go index c1a4e2b2..f470338f 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -32,6 +33,41 @@ func TestInvalidAssociation(t *testing.T) { } } +func TestAssociationNotNullClear(t *testing.T) { + type Profile struct { + gorm.Model + Number string + MemberID uint `gorm:"not null"` + } + + type Member struct { + gorm.Model + Profiles []Profile + } + + DB.Migrator().DropTable(&Member{}, &Profile{}) + + if err := DB.AutoMigrate(&Member{}, &Profile{}); err != nil { + t.Fatalf("Failed to migrate, got error: %v", err) + } + + member := &Member{ + Profiles: []Profile{{ + Number: "1", + }, { + Number: "2", + }}, + } + + if err := DB.Create(&member).Error; err != nil { + t.Fatalf("Failed to create test data, got error: %v", err) + } + + if err := DB.Model(member).Association("Profiles").Clear(); err == nil { + t.Fatalf("No error occured during clearind not null association") + } +} + func TestForeignKeyConstraints(t *testing.T) { type Profile struct { ID uint From e3b4e0418f2c9c4670bf21f6d9d63caa5a0903ce Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 20 Nov 2020 15:11:02 +0800 Subject: [PATCH 733/881] Inherit SkipHooks option when preloading associations, close #3772 --- callbacks/preload.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index e1dfdace..c2304af8 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -13,7 +13,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { var ( reflectValue = db.Statement.ReflectValue rel = rels[len(rels)-1] - tx = db.Session(&gorm.Session{NewDB: true}) + tx = db.Session(&gorm.Session{NewDB: true, SkipHooks: db.Statement.SkipHooks}) relForeignKeys []string relForeignFields []*schema.Field foreignFields []*schema.Field From 47ffd0bef4947fff1ba6ef4bd61b0c82f289ad20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luis=20Guillermo=20G=C3=B3mez?= <44306301+luisgomez29@users.noreply.github.com> Date: Fri, 20 Nov 2020 02:38:25 -0500 Subject: [PATCH 734/881] Select all fields in SQL queries avoiding the SELECT * FROM (#3731) * Select all fields in SQL queries avoiding the SELECT * FROM * Select table name with fields in SQL queries * Use QueryFields to execute the SQL query with all fields of the table --- callbacks/query.go | 35 ++++--- gorm.go | 7 ++ tests/multi_primary_keys_test.go | 4 +- tests/query_test.go | 160 +++++++++++++++++++++++++++++++ tests/table_test.go | 57 +++++++++++ 5 files changed, 250 insertions(+), 13 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 89f02f58..5274c246 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -68,26 +68,39 @@ func BuildQuerySQL(db *gorm.DB) { clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames)) for _, dbName := range db.Statement.Schema.DBNames { if v, ok := selectColumns[dbName]; (ok && v) || !ok { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Name: dbName}) + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName}) } } } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() { - smallerStruct := false - switch db.Statement.ReflectValue.Kind() { - case reflect.Struct: - smallerStruct = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType - case reflect.Slice: - smallerStruct = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType - } + if !db.QueryFields { + smallerStruct := false + switch db.Statement.ReflectValue.Kind() { + case reflect.Struct: + smallerStruct = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType + case reflect.Slice: + smallerStruct = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType + } - if smallerStruct { + if smallerStruct { + stmt := gorm.Statement{DB: db} + // smaller struct + if err := stmt.Parse(db.Statement.Dest); err == nil && stmt.Schema.ModelType != db.Statement.Schema.ModelType { + clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) + + for idx, dbName := range stmt.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Name: dbName} + } + } + } + } else { + // Execute the query with all the fields of the table stmt := gorm.Statement{DB: db} // smaller struct - if err := stmt.Parse(db.Statement.Dest); err == nil && stmt.Schema.ModelType != db.Statement.Schema.ModelType { + if err := stmt.Parse(db.Statement.Dest); err == nil { clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) for idx, dbName := range stmt.Schema.DBNames { - clauseSelect.Columns[idx] = clause.Column{Name: dbName} + clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} } } } diff --git a/gorm.go b/gorm.go index 59e4fd6c..1947b4df 100644 --- a/gorm.go +++ b/gorm.go @@ -36,6 +36,8 @@ type Config struct { DisableForeignKeyConstraintWhenMigrating bool // AllowGlobalUpdate allow global update AllowGlobalUpdate bool + // QueryFields executes the SQL query with all fields of the table + QueryFields bool // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder @@ -68,6 +70,7 @@ type Session struct { SkipDefaultTransaction bool AllowGlobalUpdate bool FullSaveAssociations bool + QueryFields bool Context context.Context Logger logger.Interface NowFunc func() time.Time @@ -204,6 +207,10 @@ func (db *DB) Session(config *Session) *DB { tx.Config.DryRun = true } + if config.QueryFields { + tx.Config.QueryFields = true + } + if config.Logger != nil { tx.Config.Logger = config.Logger } diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go index 68da8a88..dcc90cd9 100644 --- a/tests/multi_primary_keys_test.go +++ b/tests/multi_primary_keys_test.go @@ -140,7 +140,7 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { } if name := DB.Dialector.Name(); name == "postgres" { - t.Skip("skip postgers due to it only allow unique constraint matching given keys") + t.Skip("skip postgres due to it only allow unique constraint matching given keys") } DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") @@ -265,7 +265,7 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { } if name := DB.Dialector.Name(); name == "postgres" { - t.Skip("skip postgers due to it only allow unique constraint matching given keys") + t.Skip("skip postgres due to it only allow unique constraint matching given keys") } DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") diff --git a/tests/query_test.go b/tests/query_test.go index 20968c7e..c4162bdc 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -348,6 +348,39 @@ func TestFillSmallerStruct(t *testing.T) { } } +func TestFillSmallerStructWithAllFields(t *testing.T) { + user := User{Name: "SmallerUser", Age: 100} + DB.Save(&user) + type SimpleUser struct { + ID int64 + Name string + UpdatedAt time.Time + CreatedAt time.Time + } + var simpleUsers []SimpleUser + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + + result := dryDB.Model(&User{}).Find(&simpleUsers, user.ID) + if !regexp.MustCompile("SELECT .users.*id.*users.*name.*users.*updated_at.*users.*created_at.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Model(&User{}).Find(&User{}, user.ID) + if regexp.MustCompile("SELECT \\* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include a * wildcard, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Model(&User{}).Find(&[]User{}, user.ID) + if regexp.MustCompile("SELECT \\* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include a * wildcard, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Model(&User{}).Find(&[]*User{}, user.ID) + if regexp.MustCompile("SELECT \\* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include a * wildcard, but got %v", result.Statement.SQL.String()) + } +} + func TestNot(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) @@ -392,6 +425,53 @@ func TestNot(t *testing.T) { } } +func TestNotWithAllFields(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*name" + + ".*users.*age.*users.*birthday.*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " + + result := dryDB.Not(map[string]interface{}{"users.name": "jinzhu"}).Find(&User{}) + + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* <> .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("users.name = ?", "jinzhu1").Not("users.name = ?", "jinzhu2").Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* = .+ AND NOT .*users.*name.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where(map[string]interface{}{"users.name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not("users.name = ?", "jinzhu").Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE NOT .*users.*name.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not(map[string]interface{}{"users.name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not([]int64{1, 2}).First(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*id.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not([]int64{}).First(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .users.\\..deleted_at. IS NULL ORDER BY").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not(User{Name: "jinzhu", Age: 18}).First(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } +} + func TestOr(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) @@ -411,6 +491,27 @@ func TestOr(t *testing.T) { } } +func TestOrWithAllFields(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*name" + + ".*users.*age.*users.*birthday.*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " + + result := dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("users.name = ?", "jinzhu").Or(User{Name: "jinzhu 2", Age: 18}).Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* = .+ OR \\(.*users.*name.* AND .*users.*age.*\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("users.name = ?", "jinzhu").Or(map[string]interface{}{"name": "jinzhu 2", "age": 18}).Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* = .+ OR \\(.*age.* AND .*name.*\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } +} + func TestPluck(t *testing.T) { users := []*User{ GetUser("pluck-user1", Config{}), @@ -543,6 +644,30 @@ func TestOmit(t *testing.T) { } } +func TestOmitWithAllFields(t *testing.T) { + user := User{Name: "OmitUser1", Age: 20} + DB.Save(&user) + + var userResult User + DB.Session(&gorm.Session{QueryFields: true}).Where("users.name = ?", user.Name).Omit("name").Find(&userResult) + if userResult.ID == 0 { + t.Errorf("Should not have ID because only selected name, %+v", userResult.ID) + } + + if userResult.Name != "" || userResult.Age != 20 { + t.Errorf("User Name should be omitted, got %v, Age should be ok, got %v", userResult.Name, userResult.Age) + } + + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*birthday" + + ".*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " + + result := dryDB.Omit("name, age").Find(&User{}) + if !regexp.MustCompile(userQuery).MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL must include table name and selected fields, got %v", result.Statement.SQL.String()) + } +} + func TestPluckWithSelect(t *testing.T) { users := []User{ {Name: "pluck_with_select_1", Age: 25}, @@ -685,6 +810,31 @@ func TestOrder(t *testing.T) { } } +func TestOrderWithAllFields(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*name.*users.*age" + + ".*users.*birthday.*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " + + result := dryDB.Order("users.age desc, users.name").Find(&User{}) + if !regexp.MustCompile(userQuery + "users.age desc, users.name").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Order("users.age desc").Order("users.name").Find(&User{}) + if !regexp.MustCompile(userQuery + "ORDER BY users.age desc,users.name").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + stmt := dryDB.Clauses(clause.OrderBy{ + Expression: clause.Expr{SQL: "FIELD(id,?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, + }).Find(&User{}).Statement + + explainedSQL := dryDB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(userQuery + "ORDER BY FIELD\\(id,1,2,3\\)").MatchString(explainedSQL) { + t.Fatalf("Build Order condition, but got %v", explainedSQL) + } +} + func TestLimit(t *testing.T) { users := []User{ {Name: "LimitUser1", Age: 1}, @@ -892,3 +1042,13 @@ func TestQueryWithTableAndConditions(t *testing.T) { t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) } } + +func TestQueryWithTableAndConditionsAndAllFields(t *testing.T) { + result := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}).Table("user").Find(&User{}, User{Name: "jinzhu"}) + userQuery := "SELECT .*user.*id.*user.*created_at.*user.*updated_at.*user.*deleted_at.*user.*name.*user.*age" + + ".*user.*birthday.*user.*company_id.*user.*manager_id.*user.*active.* FROM .user. " + + if !regexp.MustCompile(userQuery + `WHERE .user.\..name. = .+ AND .user.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } +} diff --git a/tests/table_test.go b/tests/table_test.go index 647b5e19..0c6b3eb0 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -68,3 +68,60 @@ func TestTable(t *testing.T) { AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3}) } + +func TestTableWithAllFields(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*user.*id.*user.*created_at.*user.*updated_at.*user.*deleted_at.*user.*name.*user.*age" + + ".*user.*birthday.*user.*company_id.*user.*manager_id.*user.*active.* " + + r := dryDB.Table("`user`").Find(&User{}).Statement + if !regexp.MustCompile(userQuery + "FROM `user`").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("user as u").Select("name").Find(&User{}).Statement + if !regexp.MustCompile("SELECT .name. FROM user as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("gorm.user").Select("name").Find(&User{}).Statement + if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Select("name").Find(&UserWithTable{}).Statement + if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Create(&UserWithTable{}).Statement + if DB.Dialector.Name() != "sqlite" { + if !regexp.MustCompile(`INSERT INTO .gorm.\..user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } else { + if !regexp.MustCompile(`INSERT INTO .user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } + + userQueryCharacter := "SELECT .*u.*id.*u.*created_at.*u.*updated_at.*u.*deleted_at.*u.*name.*u.*age.*u.*birthday" + + ".*u.*company_id.*u.*manager_id.*u.*active.* " + + r = dryDB.Table("(?) as u", DB.Model(&User{}).Select("name")).Find(&User{}).Statement + if !regexp.MustCompile(userQueryCharacter + "FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name"), DB.Model(&Pet{}).Select("name")).Find(&User{}).Statement + if !regexp.MustCompile(userQueryCharacter + "FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE .pets.\\..deleted_at. IS NULL\\) as p WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Where("name = ?", 1).Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name").Where("name = ?", 2), DB.Model(&Pet{}).Where("name = ?", 4).Select("name")).Where("name = ?", 3).Find(&User{}).Statement + if !regexp.MustCompile(userQueryCharacter + "FROM \\(SELECT .name. FROM .users. WHERE name = .+ AND .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE name = .+ AND .pets.\\..deleted_at. IS NULL\\) as p WHERE name = .+ AND name = .+ AND .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3}) +} From dec874851285805dc82d29f4e9ed360cb99c3345 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 20 Nov 2020 15:44:39 +0800 Subject: [PATCH 735/881] Refactor QueryFields Option --- callbacks/query.go | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 5274c246..aa4629a2 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -72,31 +72,20 @@ func BuildQuerySQL(db *gorm.DB) { } } } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() { - if !db.QueryFields { - smallerStruct := false + queryFields := db.QueryFields + if !queryFields { switch db.Statement.ReflectValue.Kind() { case reflect.Struct: - smallerStruct = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType + queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType case reflect.Slice: - smallerStruct = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType + queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType } + } - if smallerStruct { - stmt := gorm.Statement{DB: db} - // smaller struct - if err := stmt.Parse(db.Statement.Dest); err == nil && stmt.Schema.ModelType != db.Statement.Schema.ModelType { - clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) - - for idx, dbName := range stmt.Schema.DBNames { - clauseSelect.Columns[idx] = clause.Column{Name: dbName} - } - } - } - } else { - // Execute the query with all the fields of the table + if queryFields { stmt := gorm.Statement{DB: db} // smaller struct - if err := stmt.Parse(db.Statement.Dest); err == nil { + if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) { clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) for idx, dbName := range stmt.Schema.DBNames { From 6186a4daa7ad61fdfb7750db68ba30c3391cc614 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 20 Nov 2020 16:56:52 +0800 Subject: [PATCH 736/881] allow SkipHooks when preload & save associations --- callbacks/associations.go | 2 +- callbacks/preload.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 0fa47868..e6669600 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -359,7 +359,7 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, } } - tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict) + tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks}) if len(selects) > 0 { tx = tx.Select(selects) diff --git a/callbacks/preload.go b/callbacks/preload.go index c2304af8..682427c9 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -13,7 +13,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { var ( reflectValue = db.Statement.ReflectValue rel = rels[len(rels)-1] - tx = db.Session(&gorm.Session{NewDB: true, SkipHooks: db.Statement.SkipHooks}) + tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks}) relForeignKeys []string relForeignFields []*schema.Field foreignFields []*schema.Field From 66e8a72bf1b04a6b256c94708da68ddab498a5aa Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 23 Nov 2020 11:24:07 +0800 Subject: [PATCH 737/881] Support NameReplace for NamingStrategy, close #3779 --- schema/naming.go | 21 +++++++++++++-------- schema/naming_test.go | 12 ++++++++++-- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/schema/naming.go b/schema/naming.go index e3b2104a..63296967 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -24,19 +24,20 @@ type Namer interface { type NamingStrategy struct { TablePrefix string SingularTable bool + NameReplacer *strings.Replacer } // TableName convert string to table name func (ns NamingStrategy) TableName(str string) string { if ns.SingularTable { - return ns.TablePrefix + toDBName(str) + return ns.TablePrefix + ns.toDBName(str) } - return ns.TablePrefix + inflection.Plural(toDBName(str)) + return ns.TablePrefix + inflection.Plural(ns.toDBName(str)) } // ColumnName convert string to column name func (ns NamingStrategy) ColumnName(table, column string) string { - return toDBName(column) + return ns.toDBName(column) } // JoinTableName convert string to join table name @@ -46,14 +47,14 @@ func (ns NamingStrategy) JoinTableName(str string) string { } if ns.SingularTable { - return ns.TablePrefix + toDBName(str) + return ns.TablePrefix + ns.toDBName(str) } - return ns.TablePrefix + inflection.Plural(toDBName(str)) + return ns.TablePrefix + inflection.Plural(ns.toDBName(str)) } // RelationshipFKName generate fk name for relation func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { - return strings.Replace(fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name)), ".", "_", -1) + return strings.Replace(fmt.Sprintf("fk_%s_%s", rel.Schema.Table, ns.toDBName(rel.Name)), ".", "_", -1) } // CheckerName generate checker name @@ -63,7 +64,7 @@ func (ns NamingStrategy) CheckerName(table, column string) string { // IndexName generate index name func (ns NamingStrategy) IndexName(table, column string) string { - idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column)) + idxName := fmt.Sprintf("idx_%v_%v", table, ns.toDBName(column)) idxName = strings.Replace(idxName, ".", "_", -1) if utf8.RuneCountInString(idxName) > 64 { @@ -91,13 +92,17 @@ func init() { commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) } -func toDBName(name string) string { +func (ns NamingStrategy) toDBName(name string) string { if name == "" { return "" } else if v, ok := smap.Load(name); ok { return v.(string) } + if ns.NameReplacer != nil { + name = ns.NameReplacer.Replace(name) + } + var ( value = commonInitialismsReplacer.Replace(name) buf strings.Builder diff --git a/schema/naming_test.go b/schema/naming_test.go index 26b0dcf6..b7a32160 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -1,6 +1,7 @@ package schema import ( + "strings" "testing" ) @@ -26,9 +27,10 @@ func TestToDBName(t *testing.T) { "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", } + ns := NamingStrategy{} for key, value := range maps { - if toDBName(key) != value { - t.Errorf("%v toName should equal %v, but got %v", key, value, toDBName(key)) + if ns.toDBName(key) != value { + t.Errorf("%v toName should equal %v, but got %v", key, value, ns.toDBName(key)) } } } @@ -37,6 +39,7 @@ func TestNamingStrategy(t *testing.T) { var ns = NamingStrategy{ TablePrefix: "public.", SingularTable: true, + NameReplacer: strings.NewReplacer("CID", "Cid"), } idxName := ns.IndexName("public.table", "name") @@ -63,4 +66,9 @@ func TestNamingStrategy(t *testing.T) { if tableName != "public.company" { t.Errorf("invalid table name generated, got %v", tableName) } + + columdName := ns.ColumnName("", "NameCID") + if columdName != "name_cid" { + t.Errorf("invalid column name generated, got %v", columdName) + } } From 557b874ee3c9a6df9ffc5cd4a4bf2d89d3e788d5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 25 Nov 2020 14:55:53 +0800 Subject: [PATCH 738/881] Fix check field's precision --- migrator/migrator.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 5de820a8..084d430f 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -381,7 +381,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy // check precision if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { - if strings.Contains(fullDataType, fmt.Sprint(field.Precision)) { + if strings.Contains(m.DataTypeOf(field), fmt.Sprint(field.Precision)) { alterColumn = true } } From 6950007d6a68f6e5bd3f2295152a0e8f148451cc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 27 Nov 2020 14:32:20 +0800 Subject: [PATCH 739/881] Fix failed to parse relations when using goroutinue, close #3790 commit ee0ec43e8dfa85c1c1a562c2d0d47776cf8abd92 Author: Jinzhu Date: Fri Nov 27 14:31:57 2020 +0800 Fix failed to parse relations when using goroutinue, close #3790 commit 590e73ff95d8af6bd14f0a0da687dd7d12e5f94e Author: rokeyzhao Date: Thu Nov 26 20:27:55 2020 +0800 test: no cache preload in goroutine --- schema/field.go | 2 +- schema/relationship.go | 2 +- schema/schema.go | 31 +++++++++++++++++++++++++++++-- tests/go.mod | 1 + tests/preload_suits_test.go | 6 +++--- tests/preload_test.go | 19 +++++++++++++++++++ 6 files changed, 54 insertions(+), 7 deletions(-) diff --git a/schema/field.go b/schema/field.go index b303fb30..86b4a061 100644 --- a/schema/field.go +++ b/schema/field.go @@ -330,7 +330,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { cacheStore := &sync.Map{} cacheStore.Store(embeddedCacheKey, true) - if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { + if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { schema.err = err } diff --git a/schema/relationship.go b/schema/relationship.go index 35af111f..9cfc10be 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -71,7 +71,7 @@ func (schema *Schema) parseRelation(field *Field) { cacheStore = field.OwnerSchema.cacheStore } - if relation.FieldSchema, err = Parse(fieldValue, cacheStore, schema.namer); err != nil { + if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { schema.err = err return } diff --git a/schema/schema.go b/schema/schema.go index 05db641f..89392643 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -38,6 +38,7 @@ type Schema struct { BeforeSave, AfterSave bool AfterFind bool err error + initialized chan struct{} namer Namer cacheStore *sync.Map } @@ -89,7 +90,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } if v, ok := cacheStore.Load(modelType); ok { - return v.(*Schema), nil + s := v.(*Schema) + <-s.initialized + return s, nil } modelValue := reflect.New(modelType) @@ -110,6 +113,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) Relationships: Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore, namer: namer, + initialized: make(chan struct{}), } defer func() { @@ -219,7 +223,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } - if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded { + if s, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded { if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { for _, field := range schema.Fields { if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { @@ -245,8 +249,31 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) } } + close(schema.initialized) } + } else { + return s.(*Schema), nil } return schema, schema.err } + +func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + modelType := reflect.ValueOf(dest).Type() + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + } + + if v, ok := cacheStore.Load(modelType); ok { + return v.(*Schema), nil + } + + return Parse(dest, cacheStore, namer) +} diff --git a/tests/go.mod b/tests/go.mod index 55495de3..fa293987 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,6 +6,7 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 + github.com/stretchr/testify v1.5.1 gorm.io/driver/mysql v1.0.3 gorm.io/driver/postgres v1.0.5 gorm.io/driver/sqlite v1.1.3 diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go index d40309e7..0ef8890b 100644 --- a/tests/preload_suits_test.go +++ b/tests/preload_suits_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "reflect" "sort" + "sync/atomic" "testing" "gorm.io/gorm" @@ -1497,10 +1498,9 @@ func TestPreloadManyToManyCallbacks(t *testing.T) { } DB.Save(&lvl) - called := 0 - + var called int64 DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(_ *gorm.DB) { - called = called + 1 + atomic.AddInt64(&called, 1) }) DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID) diff --git a/tests/preload_test.go b/tests/preload_test.go index d9035661..4b31b12c 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -5,6 +5,7 @@ import ( "regexp" "sort" "strconv" + "sync" "testing" "gorm.io/gorm" @@ -212,3 +213,21 @@ func TestPreloadEmptyData(t *testing.T) { t.Errorf("json marshal is not empty slice, got %v", string(r)) } } + +func TestPreloadGoroutine(t *testing.T) { + var wg sync.WaitGroup + + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + var user2 []User + tx := DB.Where("id = ?", 1).Session(&gorm.Session{}) + + if err := tx.Preload("Team").Find(&user2).Error; err != nil { + t.Error(err) + } + }() + } + wg.Wait() +} From 0f77500917e619b0c52880e59487f1e2eef005ab Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 27 Nov 2020 17:04:56 +0800 Subject: [PATCH 740/881] Waiting for schema to be initialized, close #3790 --- schema/schema.go | 1 + 1 file changed, 1 insertion(+) diff --git a/schema/schema.go b/schema/schema.go index 89392643..da4be305 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -252,6 +252,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) close(schema.initialized) } } else { + <-s.(*Schema).initialized return s.(*Schema), nil } From acedbb8310221ac1d943c34f81d55fea95901f63 Mon Sep 17 00:00:00 2001 From: Dakatan Date: Mon, 30 Nov 2020 11:09:08 +0900 Subject: [PATCH 741/881] Fix Scan int32, uint32 (#3801) --- scan.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scan.go b/scan.go index c9c8f442..89849d98 100644 --- a/scan.go +++ b/scan.go @@ -84,7 +84,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { scanIntoMap(mapValue, values, columns) *dest = append(*dest, mapValue) } - case *int, *int64, *uint, *uint64, *float32, *float64, *string, *time.Time: + case *int, *int32, *int64, *uint, *uint32, *uint64, *float32, *float64, *string, *time.Time: for initialized || rows.Next() { initialized = false db.RowsAffected++ From 41e52f343af753cb173cbb3ddd092b034151428a Mon Sep 17 00:00:00 2001 From: SmallTianTian Date: Wed, 2 Dec 2020 14:00:16 +0800 Subject: [PATCH 742/881] fix: scan more base type and sql.NullXXX (#3813) --- scan.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/scan.go b/scan.go index 89849d98..0416489d 100644 --- a/scan.go +++ b/scan.go @@ -84,7 +84,12 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { scanIntoMap(mapValue, values, columns) *dest = append(*dest, mapValue) } - case *int, *int32, *int64, *uint, *uint32, *uint64, *float32, *float64, *string, *time.Time: + case *int, *int8, *int16, *int32, *int64, + *uint, *uint8, *uint16, *uint32, *uint64, *uintptr, + *float32, *float64, + *bool, *string, *time.Time, + *sql.NullInt32, *sql.NullInt64, *sql.NullFloat64, + *sql.NullBool, *sql.NullString, *sql.NullTime: for initialized || rows.Next() { initialized = false db.RowsAffected++ From 0c12a4c360e1f8b8569ffc9c29111a9abf58b492 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 2 Dec 2020 14:59:50 +0800 Subject: [PATCH 743/881] Add CreateBatchSize option --- finisher_api.go | 27 +++++++++++++++++++++------ gorm.go | 7 +++++++ tests/create_test.go | 34 +++++++++++++++++++++++++++++++++- tests/go.mod | 4 ++-- 4 files changed, 63 insertions(+), 9 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index f2aed8da..fc7a73be 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -15,6 +15,10 @@ import ( // Create insert the value into database func (db *DB) Create(value interface{}) (tx *DB) { + if db.CreateBatchSize > 0 { + return db.CreateInBatches(value, db.CreateBatchSize) + } + tx = db.getInstance() tx.Statement.Dest = value tx.callbacks.Create().Execute(tx) @@ -27,19 +31,30 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: + var rowsAffected int64 tx = db.getInstance() - for i := 0; i < reflectValue.Len(); i += batchSize { - tx.AddError(tx.Transaction(func(tx *DB) error { + tx.AddError(tx.Transaction(func(tx *DB) error { + for i := 0; i < reflectValue.Len(); i += batchSize { ends := i + batchSize if ends > reflectValue.Len() { ends = reflectValue.Len() } - return tx.Create(reflectValue.Slice(i, ends).Interface()).Error - })) - } + subtx := tx.getInstance() + subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface() + subtx.callbacks.Create().Execute(subtx) + if subtx.Error != nil { + return subtx.Error + } + rowsAffected += subtx.RowsAffected + } + return nil + })) + tx.RowsAffected = rowsAffected default: - return db.Create(value) + tx = db.getInstance() + tx.Statement.Dest = value + tx.callbacks.Create().Execute(tx) } return } diff --git a/gorm.go b/gorm.go index 1947b4df..ae1cf2c9 100644 --- a/gorm.go +++ b/gorm.go @@ -38,6 +38,8 @@ type Config struct { AllowGlobalUpdate bool // QueryFields executes the SQL query with all fields of the table QueryFields bool + // CreateBatchSize default create batch size + CreateBatchSize int // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder @@ -74,6 +76,7 @@ type Session struct { Context context.Context Logger logger.Interface NowFunc func() time.Time + CreateBatchSize int } // Open initialize db session based on dialector @@ -161,6 +164,10 @@ func (db *DB) Session(config *Session) *DB { } ) + if config.CreateBatchSize > 0 { + tx.Config.CreateBatchSize = config.CreateBatchSize + } + if config.SkipDefaultTransaction { tx.Config.SkipDefaultTransaction = true } diff --git a/tests/create_test.go b/tests/create_test.go index 8d005d0b..170c8546 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -50,7 +50,39 @@ func TestCreateInBatches(t *testing.T) { *GetUser("create_in_batches_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), } - DB.CreateInBatches(&users, 2) + result := DB.CreateInBatches(&users, 2) + if result.RowsAffected != int64(len(users)) { + t.Errorf("affected rows should be %v, but got %v", len(users), result.RowsAffected) + } + + for _, user := range users { + if user.ID == 0 { + t.Fatalf("failed to fill user's ID, got %v", user.ID) + } else { + var newUser User + if err := DB.Where("id = ?", user.ID).Preload(clause.Associations).First(&newUser).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + CheckUser(t, newUser, user) + } + } + } +} + +func TestCreateInBatchesWithDefaultSize(t *testing.T) { + users := []User{ + *GetUser("create_with_default_batch_size_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), + *GetUser("create_with_default_batch_sizs_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), + *GetUser("create_with_default_batch_sizs_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), + *GetUser("create_with_default_batch_sizs_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), + *GetUser("create_with_default_batch_sizs_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), + *GetUser("create_with_default_batch_sizs_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), + } + + result := DB.Session(&gorm.Session{CreateBatchSize: 2}).Create(&users) + if result.RowsAffected != int64(len(users)) { + t.Errorf("affected rows should be %v, but got %v", len(users), result.RowsAffected) + } for _, user := range users { if user.ID == 0 { diff --git a/tests/go.mod b/tests/go.mod index fa293987..03283a53 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,9 +9,9 @@ require ( github.com/stretchr/testify v1.5.1 gorm.io/driver/mysql v1.0.3 gorm.io/driver/postgres v1.0.5 - gorm.io/driver/sqlite v1.1.3 + gorm.io/driver/sqlite v1.1.4 gorm.io/driver/sqlserver v1.0.5 - gorm.io/gorm v1.20.5 + gorm.io/gorm v1.20.7 ) replace gorm.io/gorm => ../ From 51568ba4ab0da8fd382af023f8400c366b70bf88 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 2 Dec 2020 17:27:07 +0800 Subject: [PATCH 744/881] Delete select clause after Count, close #3814 --- finisher_api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finisher_api.go b/finisher_api.go index fc7a73be..d36dc754 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -377,7 +377,7 @@ func (db *DB) Count(count *int64) (tx *DB) { } tx.Statement.AddClause(clause.Select{Expression: expr}) - defer tx.Statement.AddClause(clause.Select{}) + defer delete(tx.Statement.Clauses, "SELECT") } if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok { From f2321ca164c0e5fd6cdcd5727152b39f2062ca6b Mon Sep 17 00:00:00 2001 From: Andrei Baibaratsky Date: Thu, 3 Dec 2020 08:00:26 +0100 Subject: [PATCH 745/881] Fixed creation of associated records with composite primary keys (go-gorm#3817) (#3818) --- callbacks/associations.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index e6669600..9e767e5e 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -318,12 +318,8 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[ if len(defaultUpdatingColumns) > 0 { var columns []clause.Column - if s.PrioritizedPrimaryField != nil { - columns = []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}} - } else { - for _, dbName := range s.PrimaryFieldDBNames { - columns = append(columns, clause.Column{Name: dbName}) - } + for _, dbName := range s.PrimaryFieldDBNames { + columns = append(columns, clause.Column{Name: dbName}) } return clause.OnConflict{ From 61d3a4d6ea93e865738bf949657ff0ffcc8a7f97 Mon Sep 17 00:00:00 2001 From: Andy Bursavich Date: Thu, 3 Dec 2020 19:28:38 -0800 Subject: [PATCH 746/881] Fix schema initialization paths (#3825) * Fix schema initialization paths The initialized channel was only closed if the schema's cacheStore did not contain the embeddedCacheKey and there were no errors parsing relations. If the key existed or an error occurred, it would not be closed. This could leave other goroutines waiting for synchronization that will never occur. Additionally, the other code paths that wait for initialization to complete did not return the possible error. * Unnest common schema initialization This makes the common code path less deeply nested and the flow control easier to follow. --- schema/schema.go | 57 ++++++++++++++++++++++++------------------------ 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index da4be305..8d9368da 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -92,7 +92,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if v, ok := cacheStore.Load(modelType); ok { s := v.(*Schema) <-s.initialized - return s, nil + return s, s.err } modelValue := reflect.New(modelType) @@ -223,37 +223,38 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } - if s, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded { - if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { - for _, field := range schema.Fields { - if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { - if schema.parseRelation(field); schema.err != nil { - return schema, schema.err - } - } + if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { + s := v.(*Schema) + <-s.initialized + return s, s.err + } - fieldValue := reflect.New(field.IndirectFieldType) - if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { - field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) - } - - if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { - field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) - } - - if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) - } - - if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + defer close(schema.initialized) + if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { + for _, field := range schema.Fields { + if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err } } - close(schema.initialized) + + fieldValue := reflect.New(field.IndirectFieldType) + if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } } - } else { - <-s.(*Schema).initialized - return s.(*Schema), nil } return schema, schema.err From f6550419088d21a98cf5f3c8dc3bfc30e46e1cb1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 6 Dec 2020 11:06:52 +0800 Subject: [PATCH 747/881] Allow overwrite ignored field's permission, close #3829 --- schema/schema.go | 2 +- statement.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index 8d9368da..e36ed7b6 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -161,7 +161,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } - if _, ok := schema.FieldsByName[field.Name]; !ok { + if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { schema.FieldsByName[field.Name] = field } diff --git a/statement.go b/statement.go index 27edf9da..a0da0c6d 100644 --- a/statement.go +++ b/statement.go @@ -576,7 +576,7 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( } if stmt.Schema != nil { - for _, field := range stmt.Schema.Fields { + for _, field := range stmt.Schema.FieldsByName { name := field.DBName if name == "" { name = field.Name From 1ef1f0bfe46cb18cf8453738e40d6c1c72c3621c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 6 Dec 2020 14:04:37 +0800 Subject: [PATCH 748/881] Fix Count with complicated Select, close #3826 --- chainable_api.go | 15 ++++++--------- finisher_api.go | 41 ++++++++++++++++++++++++++--------------- tests/count_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ tests/query_test.go | 2 +- 4 files changed, 77 insertions(+), 25 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index c3a02d20..dca12b08 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -93,10 +93,12 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } delete(tx.Statement.Clauses, "SELECT") case string: - fields := strings.FieldsFunc(v, utils.IsValidDBNameChar) - - // normal field names - if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { + if (strings.Contains(v, " ?") || strings.Contains(v, "(?")) && len(args) > 0 { + tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, + Expression: clause.Expr{SQL: v, Vars: args}, + }) + } else { tx.Statement.Selects = []string{v} for _, arg := range args { @@ -115,11 +117,6 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } delete(tx.Statement.Clauses, "SELECT") - } else { - tx.Statement.AddClause(clause.Select{ - Distinct: db.Statement.Distinct, - Expression: clause.Expr{SQL: v, Vars: args}, - }) } default: tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) diff --git a/finisher_api.go b/finisher_api.go index d36dc754..98a877f2 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -355,29 +355,38 @@ func (db *DB) Count(count *int64) (tx *DB) { }() } + if selectClause, ok := db.Statement.Clauses["SELECT"]; ok { + defer func() { + db.Statement.Clauses["SELECT"] = selectClause + }() + } else { + defer delete(tx.Statement.Clauses, "SELECT") + } + if len(tx.Statement.Selects) == 0 { tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) - defer delete(tx.Statement.Clauses, "SELECT") } else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") { expr := clause.Expr{SQL: "count(1)"} if len(tx.Statement.Selects) == 1 { dbName := tx.Statement.Selects[0] - if tx.Statement.Parse(tx.Statement.Model) == nil { - if f := tx.Statement.Schema.LookUpField(dbName); f != nil { - dbName = f.DBName + fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) + if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { + if tx.Statement.Parse(tx.Statement.Model) == nil { + if f := tx.Statement.Schema.LookUpField(dbName); f != nil { + dbName = f.DBName + } } - } - if tx.Statement.Distinct { - expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}} - } else { - expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} + if tx.Statement.Distinct { + expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}} + } else { + expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} + } } } tx.Statement.AddClause(clause.Select{Expression: expr}) - defer delete(tx.Statement.Clauses, "SELECT") } if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok { @@ -457,11 +466,13 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx.AddError(ErrModelValueRequired) } - fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) - tx.Statement.AddClauseIfNotExists(clause.Select{ - Distinct: tx.Statement.Distinct, - Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, - }) + if len(tx.Statement.Selects) != 1 { + fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) + tx.Statement.AddClauseIfNotExists(clause.Select{ + Distinct: tx.Statement.Distinct, + Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, + }) + } tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) return diff --git a/tests/count_test.go b/tests/count_test.go index 55fb71e2..ffe675d9 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -3,6 +3,8 @@ package tests_test import ( "fmt" "regexp" + "sort" + "strings" "testing" "gorm.io/gorm" @@ -77,4 +79,46 @@ func TestCount(t *testing.T) { if err := DB.Table("users").Where("users.name = ?", user1.Name).Order("name").Count(&count5).Error; err != nil || count5 != 1 { t.Errorf("count with join, got error: %v, count %v", err, count) } + + var count6 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( + "(CASE WHEN name=? THEN ? ELSE ? END) as name", "count-1", "main", "other", + ).Count(&count6).Find(&users).Error; err != nil || count6 != 3 { + t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + expects := []User{User{Name: "main"}, {Name: "other"}, {Name: "other"}} + sort.SliceStable(users, func(i, j int) bool { + return strings.Compare(users[i].Name, users[j].Name) < 0 + }) + + AssertEqual(t, users, expects) + + var count7 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( + "(CASE WHEN name=? THEN ? ELSE ? END) as name, age", "count-1", "main", "other", + ).Count(&count7).Find(&users).Error; err != nil || count7 != 3 { + t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + expects = []User{User{Name: "main", Age: 18}, {Name: "other", Age: 18}, {Name: "other", Age: 18}} + sort.SliceStable(users, func(i, j int) bool { + return strings.Compare(users[i].Name, users[j].Name) < 0 + }) + + AssertEqual(t, users, expects) + + var count8 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( + "(CASE WHEN age=18 THEN 1 ELSE 2 END) as age", "name", + ).Count(&count8).Find(&users).Error; err != nil || count8 != 3 { + t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + expects = []User{User{Name: "count-1", Age: 1}, {Name: "count-2", Age: 1}, {Name: "count-3", Age: 1}} + sort.SliceStable(users, func(i, j int) bool { + return strings.Compare(users[i].Name, users[j].Name) < 0 + }) + + AssertEqual(t, users, expects) } diff --git a/tests/query_test.go b/tests/query_test.go index c4162bdc..af8bbf07 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -677,7 +677,7 @@ func TestPluckWithSelect(t *testing.T) { DB.Create(&users) var userAges []int - err := DB.Model(&User{}).Where("name like ?", "pluck_with_select%").Select("age + 1 as user_age").Pluck("user_age", &userAges).Error + err := DB.Model(&User{}).Where("name like ?", "pluck_with_select%").Select("age + 1 as user_age").Pluck("user_age", &userAges).Error if err != nil { t.Fatalf("got error when pluck user_age: %v", err) } From 6a0fca21952b1852bece7aa4479099adbb205f56 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 6 Dec 2020 18:07:12 +0800 Subject: [PATCH 749/881] Return error for invalid relations definition, close #3830 --- schema/relationship.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index 9cfc10be..19945e0f 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -362,7 +362,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue schema.guessRelation(relation, field, guessEmbeddedHas) // case guessEmbeddedHas: default: - schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name) + schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a valid foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name) } } @@ -427,7 +427,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue } } } else if len(primaryFields) == 0 { - if len(foreignFields) == 1 { + if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil { primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField) } else if len(primarySchema.PrimaryFields) == len(foreignFields) { primaryFields = append(primaryFields, primarySchema.PrimaryFields...) From e1952924e2a844eca52e5030f7b46b78de6ec135 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 7 Dec 2020 10:31:06 +0800 Subject: [PATCH 750/881] Support named Joins, close #3833 --- callbacks/query.go | 4 ++-- tests/joins_test.go | 16 +++++++++++----- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index aa4629a2..ebb09d6b 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -108,7 +108,7 @@ func BuildQuerySQL(db *gorm.DB) { for _, join := range db.Statement.Joins { if db.Statement.Schema == nil { joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, + Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, }) } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { tableAliasName := relation.Name @@ -150,7 +150,7 @@ func BuildQuerySQL(db *gorm.DB) { }) } else { joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, + Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, }) } } diff --git a/tests/joins_test.go b/tests/joins_test.go index f78ddf67..46611f5f 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -61,35 +61,41 @@ func TestJoinConds(t *testing.T) { DB.Save(&user) var users1 []User - DB.Joins("left join pets on pets.user_id = users.id").Where("users.name = ?", user.Name).Find(&users1) + DB.Joins("inner join pets on pets.user_id = users.id").Where("users.name = ?", user.Name).Find(&users1) if len(users1) != 3 { t.Errorf("should find two users using left join, but got %v", len(users1)) } var users2 []User - DB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Where("users.name = ?", user.Name).First(&users2) + DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Where("users.name = ?", user.Name).First(&users2) if len(users2) != 1 { t.Errorf("should find one users using left join with conditions, but got %v", len(users2)) } var users3 []User - DB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where("users.name = ?", user.Name).First(&users3) + DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where("users.name = ?", user.Name).First(&users3) if len(users3) != 1 { t.Errorf("should find one users using multiple left join conditions, but got %v", len(users3)) } var users4 []User - DB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number+"non-exist").Where("users.name = ?", user.Name).First(&users4) + DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number+"non-exist").Where("users.name = ?", user.Name).First(&users4) if len(users4) != 0 { t.Errorf("should find no user when searching with unexisting credit card, but got %v", len(users4)) } var users5 []User - db5 := DB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5) + db5 := DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5) if db5.Error != nil { t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) } + var users6 []User + DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = @Name", user.Pets[0]).Where("users.name = ?", user.Name).First(&users6) + if len(users6) != 1 { + t.Errorf("should find one users using left join with conditions, but got %v", len(users6)) + } + dryDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryDB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5).Statement From 51b5208599e4e4f27205d3eadf0ca561fc8ee6bb Mon Sep 17 00:00:00 2001 From: vellotis Date: Fri, 11 Dec 2020 08:07:23 +0200 Subject: [PATCH 751/881] Fix building of `clause.Eq` and `clause.Neq` expressions that fail to handle `(*T)(nil)` use cases correctly (#3848) * Update tests to cover building `clause.Eq` and `clause.Neq` when value could be a nil pointer of a primitive * Fix use cases for `clause.Eq` and `clause.Neq` when value is nil pointer of a primitive type --- clause/expression.go | 13 +++++++++-- clause/expression_test.go | 49 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index b30c46b0..3844d66b 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -203,7 +203,7 @@ type Eq struct { func (eq Eq) Build(builder Builder) { builder.WriteQuoted(eq.Column) - if eq.Value == nil { + if eqNil(eq.Value) { builder.WriteString(" IS NULL") } else { builder.WriteString(" = ") @@ -221,7 +221,7 @@ type Neq Eq func (neq Neq) Build(builder Builder) { builder.WriteQuoted(neq.Column) - if neq.Value == nil { + if eqNil(neq.Value) { builder.WriteString(" IS NOT NULL") } else { builder.WriteString(" <> ") @@ -299,3 +299,12 @@ func (like Like) NegationBuild(builder Builder) { builder.WriteString(" NOT LIKE ") builder.AddVar(builder, like.Value) } + +func eqNil(value interface{}) bool { + return value == nil || eqNilReflect(value) +} + +func eqNilReflect(value interface{}) bool { + reflectValue := reflect.ValueOf(value) + return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() +} diff --git a/clause/expression_test.go b/clause/expression_test.go index 83082486..9e3d7bad 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -101,3 +101,52 @@ func TestNamedExpr(t *testing.T) { }) } } + +func TestExpression(t *testing.T) { + column := "column-name" + results := []struct { + Expressions []clause.Expression + Result string + }{{ + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: "column-value"}, + }, + Result: "`column-name` = ?", + },{ + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: nil}, + clause.Eq{Column: column, Value: (*string)(nil)}, + clause.Eq{Column: column, Value: (*int)(nil)}, + clause.Eq{Column: column, Value: (*bool)(nil)}, + clause.Eq{Column: column, Value: (interface{})(nil)}, + }, + Result: "`column-name` IS NULL", + },{ + Expressions: []clause.Expression{ + clause.Neq{Column: column, Value: "column-value"}, + }, + Result: "`column-name` <> ?", + },{ + Expressions: []clause.Expression{ + clause.Neq{Column: column, Value: nil}, + clause.Neq{Column: column, Value: (*string)(nil)}, + clause.Neq{Column: column, Value: (*int)(nil)}, + clause.Neq{Column: column, Value: (*bool)(nil)}, + clause.Neq{Column: column, Value: (interface{})(nil)}, + }, + Result: "`column-name` IS NOT NULL", + }} + + for idx, result := range results { + for idy, expression := range result.Expressions { + t.Run(fmt.Sprintf("case #%v.%v", idx, idy), func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + expression.Build(stmt) + if stmt.SQL.String() != result.Result { + t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) + } + }) + } + } +} From 21c3f05aa2a6e36b63fa9b8d7f1b6f198bfcdc41 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 14 Dec 2020 18:30:43 +0800 Subject: [PATCH 752/881] Use transaction's conn when preparing statement --- prepare_stmt.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index eddee1f2..dbf21118 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -25,7 +25,7 @@ func (db *PreparedStmtDB) Close() { db.Mux.Unlock() } -func (db *PreparedStmtDB) prepare(ctx context.Context, query string) (*sql.Stmt, error) { +func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, query string) (*sql.Stmt, error) { db.Mux.RLock() if stmt, ok := db.Stmts[query]; ok { db.Mux.RUnlock() @@ -40,7 +40,7 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, query string) (*sql.Stmt, return stmt, nil } - stmt, err := db.ConnPool.PrepareContext(ctx, query) + stmt, err := conn.PrepareContext(ctx, query) if err == nil { db.Stmts[query] = stmt db.PreparedSQL = append(db.PreparedSQL, query) @@ -59,7 +59,7 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn } func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { - stmt, err := db.prepare(ctx, query) + stmt, err := db.prepare(ctx, db.ConnPool, query) if err == nil { result, err = stmt.ExecContext(ctx, args...) if err != nil { @@ -73,7 +73,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. } func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { - stmt, err := db.prepare(ctx, query) + stmt, err := db.prepare(ctx, db.ConnPool, query) if err == nil { rows, err = stmt.QueryContext(ctx, args...) if err != nil { @@ -87,7 +87,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . } func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - stmt, err := db.prepare(ctx, query) + stmt, err := db.prepare(ctx, db.ConnPool, query) if err == nil { return stmt.QueryRowContext(ctx, args...) } @@ -114,7 +114,7 @@ func (tx *PreparedStmtTX) Rollback() error { } func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { - stmt, err := tx.PreparedStmtDB.prepare(ctx, query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, query) if err == nil { result, err = tx.Tx.StmtContext(ctx, stmt).ExecContext(ctx, args...) if err != nil { @@ -128,7 +128,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. } func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { - stmt, err := tx.PreparedStmtDB.prepare(ctx, query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, query) if err == nil { rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) if err != nil { @@ -142,7 +142,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . } func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - stmt, err := tx.PreparedStmtDB.prepare(ctx, query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, query) if err == nil { return tx.Tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...) } From 14a0976dd4d4dcf12c10b4ce1431f5d54c31fde3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 15 Dec 2020 10:39:20 +0800 Subject: [PATCH 753/881] populate the DeletedAt field when soft delete, fix #3855 --- soft_delete.go | 4 +++- statement.go | 16 ++++++++++++++-- tests/delete_test.go | 2 +- tests/soft_delete_test.go | 5 +++++ 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/soft_delete.go b/soft_delete.go index cb56035d..bdbf03c2 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -104,7 +104,9 @@ func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { if stmt.SQL.String() == "" { - stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: stmt.DB.NowFunc()}}) + curTime := stmt.DB.NowFunc() + stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}}) + stmt.SetColumn(sd.Field.DBName, curTime, true) if stmt.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) diff --git a/statement.go b/statement.go index a0da0c6d..355a5f0b 100644 --- a/statement.go +++ b/statement.go @@ -447,9 +447,15 @@ func (stmt *Statement) clone() *Statement { // Helpers // SetColumn set column's value -func (stmt *Statement) SetColumn(name string, value interface{}) { +// stmt.SetColumn("Name", "jinzhu") // Hooks Method +// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method +func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) { if v, ok := stmt.Dest.(map[string]interface{}); ok { v[name] = value + } else if v, ok := stmt.Dest.([]map[string]interface{}); ok { + for _, m := range v { + m[name] = value + } } else if stmt.Schema != nil { if field := stmt.Schema.LookUpField(name); field != nil { destValue := reflect.ValueOf(stmt.Dest) @@ -475,7 +481,13 @@ func (stmt *Statement) SetColumn(name string, value interface{}) { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) + if len(fromCallbacks) > 0 { + for i := 0; i < stmt.ReflectValue.Len(); i++ { + field.Set(stmt.ReflectValue.Index(i), value) + } + } else { + field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) + } case reflect.Struct: field.Set(stmt.ReflectValue, value) } diff --git a/tests/delete_test.go b/tests/delete_test.go index 954c7097..37e29fbe 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -45,7 +45,7 @@ func TestDelete(t *testing.T) { } } - if err := DB.Delete(users[0]).Error; err != nil { + if err := DB.Delete(&users[0]).Error; err != nil { t.Errorf("errors happened when delete: %v", err) } diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go index f1ea8a51..0dfe24d5 100644 --- a/tests/soft_delete_test.go +++ b/tests/soft_delete_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "database/sql" "encoding/json" "errors" "regexp" @@ -29,6 +30,10 @@ func TestSoftDelete(t *testing.T) { t.Fatalf("No error should happen when soft delete user, but got %v", err) } + if sql.NullTime(user.DeletedAt).Time.IsZero() { + t.Fatalf("user's deleted at is zero") + } + sql := DB.Session(&gorm.Session{DryRun: true}).Delete(&user).Statement.SQL.String() if !regexp.MustCompile(`UPDATE .users. SET .deleted_at.=.* WHERE .users.\..id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) From 0f00493c505145aedd451115d2d0f8c9dcbe5980 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 15 Dec 2020 11:18:29 +0800 Subject: [PATCH 754/881] Continue to update tracking fields even not selected with Select, but skip them if omited with Omit, fix #3856 --- callbacks/create.go | 2 +- callbacks/update.go | 26 ++++++++++++-------------- tests/update_test.go | 4 +++- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 3ca56d73..052f3344 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -244,7 +244,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { for _, db := range stmt.Schema.DBNames { if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { - if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { + if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) { values.Columns = append(values.Columns, clause.Column{Name: db}) } } diff --git a/callbacks/update.go b/callbacks/update.go index c8f3922e..db5b52fb 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -202,7 +202,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.LookUpField(dbName) if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { - if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if v, ok := selectColumns[field.DBName]; (ok && v) || !ok { now := stmt.DB.NowFunc() assignValue(field, now) @@ -226,21 +226,19 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.LookUpField(dbName) if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { - if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { value, isZero := field.ValueOf(updatingValue) - if !stmt.SkipHooks { - if field.AutoUpdateTime > 0 { - if field.AutoUpdateTime == schema.UnixNanosecond { - value = stmt.DB.NowFunc().UnixNano() - } else if field.AutoUpdateTime == schema.UnixMillisecond { - value = stmt.DB.NowFunc().UnixNano() / 1e6 - } else if field.GORMDataType == schema.Time { - value = stmt.DB.NowFunc() - } else { - value = stmt.DB.NowFunc().Unix() - } - isZero = false + if !stmt.SkipHooks && field.AutoUpdateTime > 0 { + if field.AutoUpdateTime == schema.UnixNanosecond { + value = stmt.DB.NowFunc().UnixNano() + } else if field.AutoUpdateTime == schema.UnixMillisecond { + value = stmt.DB.NowFunc().UnixNano() / 1e6 + } else if field.GORMDataType == schema.Time { + value = stmt.DB.NowFunc() + } else { + value = stmt.DB.NowFunc().Unix() } + isZero = false } if ok || !isZero { diff --git a/tests/update_test.go b/tests/update_test.go index a660647c..df709cff 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -466,7 +466,9 @@ func TestSelectWithUpdateColumn(t *testing.T) { var result2 User DB.First(&result2, user.ID) - AssertEqual(t, lastUpdatedAt, result2.UpdatedAt) + if lastUpdatedAt.Format(time.RFC3339Nano) == result2.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("UpdatedAt should be changed") + } if result2.Name == user.Name || result2.Age != user.Age { t.Errorf("Should only update users with name column") From 6848ae872f1c139adb617d2311307e93b826b96a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 15 Dec 2020 15:35:11 +0800 Subject: [PATCH 755/881] Fix gorm.Expr with SubQuery, fix #3857 --- statement.go | 11 +---------- tests/create_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/statement.go b/statement.go index 355a5f0b..707e4aef 100644 --- a/statement.go +++ b/statement.go @@ -165,16 +165,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { case Valuer: stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) case clause.Expr: - var varStr strings.Builder - var sql = v.SQL - for _, arg := range v.Vars { - stmt.Vars = append(stmt.Vars, arg) - stmt.DB.Dialector.BindVarTo(&varStr, stmt, arg) - sql = strings.Replace(sql, "?", varStr.String(), 1) - varStr.Reset() - } - - writer.WriteString(sql) + v.Build(stmt) case driver.Valuer: stmt.Vars = append(stmt.Vars, v) stmt.DB.Dialector.BindVarTo(writer, stmt, v) diff --git a/tests/create_test.go b/tests/create_test.go index 170c8546..bd968ea8 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -2,6 +2,7 @@ package tests_test import ( "errors" + "regexp" "testing" "time" @@ -493,3 +494,26 @@ func TestFirstOrCreateWithPrimaryKey(t *testing.T) { t.Errorf("invalid primary key after creating, got %v, %v", companies[0].ID, companies[1].ID) } } + +func TestCreateFromSubQuery(t *testing.T) { + user := User{Name: "jinzhu"} + + DB.Create(&user) + + subQuery := DB.Table("users").Where("name=?", user.Name).Select("id") + + result := DB.Session(&gorm.Session{DryRun: true}).Model(&Pet{}).Create([]map[string]interface{}{ + { + "name": "cat", + "user_id": gorm.Expr("(?)", DB.Table("(?) as tmp", subQuery).Select("@uid:=id")), + }, + { + "name": "dog", + "user_id": gorm.Expr("@uid"), + }, + }) + + if !regexp.MustCompile(`INSERT INTO .pets. \(.name.,.user_id.\) .*VALUES \(.+,\(SELECT @uid:=id FROM \(SELECT id FROM .users. WHERE name=.+\) as tmp\)\),\(.+,@uid\)`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid insert SQL, got %v", result.Statement.SQL.String()) + } +} From 468152d45b66ab30091624f32f5b989204e04c40 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 16 Dec 2020 19:33:35 +0800 Subject: [PATCH 756/881] Add DisableNestedTransaction support --- finisher_api.go | 16 +++++----- gorm.go | 31 +++++++++++-------- tests/transaction_test.go | 63 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 19 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 98a877f2..03bcd20f 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -498,13 +498,15 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { // nested transaction - err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error - defer func() { - // Make sure to rollback when panic, Block error or Commit error - if panicked || err != nil { - db.RollbackTo(fmt.Sprintf("sp%p", fc)) - } - }() + if !db.DisableNestedTransaction { + err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + db.RollbackTo(fmt.Sprintf("sp%p", fc)) + } + }() + } if err == nil { err = fc(db.Session(&Session{})) diff --git a/gorm.go b/gorm.go index ae1cf2c9..ae94daf4 100644 --- a/gorm.go +++ b/gorm.go @@ -34,6 +34,8 @@ type Config struct { DisableAutomaticPing bool // DisableForeignKeyConstraintWhenMigrating DisableForeignKeyConstraintWhenMigrating bool + // DisableNestedTransaction disable nested transaction + DisableNestedTransaction bool // AllowGlobalUpdate allow global update AllowGlobalUpdate bool // QueryFields executes the SQL query with all fields of the table @@ -65,18 +67,19 @@ type DB struct { // Session session config when create session with Session() method type Session struct { - DryRun bool - PrepareStmt bool - NewDB bool - SkipHooks bool - SkipDefaultTransaction bool - AllowGlobalUpdate bool - FullSaveAssociations bool - QueryFields bool - Context context.Context - Logger logger.Interface - NowFunc func() time.Time - CreateBatchSize int + DryRun bool + PrepareStmt bool + NewDB bool + SkipHooks bool + SkipDefaultTransaction bool + DisableNestedTransaction bool + AllowGlobalUpdate bool + FullSaveAssociations bool + QueryFields bool + Context context.Context + Logger logger.Interface + NowFunc func() time.Time + CreateBatchSize int } // Open initialize db session based on dialector @@ -206,6 +209,10 @@ func (db *DB) Session(config *Session) *DB { tx.Statement.SkipHooks = true } + if config.DisableNestedTransaction { + txConfig.DisableNestedTransaction = true + } + if !config.NewDB { tx.clone = 2 } diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 334600b8..c17fea3b 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -283,6 +283,69 @@ func TestNestedTransactionWithBlock(t *testing.T) { } } +func TestDisabledNestedTransaction(t *testing.T) { + var ( + user = *GetUser("transaction-nested", Config{}) + user1 = *GetUser("transaction-nested-1", Config{}) + user2 = *GetUser("transaction-nested-2", Config{}) + ) + + if err := DB.Session(&gorm.Session{DisableNestedTransaction: true}).Transaction(func(tx *gorm.DB) error { + tx.Create(&user) + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.Transaction(func(tx1 *gorm.DB) error { + tx1.Create(&user1) + + if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return errors.New("rollback") + }); err == nil { + t.Fatalf("nested transaction should returns error") + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should not rollback record if disabled nested transaction support") + } + + if err := tx.Transaction(func(tx2 *gorm.DB) error { + tx2.Create(&user2) + + if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return nil + }); err != nil { + t.Fatalf("nested transaction returns error: %v", err) + } + + if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + return nil + }); err != nil { + t.Fatalf("no error should return, but got %v", err) + } + + if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := DB.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should not rollback record if disabled nested transaction support") + } + + if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } +} + func TestTransactionOnClosedConn(t *testing.T) { DB, err := OpenTestConnection() if err != nil { From 77bf4aecc6e5a156aff47b26a0dbb0dd4a31382a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 18 Dec 2020 13:25:52 +0800 Subject: [PATCH 757/881] Create associations w/o nested transaction option --- callbacks/associations.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 9e767e5e..f5c9e4be 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -296,7 +296,10 @@ func SaveAfterAssociations(db *gorm.DB) { } if joins.Len() > 0 { - db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()).Error) + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{ + SkipHooks: db.Statement.SkipHooks, + DisableNestedTransaction: true, + }).Create(joins.Interface()).Error) } } } @@ -355,7 +358,10 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, } } - tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks}) + tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{ + SkipHooks: db.Statement.SkipHooks, + DisableNestedTransaction: true, + }) if len(selects) > 0 { tx = tx.Select(selects) From 59730417aabd5b510d66d9d923d265a6fc0195a0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 23 Dec 2020 17:31:47 +0800 Subject: [PATCH 758/881] Fix auto migrate field with customized field type, close https://github.com/go-gorm/mysql/issues/20 --- migrator/migrator.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 084d430f..a475d307 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -381,7 +381,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy // check precision if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { - if strings.Contains(m.DataTypeOf(field), fmt.Sprint(field.Precision)) { + if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { alterColumn = true } } From ad8a5c0d1ace1b9608fdaaae920fe17ebb5cf32a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 25 Dec 2020 16:35:25 +0800 Subject: [PATCH 759/881] Add QueryFields mode when query many2many relations --- association.go | 2 +- tests/go.mod | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/association.go b/association.go index 7adb8c91..d93ff8ca 100644 --- a/association.go +++ b/association.go @@ -470,7 +470,7 @@ func (association *Association) buildCondition() *DB { tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) } - tx.Clauses(clause.From{Joins: []clause.Join{{ + tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{ Table: clause.Table{Name: association.Relationship.JoinTable.Table}, ON: clause.Where{Exprs: queryConds}, }}}) diff --git a/tests/go.mod b/tests/go.mod index 03283a53..f6912a0f 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,10 +8,10 @@ require ( github.com/lib/pq v1.6.0 github.com/stretchr/testify v1.5.1 gorm.io/driver/mysql v1.0.3 - gorm.io/driver/postgres v1.0.5 + gorm.io/driver/postgres v1.0.6 gorm.io/driver/sqlite v1.1.4 gorm.io/driver/sqlserver v1.0.5 - gorm.io/gorm v1.20.7 + gorm.io/gorm v1.20.8 ) replace gorm.io/gorm => ../ From ade0bd6d60950e0d64d2c34c7b0b2370a10abcf8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 28 Dec 2020 10:40:30 +0800 Subject: [PATCH 760/881] Fix SELECT with sql expression in some cases, close #3889 --- chainable_api.go | 2 +- tests/query_test.go | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index dca12b08..58b9336f 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -93,7 +93,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } delete(tx.Statement.Clauses, "SELECT") case string: - if (strings.Contains(v, " ?") || strings.Contains(v, "(?")) && len(args) > 0 { + if strings.Count(v, "?") >= len(args) && len(args) > 0 { tx.Statement.AddClause(clause.Select{ Distinct: db.Statement.Distinct, Expression: clause.Expr{SQL: v, Vars: args}, diff --git a/tests/query_test.go b/tests/query_test.go index af8bbf07..f1234d0a 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -612,11 +612,15 @@ func TestSelect(t *testing.T) { t.Fatalf("Build Select with slice, but got %v", r.Statement.SQL.String()) } + // SELECT COALESCE(age,'42') FROM users; r = dryDB.Table("users").Select("COALESCE(age,?)", 42).Find(&User{}) if !regexp.MustCompile(`SELECT COALESCE\(age,.*\) FROM .*users.*`).MatchString(r.Statement.SQL.String()) { t.Fatalf("Build Select with func, but got %v", r.Statement.SQL.String()) } - // SELECT COALESCE(age,'42') FROM users; + + if _, err := DB.Table("users").Select("COALESCE(age,?)", "42").Rows(); err != nil { + t.Fatalf("Failed, got error: %v", err) + } r = dryDB.Select("u.*").Table("users as u").First(&User{}, user.ID) if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) { From 8bf50a55927dbc74bd2168233f94dd957064bf8d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 28 Dec 2020 17:58:12 +0800 Subject: [PATCH 761/881] Fix parse relations if only specfied References, close #3890 --- schema/relationship.go | 14 +++++++++++++- schema/relationship_test.go | 38 +++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/schema/relationship.go b/schema/relationship.go index 19945e0f..18f04e1f 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -396,7 +396,19 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue } } } else { - for _, primaryField := range primarySchema.PrimaryFields { + var primaryFields []*Field + + if len(relation.primaryKeys) > 0 { + for _, primaryKey := range relation.primaryKeys { + if f := primarySchema.LookUpField(primaryKey); f != nil { + primaryFields = append(primaryFields, f) + } + } + } else { + primaryFields = primarySchema.PrimaryFields + } + + for _, primaryField := range primaryFields { lookUpName := primarySchema.Name + primaryField.Name if gl == guessBelongs { lookUpName = field.Name + primaryField.Name diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 7d7fd9c9..af2897b8 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -55,6 +55,25 @@ func TestBelongsToOverrideReferences(t *testing.T) { }) } +func TestBelongsToWithOnlyReferences(t *testing.T) { + type Profile struct { + gorm.Model + Refer string + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"References:Refer"` + ProfileRefer int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}}, + }) +} + func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) { type User struct { ID int32 `gorm:"primaryKey"` @@ -106,6 +125,25 @@ func TestHasOneOverrideReferences(t *testing.T) { }) } +func TestHasOneWithOnlyReferences(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Refer string + Profile Profile `gorm:"References:Refer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "User", "UserRefer", "Profile", "", true}}, + }) +} + func TestHasManyOverrideForeignKey(t *testing.T) { type Profile struct { gorm.Model From 065787c54ef80199482ef3d245de213e7f751423 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 28 Dec 2020 18:20:42 +0800 Subject: [PATCH 762/881] Compatible with with foreign key with ID suffix #3890 --- schema/relationship.go | 15 ++++++++++++--- schema/relationship_test.go | 38 +++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index 18f04e1f..4580fa53 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -414,9 +414,18 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue lookUpName = field.Name + primaryField.Name } - if f := foreignSchema.LookUpField(lookUpName); f != nil { - foreignFields = append(foreignFields, f) - primaryFields = append(primaryFields, primaryField) + lookUpNames := []string{lookUpName} + if len(primaryFields) == 1 { + lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID") + lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"Id") + } + + for _, name := range lookUpNames { + if f := foreignSchema.LookUpField(name); f != nil { + foreignFields = append(foreignFields, f) + primaryFields = append(primaryFields, primaryField) + break + } } } } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index af2897b8..887e1341 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -74,6 +74,25 @@ func TestBelongsToWithOnlyReferences(t *testing.T) { }) } +func TestBelongsToWithOnlyReferences2(t *testing.T) { + type Profile struct { + gorm.Model + Refer string + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"References:Refer"` + ProfileID int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "Profile", "ProfileID", "User", "", false}}, + }) +} + func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) { type User struct { ID int32 `gorm:"primaryKey"` @@ -144,6 +163,25 @@ func TestHasOneWithOnlyReferences(t *testing.T) { }) } +func TestHasOneWithOnlyReferences2(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserID uint + } + + type User struct { + gorm.Model + Refer string + Profile Profile `gorm:"References:Refer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "User", "UserID", "Profile", "", true}}, + }) +} + func TestHasManyOverrideForeignKey(t *testing.T) { type Profile struct { gorm.Model From 6c0ee2700a1282fe0e2eb669cf57641f01fcf9bc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 30 Dec 2020 10:42:13 +0800 Subject: [PATCH 763/881] Allow to use Valuer with Eq expression, #3899 --- clause/expression.go | 4 ++++ clause/expression_test.go | 11 ++++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 3844d66b..7a4c09f4 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -301,6 +301,10 @@ func (like Like) NegationBuild(builder Builder) { } func eqNil(value interface{}) bool { + if valuer, ok := value.(driver.Valuer); ok { + value, _ = valuer.Value() + } + return value == nil || eqNilReflect(value) } diff --git a/clause/expression_test.go b/clause/expression_test.go index 9e3d7bad..4472bdb1 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -105,28 +105,29 @@ func TestNamedExpr(t *testing.T) { func TestExpression(t *testing.T) { column := "column-name" results := []struct { - Expressions []clause.Expression - Result string + Expressions []clause.Expression + Result string }{{ Expressions: []clause.Expression{ clause.Eq{Column: column, Value: "column-value"}, }, Result: "`column-name` = ?", - },{ + }, { Expressions: []clause.Expression{ clause.Eq{Column: column, Value: nil}, clause.Eq{Column: column, Value: (*string)(nil)}, clause.Eq{Column: column, Value: (*int)(nil)}, clause.Eq{Column: column, Value: (*bool)(nil)}, clause.Eq{Column: column, Value: (interface{})(nil)}, + clause.Eq{Column: column, Value: sql.NullString{String: "", Valid: false}}, }, Result: "`column-name` IS NULL", - },{ + }, { Expressions: []clause.Expression{ clause.Neq{Column: column, Value: "column-value"}, }, Result: "`column-name` <> ?", - },{ + }, { Expressions: []clause.Expression{ clause.Neq{Column: column, Value: nil}, clause.Neq{Column: column, Value: (*string)(nil)}, From 79864af9ffee6e12051f6bbdfaab31df77f3bc61 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 30 Dec 2020 11:16:40 +0800 Subject: [PATCH 764/881] Allow customize auto increment increment --- callbacks/create.go | 4 +- schema/field.go | 92 ++++++++++++++++++++++++--------------------- 2 files changed, 51 insertions(+), 45 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 052f3344..9166eb67 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -71,7 +71,7 @@ func Create(config *Config) func(db *gorm.DB) { _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) if isZero { db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID-- + insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } } else { @@ -83,7 +83,7 @@ func Create(config *Config) func(db *gorm.DB) { if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID++ + insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } } diff --git a/schema/field.go b/schema/field.go index 86b4a061..17cc6c43 100644 --- a/schema/field.go +++ b/schema/field.go @@ -37,55 +37,57 @@ const ( ) type Field struct { - Name string - DBName string - BindNames []string - DataType DataType - GORMDataType DataType - PrimaryKey bool - AutoIncrement bool - Creatable bool - Updatable bool - Readable bool - HasDefaultValue bool - AutoCreateTime TimeType - AutoUpdateTime TimeType - DefaultValue string - DefaultValueInterface interface{} - NotNull bool - Unique bool - Comment string - Size int - Precision int - Scale int - FieldType reflect.Type - IndirectFieldType reflect.Type - StructField reflect.StructField - Tag reflect.StructTag - TagSettings map[string]string - Schema *Schema - EmbeddedSchema *Schema - OwnerSchema *Schema - ReflectValueOf func(reflect.Value) reflect.Value - ValueOf func(reflect.Value) (value interface{}, zero bool) - Set func(reflect.Value, interface{}) error + Name string + DBName string + BindNames []string + DataType DataType + GORMDataType DataType + PrimaryKey bool + AutoIncrement bool + AutoIncrementIncrement int64 + Creatable bool + Updatable bool + Readable bool + HasDefaultValue bool + AutoCreateTime TimeType + AutoUpdateTime TimeType + DefaultValue string + DefaultValueInterface interface{} + NotNull bool + Unique bool + Comment string + Size int + Precision int + Scale int + FieldType reflect.Type + IndirectFieldType reflect.Type + StructField reflect.StructField + Tag reflect.StructTag + TagSettings map[string]string + Schema *Schema + EmbeddedSchema *Schema + OwnerSchema *Schema + ReflectValueOf func(reflect.Value) reflect.Value + ValueOf func(reflect.Value) (value interface{}, zero bool) + Set func(reflect.Value, interface{}) error } func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { var err error field := &Field{ - Name: fieldStruct.Name, - BindNames: []string{fieldStruct.Name}, - FieldType: fieldStruct.Type, - IndirectFieldType: fieldStruct.Type, - StructField: fieldStruct, - Creatable: true, - Updatable: true, - Readable: true, - Tag: fieldStruct.Tag, - TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"), - Schema: schema, + Name: fieldStruct.Name, + BindNames: []string{fieldStruct.Name}, + FieldType: fieldStruct.Type, + IndirectFieldType: fieldStruct.Type, + StructField: fieldStruct, + Creatable: true, + Updatable: true, + Readable: true, + Tag: fieldStruct.Tag, + TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"), + Schema: schema, + AutoIncrementIncrement: 1, } for field.IndirectFieldType.Kind() == reflect.Ptr { @@ -149,6 +151,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.HasDefaultValue = true } + if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok { + field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64) + } + if v, ok := field.TagSettings["DEFAULT"]; ok { field.HasDefaultValue = true field.DefaultValue = v From 1b8cb07cf29e1154778bcf063ddbeb095d4f93e5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 30 Dec 2020 17:42:27 +0800 Subject: [PATCH 765/881] Allow Where select fields when searching with struct --- statement.go | 26 +++++++++++++++++++++----- tests/query_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/statement.go b/statement.go index 707e4aef..9433f4a7 100644 --- a/statement.go +++ b/statement.go @@ -250,7 +250,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] conds := make([]clause.Expression, 0, 4) args = append([]interface{}{query}, args...) - for _, arg := range args { + for idx, arg := range args { if valuer, ok := arg.(driver.Valuer); ok { arg, _ = valuer.Value() } @@ -310,11 +310,22 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] default: reflectValue := reflect.Indirect(reflect.ValueOf(arg)) if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + selectedColumns := map[string]bool{} + if idx == 0 { + for _, v := range args[1:] { + if vs, ok := v.(string); ok { + selectedColumns[vs] = true + } + } + } + restricted := len(selectedColumns) != 0 + switch reflectValue.Kind() { case reflect.Struct: for _, field := range s.Fields { - if field.Readable { - if v, isZero := field.ValueOf(reflectValue); !isZero { + selected := selectedColumns[field.DBName] || selectedColumns[field.Name] + if selected || (!restricted && field.Readable) { + if v, isZero := field.ValueOf(reflectValue); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { @@ -326,8 +337,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { for _, field := range s.Fields { - if field.Readable { - if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { + selected := selectedColumns[field.DBName] || selectedColumns[field.Name] + if selected || (!restricted && field.Readable) { + if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { @@ -338,6 +350,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } } + + if restricted { + break + } } else if len(conds) == 0 { if len(args) == 1 { switch reflectValue.Kind() { diff --git a/tests/query_test.go b/tests/query_test.go index f1234d0a..50522f71 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -921,6 +921,30 @@ func TestSearchWithMap(t *testing.T) { } } +func TestSearchWithStruct(t *testing.T) { + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + result := dryRunDB.Where(User{Name: "jinzhu"}).Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } + + result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } + + result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } + + result = dryRunDB.Where(User{Name: "jinzhu"}, "age").Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } +} + func TestSubQuery(t *testing.T) { users := []User{ {Name: "subquery_1", Age: 10}, From 9b8d3b3a0f5ed987fc8cee9b19f8a00edd6e49db Mon Sep 17 00:00:00 2001 From: Philip Sahli Date: Mon, 4 Jan 2021 04:30:05 +0100 Subject: [PATCH 766/881] fix typo (#3911) --- clause/clause.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clause/clause.go b/clause/clause.go index d413d0ee..828d2cf2 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -7,7 +7,7 @@ type Interface interface { MergeClause(*Clause) } -// ClauseBuilder clause builder, allows to custmize how to build clause +// ClauseBuilder clause builder, allows to customize how to build clause type ClauseBuilder func(Clause, Builder) type Writer interface { From 60b769c2c8ab57eee310d86de11ec6c65b7b21d8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 4 Jan 2021 15:13:56 +0800 Subject: [PATCH 767/881] OnConflict UpdateAll includes fields that specified default values via tag --- callbacks/create.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callbacks/create.go b/callbacks/create.go index 9166eb67..7bc45a6c 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -337,7 +337,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { columns := make([]string, 0, len(values.Columns)-1) for _, column := range values.Columns { if field := stmt.Schema.LookUpField(column.Name); field != nil { - if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 { + if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { columns = append(columns, column.Name) } } From 00a785cd68d4ec24e84a191afccd725f8f62c196 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 5 Jan 2021 18:01:51 +0800 Subject: [PATCH 768/881] Don't use invalid value to build conditions, close #3912 --- statement.go | 89 ++++++++++++++++++++++++++-------------------------- 1 file changed, 45 insertions(+), 44 deletions(-) diff --git a/statement.go b/statement.go index 9433f4a7..5dd3a584 100644 --- a/statement.go +++ b/statement.go @@ -308,38 +308,24 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } default: - reflectValue := reflect.Indirect(reflect.ValueOf(arg)) - if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { - selectedColumns := map[string]bool{} - if idx == 0 { - for _, v := range args[1:] { - if vs, ok := v.(string); ok { - selectedColumns[vs] = true - } - } - } - restricted := len(selectedColumns) != 0 - - switch reflectValue.Kind() { - case reflect.Struct: - for _, field := range s.Fields { - selected := selectedColumns[field.DBName] || selectedColumns[field.Name] - if selected || (!restricted && field.Readable) { - if v, isZero := field.ValueOf(reflectValue); !isZero || selected { - if field.DBName != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) - } else if field.DataType != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) - } + if reflectValue := reflect.Indirect(reflect.ValueOf(arg)); reflectValue.IsValid() { + if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + selectedColumns := map[string]bool{} + if idx == 0 { + for _, v := range args[1:] { + if vs, ok := v.(string); ok { + selectedColumns[vs] = true } } } - case reflect.Slice, reflect.Array: - for i := 0; i < reflectValue.Len(); i++ { + restricted := len(selectedColumns) != 0 + + switch reflectValue.Kind() { + case reflect.Struct: for _, field := range s.Fields { selected := selectedColumns[field.DBName] || selectedColumns[field.Name] if selected || (!restricted && field.Readable) { - if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected { + if v, isZero := field.ValueOf(reflectValue); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { @@ -348,29 +334,44 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } } - } - } - - if restricted { - break - } - } else if len(conds) == 0 { - if len(args) == 1 { - switch reflectValue.Kind() { case reflect.Slice, reflect.Array: - values := make([]interface{}, reflectValue.Len()) for i := 0; i < reflectValue.Len(); i++ { - values[i] = reflectValue.Index(i).Interface() + for _, field := range s.Fields { + selected := selectedColumns[field.DBName] || selectedColumns[field.Name] + if selected || (!restricted && field.Readable) { + if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected { + if field.DBName != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + } + } + } + } } - - if len(values) > 0 { - conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) - } - return conds } - } - conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) + if restricted { + break + } + } else if len(conds) == 0 { + if len(args) == 1 { + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + values := make([]interface{}, reflectValue.Len()) + for i := 0; i < reflectValue.Len(); i++ { + values[i] = reflectValue.Index(i).Interface() + } + + if len(values) > 0 { + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) + } + return conds + } + } + + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) + } } } } From 53b3ebdd1d6a06bb0dcafffdaaf0883fad84a216 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 5 Jan 2021 21:01:16 +0800 Subject: [PATCH 769/881] Add invalid data error when building conditions --- statement.go | 97 ++++++++++++++++++++++++++-------------------------- 1 file changed, 49 insertions(+), 48 deletions(-) diff --git a/statement.go b/statement.go index 5dd3a584..3617d7ed 100644 --- a/statement.go +++ b/statement.go @@ -308,24 +308,38 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } default: - if reflectValue := reflect.Indirect(reflect.ValueOf(arg)); reflectValue.IsValid() { - if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { - selectedColumns := map[string]bool{} - if idx == 0 { - for _, v := range args[1:] { - if vs, ok := v.(string); ok { - selectedColumns[vs] = true + reflectValue := reflect.Indirect(reflect.ValueOf(arg)) + if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + selectedColumns := map[string]bool{} + if idx == 0 { + for _, v := range args[1:] { + if vs, ok := v.(string); ok { + selectedColumns[vs] = true + } + } + } + restricted := len(selectedColumns) != 0 + + switch reflectValue.Kind() { + case reflect.Struct: + for _, field := range s.Fields { + selected := selectedColumns[field.DBName] || selectedColumns[field.Name] + if selected || (!restricted && field.Readable) { + if v, isZero := field.ValueOf(reflectValue); !isZero || selected { + if field.DBName != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + } } } } - restricted := len(selectedColumns) != 0 - - switch reflectValue.Kind() { - case reflect.Struct: + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { for _, field := range s.Fields { selected := selectedColumns[field.DBName] || selectedColumns[field.Name] if selected || (!restricted && field.Readable) { - if v, isZero := field.ValueOf(reflectValue); !isZero || selected { + if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { @@ -334,44 +348,31 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } } - case reflect.Slice, reflect.Array: - for i := 0; i < reflectValue.Len(); i++ { - for _, field := range s.Fields { - selected := selectedColumns[field.DBName] || selectedColumns[field.Name] - if selected || (!restricted && field.Readable) { - if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected { - if field.DBName != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) - } else if field.DataType != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) - } - } - } - } - } } - - if restricted { - break - } - } else if len(conds) == 0 { - if len(args) == 1 { - switch reflectValue.Kind() { - case reflect.Slice, reflect.Array: - values := make([]interface{}, reflectValue.Len()) - for i := 0; i < reflectValue.Len(); i++ { - values[i] = reflectValue.Index(i).Interface() - } - - if len(values) > 0 { - conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) - } - return conds - } - } - - conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) } + + if restricted { + break + } + } else if !reflectValue.IsValid() { + stmt.AddError(ErrInvalidData) + } else if len(conds) == 0 { + if len(args) == 1 { + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + values := make([]interface{}, reflectValue.Len()) + for i := 0; i < reflectValue.Len(); i++ { + values[i] = reflectValue.Index(i).Interface() + } + + if len(values) > 0 { + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) + } + return conds + } + } + + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) } } } From 6d260a86bdcaf3076edbd60b4870dabcffe92396 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 5 Jan 2021 21:12:31 +0800 Subject: [PATCH 770/881] Fix Set/Get settings when saving associations, close #3908 --- callbacks/associations.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/callbacks/associations.go b/callbacks/associations.go index f5c9e4be..7b01247e 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -363,6 +363,11 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, DisableNestedTransaction: true, }) + db.Statement.Settings.Range(func(k, v interface{}) bool { + tx.Statement.Settings.Store(k, v) + return true + }) + if len(selects) > 0 { tx = tx.Select(selects) } From 435bf7086589a69361f5063348ec38768149d071 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 5 Jan 2021 21:31:51 +0800 Subject: [PATCH 771/881] Add OnConflict OnConstraint support, close #3882 --- clause/on_conflict.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/clause/on_conflict.go b/clause/on_conflict.go index 47fe169c..5ecd8e93 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -1,11 +1,12 @@ package clause type OnConflict struct { - Columns []Column - Where Where - DoNothing bool - DoUpdates Set - UpdateAll bool + Columns []Column + Where Where + OnConstraint string + DoNothing bool + DoUpdates Set + UpdateAll bool } func (OnConflict) Name() string { @@ -31,6 +32,12 @@ func (onConflict OnConflict) Build(builder Builder) { builder.WriteByte(' ') } + if onConflict.OnConstraint != "" { + builder.WriteString("ON CONSTRAINT ") + builder.WriteString(onConflict.OnConstraint) + builder.WriteByte(' ') + } + if onConflict.DoNothing { builder.WriteString("DO NOTHING") } else { From 5e72cd9a2b276c0addc5f102b0a444798481576a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 6 Jan 2021 14:42:42 +0800 Subject: [PATCH 772/881] Add ErrPrimaryKeyRequired if schema has no primary key defined --- finisher_api.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 03bcd20f..73424dc2 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -178,8 +178,13 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat break } else { resultsValue := reflect.Indirect(reflect.ValueOf(dest)) - primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) - queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) + if result.Statement.Schema.PrioritizedPrimaryField == nil { + tx.AddError(ErrPrimaryKeyRequired) + break + } else { + primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) + queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) + } } } From bf0fd9bef62ee91509abe995d3317f2138f869e8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 6 Jan 2021 16:07:19 +0800 Subject: [PATCH 773/881] Fix logger check LogLevel --- logger/logger.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/logger/logger.go b/logger/logger.go index 11619c92..1206cf90 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -135,7 +135,7 @@ func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { // Trace print sql message func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { - if l.LogLevel > 0 { + if l.LogLevel > Silent { elapsed := time.Since(begin) switch { case err != nil && l.LogLevel >= Error: @@ -153,7 +153,7 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i } else { l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql) } - case l.LogLevel >= Info: + default: sql, rows := fc() if rows == -1 { l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) From a5bfe2f39dab84fb3a51b3e6893469f4867c235d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 7 Jan 2021 11:45:40 +0800 Subject: [PATCH 774/881] Keep Error for new Session --- gorm.go | 1 + 1 file changed, 1 insertion(+) diff --git a/gorm.go b/gorm.go index ae94daf4..488e74e7 100644 --- a/gorm.go +++ b/gorm.go @@ -163,6 +163,7 @@ func (db *DB) Session(config *Session) *DB { tx = &DB{ Config: &txConfig, Statement: db.Statement, + Error: db.Error, clone: 1, } ) From d888c799d774872162d8580dfe2feb986a87fb8b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 8 Jan 2021 18:47:06 +0800 Subject: [PATCH 775/881] Change UpdatedAt to current time when doing OnConflict UpdateAll --- callbacks/create.go | 5 +++++ finisher_api.go | 2 +- tests/update_test.go | 12 ++++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/callbacks/create.go b/callbacks/create.go index 7bc45a6c..634f402b 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -278,6 +278,11 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { field.Set(rv, curTime) values.Values[i][idx], _ = field.ValueOf(rv) } + } else if field.AutoUpdateTime > 0 { + if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok { + field.Set(rv, curTime) + values.Values[0][idx], _ = field.ValueOf(rv) + } } } diff --git a/finisher_api.go b/finisher_api.go index 73424dc2..7dfb72c6 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -70,7 +70,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) } - tx.callbacks.Create().Execute(tx) + tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true)) case reflect.Struct: if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { for _, pf := range tx.Statement.Schema.PrimaryFields { diff --git a/tests/update_test.go b/tests/update_test.go index df709cff..be3e6fc9 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -606,6 +606,18 @@ func TestSave(t *testing.T) { t.Fatalf("failed to find updated user") } + user2 := *GetUser("save2", Config{}) + DB.Create(&user2) + + time.Sleep(time.Second) + user1UpdatedAt := result.UpdatedAt + var users = []*User{&result, &user2} + DB.Save(&users) + + if user1UpdatedAt == result.UpdatedAt { + t.Fatalf("user's updated at should be changed, expects: %+v, got: %+v", user1UpdatedAt, result.UpdatedAt) + } + dryDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryDB.Save(&user).Statement if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(stmt.SQL.String()) { From f9131e309d0464e409f6107556297469e7dbf8fb Mon Sep 17 00:00:00 2001 From: Qt Date: Sun, 10 Jan 2021 10:15:48 +0800 Subject: [PATCH 776/881] reduce DB's Use method complexity and make it easier to understand (#3930) --- gorm.go | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/gorm.go b/gorm.go index 488e74e7..355a0e55 100644 --- a/gorm.go +++ b/gorm.go @@ -380,15 +380,14 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac return nil } -func (db *DB) Use(plugin Plugin) (err error) { +func (db *DB) Use(plugin Plugin) error { name := plugin.Name() - if _, ok := db.Plugins[name]; !ok { - if err = plugin.Initialize(db); err == nil { - db.Plugins[name] = plugin - } - } else { + if _, ok := db.Plugins[name]; ok { return ErrRegistered } - - return err + if err := plugin.Initialize(db); err != nil { + return err + } + db.Plugins[name] = plugin + return nil } From 7ebb320f3ec98333603e213bcda6fb0d13a2c412 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 10 Jan 2021 14:58:54 +0800 Subject: [PATCH 777/881] Allow customize join table's table in callback --- callbacks/preload.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index 682427c9..5c56d851 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -49,7 +49,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } joinResults := rel.JoinTable.MakeSlice().Elem() - column, values := schema.ToQueryValues(rel.JoinTable.Table, joinForeignKeys, joinForeignValues) + column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues) db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error) // convert join identity map to relation identity map From 7302c8a136ea18ea184bc966329f26cdcaec0dc9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 10 Jan 2021 15:27:53 +0800 Subject: [PATCH 778/881] Fix tests and logger --- logger/logger.go | 2 +- tests/update_test.go | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/logger/logger.go b/logger/logger.go index 1206cf90..cd6bf57f 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -153,7 +153,7 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i } else { l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql) } - default: + case l.LogLevel == Info: sql, rows := fc() if rows == -1 { l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) diff --git a/tests/update_test.go b/tests/update_test.go index be3e6fc9..c6764207 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -148,13 +148,17 @@ func TestUpdates(t *testing.T) { CheckUser(t, user2, *users[1]) // update with struct + time.Sleep(1 * time.Second) DB.Table("users").Where("name in ?", []string{users[1].Name}).Updates(User{Name: "updates_02_newname"}) var user3 User if err := DB.First(&user3, "name = ?", "updates_02_newname").Error; err != nil { t.Errorf("User2's name should be updated") } - AssertEqual(t, user2.UpdatedAt, user3.UpdatedAt) + + if user2.UpdatedAt.Format(time.RFC1123) == user3.UpdatedAt.Format(time.RFC1123) { + t.Errorf("User's updated at should be changed, old %v, new %v", user2.UpdatedAt.Format(time.RFC1123), user3.UpdatedAt.Format(time.RFC1123)) + } // update with gorm exprs if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { From fe553a7c1ac97b81dc3e70fb4cc96fbad1461f16 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 10 Jan 2021 16:46:06 +0800 Subject: [PATCH 779/881] Fix prepared statement in transaction mode can't be shared in normal operations, close #3927 --- gorm.go | 2 +- prepare_stmt.go | 37 +++++++++++++++++++++--------------- tests/prepared_stmt_test.go | 38 +++++++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 16 deletions(-) diff --git a/gorm.go b/gorm.go index 355a0e55..88885407 100644 --- a/gorm.go +++ b/gorm.go @@ -126,7 +126,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { preparedStmt := &PreparedStmtDB{ ConnPool: db.ConnPool, - Stmts: map[string]*sql.Stmt{}, + Stmts: map[string]Stmt{}, Mux: &sync.RWMutex{}, PreparedSQL: make([]string, 0, 100), } diff --git a/prepare_stmt.go b/prepare_stmt.go index dbf21118..78a8adb4 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -6,8 +6,13 @@ import ( "sync" ) +type Stmt struct { + *sql.Stmt + Transaction bool +} + type PreparedStmtDB struct { - Stmts map[string]*sql.Stmt + Stmts map[string]Stmt PreparedSQL []string Mux *sync.RWMutex ConnPool @@ -25,9 +30,9 @@ func (db *PreparedStmtDB) Close() { db.Mux.Unlock() } -func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, query string) (*sql.Stmt, error) { +func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { db.Mux.RLock() - if stmt, ok := db.Stmts[query]; ok { + if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { db.Mux.RUnlock() return stmt, nil } @@ -35,19 +40,21 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, query stri db.Mux.Lock() // double check - if stmt, ok := db.Stmts[query]; ok { + if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { db.Mux.Unlock() return stmt, nil + } else if ok { + stmt.Close() } stmt, err := conn.PrepareContext(ctx, query) if err == nil { - db.Stmts[query] = stmt + db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction} db.PreparedSQL = append(db.PreparedSQL, query) } db.Mux.Unlock() - return stmt, err + return db.Stmts[query], err } func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { @@ -59,7 +66,7 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn } func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { - stmt, err := db.prepare(ctx, db.ConnPool, query) + stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { result, err = stmt.ExecContext(ctx, args...) if err != nil { @@ -73,7 +80,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. } func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { - stmt, err := db.prepare(ctx, db.ConnPool, query) + stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { rows, err = stmt.QueryContext(ctx, args...) if err != nil { @@ -87,7 +94,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . } func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - stmt, err := db.prepare(ctx, db.ConnPool, query) + stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { return stmt.QueryRowContext(ctx, args...) } @@ -114,9 +121,9 @@ func (tx *PreparedStmtTX) Rollback() error { } func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { - stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { - result, err = tx.Tx.StmtContext(ctx, stmt).ExecContext(ctx, args...) + result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() stmt.Close() @@ -128,9 +135,9 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. } func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { - stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { - rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) + rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() stmt.Close() @@ -142,9 +149,9 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . } func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { - return tx.Tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...) + return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...) } return &sql.Row{} } diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 6b10b6dc..8730e547 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -50,3 +50,41 @@ func TestPreparedStmt(t *testing.T) { t.Fatalf("no error should happen but got %v", err) } } + +func TestPreparedStmtFromTransaction(t *testing.T) { + db := DB.Session(&gorm.Session{PrepareStmt: true, SkipDefaultTransaction: true}) + + tx := db.Begin() + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() + if err := tx.Error; err != nil { + t.Errorf("Failed to start transaction, got error %v\n", err) + } + + if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil { + tx.Rollback() + t.Errorf("Failed to run one transaction, got error %v\n", err) + } + + if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil { + tx.Rollback() + t.Errorf("Failed to run one transaction, got error %v\n", err) + } + + if err := tx.Commit().Error; err != nil { + t.Errorf("Failed to commit transaction, got error %v\n", err) + } + + if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 { + t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) + } + + tx2 := db.Begin() + if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 { + t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) + } + tx2.Commit() +} From b864a5457a59ddfca3dae0f6b11de7443633392b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 10 Jan 2021 17:32:17 +0800 Subject: [PATCH 780/881] Allow foreign key following the default naming conventions, close #3928 --- schema/relationship.go | 1 + 1 file changed, 1 insertion(+) diff --git a/schema/relationship.go b/schema/relationship.go index 4580fa53..ae0e0b2b 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -418,6 +418,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue if len(primaryFields) == 1 { lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID") lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"Id") + lookUpNames = append(lookUpNames, schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) } for _, name := range lookUpNames { From de850edb4f87ab713070fcf9788d0d702a644e56 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 11 Jan 2021 19:16:47 +0800 Subject: [PATCH 781/881] Fix Change UpdatedAt to current time when doing OnConflict UpdateAll --- callbacks/create.go | 2 +- tests/update_test.go | 21 ++++++++++++++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 634f402b..5656b861 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -281,7 +281,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } else if field.AutoUpdateTime > 0 { if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok { field.Set(rv, curTime) - values.Values[0][idx], _ = field.ValueOf(rv) + values.Values[i][idx], _ = field.ValueOf(rv) } } } diff --git a/tests/update_test.go b/tests/update_test.go index c6764207..5ad1bb39 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -156,8 +156,8 @@ func TestUpdates(t *testing.T) { t.Errorf("User2's name should be updated") } - if user2.UpdatedAt.Format(time.RFC1123) == user3.UpdatedAt.Format(time.RFC1123) { - t.Errorf("User's updated at should be changed, old %v, new %v", user2.UpdatedAt.Format(time.RFC1123), user3.UpdatedAt.Format(time.RFC1123)) + if user2.UpdatedAt.Format(time.RFC1123Z) == user3.UpdatedAt.Format(time.RFC1123Z) { + t.Errorf("User's updated at should be changed, old %v, new %v", user2.UpdatedAt.Format(time.RFC1123Z), user3.UpdatedAt.Format(time.RFC1123Z)) } // update with gorm exprs @@ -615,13 +615,28 @@ func TestSave(t *testing.T) { time.Sleep(time.Second) user1UpdatedAt := result.UpdatedAt + user2UpdatedAt := user2.UpdatedAt var users = []*User{&result, &user2} DB.Save(&users) - if user1UpdatedAt == result.UpdatedAt { + if user1UpdatedAt.Format(time.RFC1123Z) == result.UpdatedAt.Format(time.RFC1123Z) { t.Fatalf("user's updated at should be changed, expects: %+v, got: %+v", user1UpdatedAt, result.UpdatedAt) } + if user2UpdatedAt.Format(time.RFC1123Z) == user2.UpdatedAt.Format(time.RFC1123Z) { + t.Fatalf("user's updated at should be changed, expects: %+v, got: %+v", user2UpdatedAt, user2.UpdatedAt) + } + + DB.First(&result) + if user1UpdatedAt.Format(time.RFC1123Z) == result.UpdatedAt.Format(time.RFC1123Z) { + t.Fatalf("user's updated at should be changed after reload, expects: %+v, got: %+v", user1UpdatedAt, result.UpdatedAt) + } + + DB.First(&user2) + if user2UpdatedAt.Format(time.RFC1123Z) == user2.UpdatedAt.Format(time.RFC1123Z) { + t.Fatalf("user2's updated at should be changed after reload, expects: %+v, got: %+v", user2UpdatedAt, user2.UpdatedAt) + } + dryDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryDB.Save(&user).Statement if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(stmt.SQL.String()) { From ce610a9560f3b8651c50f6355b0b8b6c9ad8d3bc Mon Sep 17 00:00:00 2001 From: Lisa Casner Date: Tue, 12 Jan 2021 21:05:05 -0800 Subject: [PATCH 782/881] title case schema name (#3940) --- schema/relationship.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index ae0e0b2b..b2253035 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -219,7 +219,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } for idx, ownField := range ownForeignFields { - joinFieldName := schema.Name + ownField.Name + joinFieldName := strings.Title(schema.Name) + ownField.Name if len(joinForeignKeys) > idx { joinFieldName = strings.Title(joinForeignKeys[idx]) } @@ -258,7 +258,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } joinTableFields = append(joinTableFields, reflect.StructField{ - Name: schema.Name + field.Name, + Name: strings.Title(schema.Name) + field.Name, Type: schema.ModelType, Tag: `gorm:"-"`, }) From 79628be2c22a3d383dbe15d10796cad0b998d734 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 14 Jan 2021 16:01:23 +0800 Subject: [PATCH 783/881] Fix wrong RowsAffected if not data found --- finisher_api.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/finisher_api.go b/finisher_api.go index 7dfb72c6..7424a9cb 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -446,6 +446,8 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { defer rows.Close() if rows.Next() { tx.ScanRows(rows, dest) + } else { + tx.RowsAffected = 0 } } From 59fa07953cf43385587677f106bb5e522621dca1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 15 Jan 2021 17:15:59 +0800 Subject: [PATCH 784/881] Preload with settings, close #3945 --- callbacks/preload.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/callbacks/preload.go b/callbacks/preload.go index 5c56d851..3614346f 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -22,6 +22,11 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { inlineConds []interface{} ) + db.Statement.Settings.Range(func(k, v interface{}) bool { + tx.Statement.Settings.Store(k, v) + return true + }) + if len(rels) > 1 { reflectValue = schema.GetRelationsValues(reflectValue, rels[:len(rels)-1]) } From 4a15540504db9a7e1ecf69bb2a88bdb7097f6d1a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 18 Jan 2021 11:43:42 +0800 Subject: [PATCH 785/881] SkipDefaultTransaction skip CreateInBatches transaction --- callbacks/transaction.go | 2 +- finisher_api.go | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/callbacks/transaction.go b/callbacks/transaction.go index 3171b5bb..45c6ca11 100644 --- a/callbacks/transaction.go +++ b/callbacks/transaction.go @@ -9,7 +9,7 @@ func BeginTransaction(db *gorm.DB) { if tx := db.Begin(); tx.Error == nil { db.Statement.ConnPool = tx.Statement.ConnPool db.InstanceSet("gorm:started_transaction", true) - } else { + } else if tx.Error == gorm.ErrInvalidTransaction { tx.Error = nil } } diff --git a/finisher_api.go b/finisher_api.go index 7424a9cb..528f32be 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -33,7 +33,8 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { case reflect.Slice, reflect.Array: var rowsAffected int64 tx = db.getInstance() - tx.AddError(tx.Transaction(func(tx *DB) error { + + callFc := func(tx *DB) error { for i := 0; i < reflectValue.Len(); i += batchSize { ends := i + batchSize if ends > reflectValue.Len() { @@ -49,7 +50,14 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { rowsAffected += subtx.RowsAffected } return nil - })) + } + + if tx.SkipDefaultTransaction { + tx.AddError(callFc(tx.Session(&Session{}))) + } else { + tx.AddError(tx.Transaction(callFc)) + } + tx.RowsAffected = rowsAffected default: tx = db.getInstance() From 3d87575e7efd2b42d6f02b5b04a8179d49b46073 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 18 Jan 2021 19:43:04 +0800 Subject: [PATCH 786/881] make Count compatible with Select with Count func, close #3962 --- finisher_api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finisher_api.go b/finisher_api.go index 528f32be..e757bfe9 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -378,7 +378,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 0 { tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) - } else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") { + } else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") { expr := clause.Expr{SQL: "count(1)"} if len(tx.Statement.Selects) == 1 { From 6095dbf939a8de468378084eb4cbbe9d83fe7201 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 19 Jan 2021 15:40:04 +0800 Subject: [PATCH 787/881] Fix parse embedded relations, close #3964, #3965 --- migrator/migrator.go | 16 +++++++++------- schema/relationship.go | 12 ++++++------ schema/relationship_test.go | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 13 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index a475d307..e25d427c 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -667,13 +667,15 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i } orderedModelNamesMap[name] = true - dep := valuesMap[name] - for _, d := range dep.Depends { - if _, ok := valuesMap[d.Table]; ok { - insertIntoOrderedList(d.Table) - } else if autoAdd { - parseDependence(reflect.New(d.ModelType).Interface(), autoAdd) - insertIntoOrderedList(d.Table) + if autoAdd { + dep := valuesMap[name] + for _, d := range dep.Depends { + if _, ok := valuesMap[d.Table]; ok { + insertIntoOrderedList(d.Table) + } else { + parseDependence(reflect.New(d.ModelType).Interface(), autoAdd) + insertIntoOrderedList(d.Table) + } } } diff --git a/schema/relationship.go b/schema/relationship.go index b2253035..41e0b9bd 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -53,7 +53,7 @@ type Reference struct { OwnPrimaryKey bool } -func (schema *Schema) parseRelation(field *Field) { +func (schema *Schema) parseRelation(field *Field) *Relationship { var ( err error fieldValue = reflect.New(field.IndirectFieldType).Interface() @@ -67,13 +67,10 @@ func (schema *Schema) parseRelation(field *Field) { ) cacheStore := schema.cacheStore - if field.OwnerSchema != nil { - cacheStore = field.OwnerSchema.cacheStore - } if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { schema.err = err - return + return nil } if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { @@ -92,7 +89,8 @@ func (schema *Schema) parseRelation(field *Field) { } if relation.Type == "has" { - if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil { + // don't add relations to embeded schema, which might be shared + if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil { relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation } @@ -117,6 +115,8 @@ func (schema *Schema) parseRelation(field *Field) { schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation) } } + + return relation } // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 887e1341..64d0c2a7 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -397,3 +397,39 @@ func TestMultipleMany2Many(t *testing.T) { }, ) } + +type CreatedByModel struct { + CreatedByID uint + CreatedBy *CreatedUser +} + +type CreatedUser struct { + gorm.Model + CreatedByModel +} + +func TestEmbeddedRelation(t *testing.T) { + checkStructRelation(t, &CreatedUser{}, Relation{ + Name: "CreatedBy", Type: schema.BelongsTo, Schema: "CreatedUser", FieldSchema: "CreatedUser", + References: []Reference{ + {"ID", "CreatedUser", "CreatedByID", "CreatedUser", "", false}, + }, + }) + + userSchema, err := schema.Parse(&CreatedUser{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse schema, got error %v", err) + } + + if len(userSchema.Relationships.Relations) != 1 { + t.Fatalf("expects 1 relations, but got %v", len(userSchema.Relationships.Relations)) + } + + if createdByRel, ok := userSchema.Relationships.Relations["CreatedBy"]; ok { + if createdByRel.FieldSchema != userSchema { + t.Fatalf("expects same field schema, but got new %p, old %p", createdByRel.FieldSchema, userSchema) + } + } else { + t.Fatalf("expects created by relations, but not found") + } +} From 9790103e68e4072ada9b0cf17f2e00fc3ac036e8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 19 Jan 2021 16:37:49 +0800 Subject: [PATCH 788/881] Fix Where with empty struct, close #3966 --- finisher_api.go | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index e757bfe9..4a3c323b 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -116,7 +116,9 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) if len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = dest @@ -128,7 +130,9 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1) if len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = dest @@ -143,7 +147,9 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { Desc: true, }) if len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = dest @@ -155,7 +161,9 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } } tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) @@ -221,8 +229,9 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { } } case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: - exprs := tx.Statement.BuildCondition(value) - tx.assignInterfacesToValue(exprs) + if exprs := tx.Statement.BuildCondition(value); len(exprs) > 0 { + tx.assignInterfacesToValue(exprs) + } default: if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil { reflectValue := reflect.Indirect(reflect.ValueOf(value)) @@ -239,8 +248,9 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { } } } else if len(values) > 0 { - exprs := tx.Statement.BuildCondition(values[0], values[1:]...) - tx.assignInterfacesToValue(exprs) + if exprs := tx.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 { + tx.assignInterfacesToValue(exprs) + } return } } @@ -352,7 +362,9 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) { func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } } tx.Statement.Dest = value tx.callbacks.Delete().Execute(tx) From 35ebfe68740ef8d1ff3fde2037fbba34d802e287 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 20 Jan 2021 18:24:05 +0800 Subject: [PATCH 789/881] Support group conditions with single OR condition --- statement.go | 5 +++++ tests/query_test.go | 12 +++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/statement.go b/statement.go index 3617d7ed..de1b300f 100644 --- a/statement.go +++ b/statement.go @@ -261,6 +261,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] case *DB: if cs, ok := v.Statement.Clauses["WHERE"]; ok { if where, ok := cs.Expression.(clause.Where); ok { + if len(where.Exprs) == 1 { + if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { + where.Exprs[0] = clause.AndConditions{Exprs: orConds.Exprs} + } + } conds = append(conds, clause.And(where.Exprs...)) } else if cs.Expression != nil { conds = append(conds, cs.Expression) diff --git a/tests/query_test.go b/tests/query_test.go index 50522f71..c6c7acb0 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -475,7 +475,17 @@ func TestNotWithAllFields(t *testing.T) { func TestOr(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) - result := dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{}) + result := dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin")).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ AND .*role.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin").Or("role = ?", "admin")).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ AND (.*role.* = .+ OR .*role.* = .+)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) } From f8bd4c4875a269b97a2175a0c719805692d0d210 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 24 Jan 2021 10:23:04 +0800 Subject: [PATCH 790/881] Don't create index if there are error exist, close #3976 --- migrator/migrator.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index e25d427c..e8718d18 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -183,7 +183,9 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, idx := range stmt.Schema.ParseIndexes() { if m.CreateIndexAfterCreateTable { defer func(value interface{}, name string) { - errr = tx.Migrator().CreateIndex(value, name) + if errr == nil { + errr = tx.Migrator().CreateIndex(value, name) + } }(value, idx.Name) } else { if idx.Class != "" { From 59c01b7943a3be36e0d17bfd62a763cf8572f44c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 25 Jan 2021 10:30:57 +0800 Subject: [PATCH 791/881] Make migrator works with dbresolver, close #3992 --- migrator/migrator.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index e8718d18..c6d0947a 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -82,7 +82,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { // AutoMigrate func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { - tx := m.DB.Session(&gorm.Session{NewDB: true}) + tx := m.DB.Session(&gorm.Session{}) if !tx.Migrator().HasTable(value) { if err := tx.Migrator().CreateTable(value); err != nil { return err @@ -154,7 +154,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { - tx := m.DB.Session(&gorm.Session{NewDB: true}) + tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { var ( createTableSQL = "CREATE TABLE ? (" @@ -239,7 +239,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { func (m Migrator) DropTable(values ...interface{}) error { values = m.ReorderModels(values, false) for i := len(values) - 1; i >= 0; i-- { - tx := m.DB.Session(&gorm.Session{NewDB: true}) + tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { return tx.Exec("DROP TABLE IF EXISTS ?", m.CurrentTable(stmt)).Error }); err != nil { @@ -406,7 +406,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) { columnTypes = make([]gorm.ColumnType, 0) err = m.RunWithValue(value, func(stmt *gorm.Statement) error { - rows, err := m.DB.Session(&gorm.Session{NewDB: true}).Table(stmt.Table).Limit(1).Rows() + rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() if err == nil { defer rows.Close() rawColumnTypes, err := rows.ColumnTypes() From 916338a9e178f01c3da62c817c3efa44f1d36c4d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 26 Jan 2021 13:39:34 +0800 Subject: [PATCH 792/881] Test migrate constraints, close #3986 --- migrator/migrator.go | 95 +++++++++++++++++++++++++++++------------- schema/relationship.go | 6 +-- tests/migrate_test.go | 30 +++++++++++++ 3 files changed, 97 insertions(+), 34 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index c6d0947a..91dd8e83 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -451,50 +451,80 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter return } +func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) { + if stmt.Schema == nil { + return nil, nil, stmt.Table + } + + checkConstraints := stmt.Schema.ParseCheckConstraints() + if chk, ok := checkConstraints[name]; ok { + return nil, &chk, stmt.Table + } + + getTable := func(rel *schema.Relationship) string { + switch rel.Type { + case schema.HasOne, schema.HasMany: + return rel.FieldSchema.Table + case schema.Many2Many: + return rel.JoinTable.Table + } + return stmt.Table + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { + return constraint, nil, getTable(rel) + } + } + + if field := stmt.Schema.LookUpField(name); field != nil { + for _, cc := range checkConstraints { + if cc.Field == field { + return nil, &cc, stmt.Table + } + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field { + return constraint, nil, getTable(rel) + } + } + } + return nil, nil, "" +} + func (m Migrator) CreateConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - checkConstraints := stmt.Schema.ParseCheckConstraints() - if chk, ok := checkConstraints[name]; ok { + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if chk != nil { return m.DB.Exec( "ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)", m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, ).Error } - for _, rel := range stmt.Schema.Relationships.Relations { - if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { - sql, values := buildConstraint(constraint) - return m.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{m.CurrentTable(stmt)}, values...)...).Error + if constraint != nil { + var vars = []interface{}{clause.Table{Name: table}} + if stmt.TableExpr != nil { + vars[0] = stmt.TableExpr } + sql, values := buildConstraint(constraint) + return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error } - err := fmt.Errorf("failed to create constraint with name %v", name) - if field := stmt.Schema.LookUpField(name); field != nil { - for _, cc := range checkConstraints { - if err = m.DB.Migrator().CreateIndex(value, cc.Name); err != nil { - return err - } - } - - for _, rel := range stmt.Schema.Relationships.Relations { - if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field { - if err = m.DB.Migrator().CreateIndex(value, constraint.Name); err != nil { - return err - } - } - } - } - - return err + return nil }) } func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Exec( - "ALTER TABLE ? DROP CONSTRAINT ?", - m.CurrentTable(stmt), clause.Column{Name: name}, - ).Error + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if constraint != nil { + name = constraint.Name + } else if chk != nil { + name = chk.Name + } + return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error }) } @@ -502,9 +532,16 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if constraint != nil { + name = constraint.Name + } else if chk != nil { + name = chk.Name + } + return m.DB.Raw( "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", - currentDatabase, stmt.Table, name, + currentDatabase, table, name, ).Row().Scan(&count) }) diff --git a/schema/relationship.go b/schema/relationship.go index 41e0b9bd..9b7d803c 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -519,7 +519,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { } for _, ref := range rel.References { - if ref.PrimaryKey != nil { + if ref.PrimaryKey != nil && (rel.JoinTable == nil || ref.OwnPrimaryKey) { constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey) constraint.References = append(constraint.References, ref.PrimaryKey) @@ -533,10 +533,6 @@ func (rel *Relationship) ParseConstraint() *Constraint { } } - if rel.JoinTable != nil { - return nil - } - return &constraint } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 275fe634..ca28dfbc 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -323,3 +323,33 @@ func TestMigrateColumns(t *testing.T) { t.Fatalf("Found deleted column") } } + +func TestMigrateConstraint(t *testing.T) { + if DB.Dialector.Name() == "sqlite" { + t.Skip() + } + + names := []string{"Account", "fk_users_account", "Pets", "fk_users_pets", "Company", "fk_users_company", "Manager", "fk_users_manager", "Team", "fk_users_team", "Languages", "fk_users_languages"} + + for _, name := range names { + if !DB.Migrator().HasConstraint(&User{}, name) { + DB.Migrator().CreateConstraint(&User{}, name) + } + + if err := DB.Migrator().DropConstraint(&User{}, name); err != nil { + t.Fatalf("failed to drop constraint %v, got error %v", name, err) + } + + if DB.Migrator().HasConstraint(&User{}, name) { + t.Fatalf("constraint %v should been deleted", name) + } + + if err := DB.Migrator().CreateConstraint(&User{}, name); err != nil { + t.Fatalf("failed to create constraint %v, got error %v", name, err) + } + + if !DB.Migrator().HasConstraint(&User{}, name) { + t.Fatalf("failed to found constraint %v", name) + } + } +} From 08678106a4ebcd9d7de42a254b61a198a69504a4 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 26 Jan 2021 14:34:21 +0800 Subject: [PATCH 793/881] Support replace associations without the creation in association mode, close #3937 --- association.go | 28 +++++++++++++++++++++++++--- tests/associations_many2many_test.go | 5 +++++ tests/go.mod | 8 ++++---- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/association.go b/association.go index d93ff8ca..4c55c7e1 100644 --- a/association.go +++ b/association.go @@ -66,7 +66,9 @@ func (association *Association) Append(values ...interface{}) error { func (association *Association) Replace(values ...interface{}) error { if association.Error == nil { // save associations - association.saveAssociation( /*clear*/ true, values...) + if association.saveAssociation( /*clear*/ true, values...); association.Error != nil { + return association.Error + } // set old associations's foreign key to null reflectValue := association.DB.Statement.ReflectValue @@ -378,11 +380,31 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } selectedSaveColumns := []string{association.Relationship.Name} + omitColumns := []string{} + selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false) + for name, ok := range selectColumns { + columnName := "" + if strings.HasPrefix(name, association.Relationship.Name) { + columnName = strings.TrimPrefix(name, association.Relationship.Name) + } else if strings.HasPrefix(name, clause.Associations) { + columnName = name + } + + if columnName != "" { + if ok { + selectedSaveColumns = append(selectedSaveColumns, columnName) + } else { + omitColumns = append(omitColumns, columnName) + } + } + } + for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey { selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name) } } + associationDB := association.DB.Session(&Session{}).Model(nil).Select(selectedSaveColumns).Session(&Session{}) switch reflectValue.Kind() { case reflect.Slice, reflect.Array: @@ -417,7 +439,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) // TODO support save slice data, sql with case? - association.Error = association.DB.Session(&Session{NewDB: true}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Index(i).Addr().Interface()).Error + association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error } case reflect.Struct: // clear old data @@ -439,7 +461,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } if len(values) > 0 { - association.Error = association.DB.Session(&Session{NewDB: true}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Addr().Interface()).Error + association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error } } diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index 1ddd3b85..739d1682 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -113,6 +113,11 @@ func TestMany2ManyOmitAssociations(t *testing.T) { if DB.Model(&user).Association("Languages").Find(&languages); len(languages) != 2 { t.Errorf("languages count should be %v, but got %v", 2, len(languages)) } + + var newLang = Language{Code: "omitmany2many", Name: "omitmany2many"} + if err := DB.Model(&user).Omit("Languages.*").Association("Languages").Replace(&newLang); err == nil { + t.Errorf("should failed to insert languages due to constraint failed, error: %v", err) + } } func TestMany2ManyAssociationForSlice(t *testing.T) { diff --git a/tests/go.mod b/tests/go.mod index f6912a0f..67db5117 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,11 +7,11 @@ require ( github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 github.com/stretchr/testify v1.5.1 - gorm.io/driver/mysql v1.0.3 - gorm.io/driver/postgres v1.0.6 + gorm.io/driver/mysql v1.0.4 + gorm.io/driver/postgres v1.0.7 gorm.io/driver/sqlite v1.1.4 - gorm.io/driver/sqlserver v1.0.5 - gorm.io/gorm v1.20.8 + gorm.io/driver/sqlserver v1.0.6 + gorm.io/gorm v1.20.12 ) replace gorm.io/gorm => ../ From 7f198ead0e716265acd3491925e340bfae758e95 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 26 Jan 2021 16:33:19 +0800 Subject: [PATCH 794/881] Refactor nested preloading associations, close #3970 --- callbacks/preload.go | 12 ++++++------ callbacks/query.go | 40 ++++++++++++++-------------------------- 2 files changed, 20 insertions(+), 32 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index 3614346f..27e3c3dd 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -9,10 +9,9 @@ import ( "gorm.io/gorm/utils" ) -func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { +func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) { var ( reflectValue = db.Statement.ReflectValue - rel = rels[len(rels)-1] tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks}) relForeignKeys []string relForeignFields []*schema.Field @@ -27,10 +26,6 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { return true }) - if len(rels) > 1 { - reflectValue = schema.GetRelationsValues(reflectValue, rels[:len(rels)-1]) - } - if rel.JoinTable != nil { var joinForeignFields, joinRelForeignFields []*schema.Field var joinForeignKeys []string @@ -97,6 +92,11 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } } + // nested preload + for p, pvs := range preloads { + tx = tx.Preload(p, pvs...) + } + reflectResults := rel.FieldSchema.MakeSlice().Elem() column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) diff --git a/callbacks/query.go b/callbacks/query.go index ebb09d6b..fff46d57 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -8,7 +8,6 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" - "gorm.io/gorm/schema" ) func Query(db *gorm.DB) { @@ -168,48 +167,37 @@ func BuildQuerySQL(db *gorm.DB) { func Preload(db *gorm.DB) { if db.Error == nil && len(db.Statement.Preloads) > 0 { - preloadMap := map[string][]string{} + preloadMap := map[string]map[string][]interface{}{} for name := range db.Statement.Preloads { if name == clause.Associations { for _, rel := range db.Statement.Schema.Relationships.Relations { if rel.Schema == db.Statement.Schema { - preloadMap[rel.Name] = []string{rel.Name} + preloadMap[rel.Name] = nil } } } else { preloadFields := strings.Split(name, ".") - for idx := range preloadFields { - preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] + if _, ok := preloadMap[preloadFields[0]]; !ok { + preloadMap[preloadFields[0]] = map[string][]interface{}{} + } + + if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { + preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name] } } } - preloadNames := make([]string, len(preloadMap)) - idx := 0 + preloadNames := make([]string, 0, len(preloadMap)) for key := range preloadMap { - preloadNames[idx] = key - idx++ + preloadNames = append(preloadNames, key) } sort.Strings(preloadNames) for _, name := range preloadNames { - var ( - curSchema = db.Statement.Schema - preloadFields = preloadMap[name] - rels = make([]*schema.Relationship, len(preloadFields)) - ) - - for idx, preloadField := range preloadFields { - if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { - rels[idx] = rel - curSchema = rel.FieldSchema - } else { - db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) - } - } - - if db.Error == nil { - preload(db, rels, db.Statement.Preloads[name]) + if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil { + preload(db, rel, db.Statement.Preloads[name], preloadMap[name]) + } else { + db.AddError(fmt.Errorf("%v: %w for schema %v", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) } } } From f6308ed223a12dfdbfd5cc01f90e338c58c21bce Mon Sep 17 00:00:00 2001 From: Manyanda Chitimbo Date: Wed, 27 Jan 2021 04:18:39 +0100 Subject: [PATCH 795/881] refactor: fix typo in tests.yml (#4005) --- .github/workflows/tests.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4388c31d..f26caa86 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -26,7 +26,7 @@ jobs: - name: Check out code into the Go module directory uses: actions/checkout@v2 - - name: go mod pakcage cache + - name: go mod package cache uses: actions/cache@v2 with: path: ~/go/pkg/mod @@ -51,7 +51,7 @@ jobs: - name: Check out code into the Go module directory uses: actions/checkout@v2 - - name: go mod pakcage cache + - name: go mod package cache uses: actions/cache@v2 with: path: ~/go/pkg/mod @@ -95,7 +95,7 @@ jobs: uses: actions/checkout@v2 - - name: go mod pakcage cache + - name: go mod package cache uses: actions/cache@v2 with: path: ~/go/pkg/mod @@ -138,7 +138,7 @@ jobs: - name: Check out code into the Go module directory uses: actions/checkout@v2 - - name: go mod pakcage cache + - name: go mod package cache uses: actions/cache@v2 with: path: ~/go/pkg/mod @@ -181,7 +181,7 @@ jobs: - name: Check out code into the Go module directory uses: actions/checkout@v2 - - name: go mod pakcage cache + - name: go mod package cache uses: actions/cache@v2 with: path: ~/go/pkg/mod From ba590650241bbab942745cd97269fa30e1a965f8 Mon Sep 17 00:00:00 2001 From: rorschach Date: Tue, 26 Jan 2021 20:08:41 +0800 Subject: [PATCH 796/881] retrieving gorm object support pointer --- callbacks.go | 5 +++++ scan.go | 2 +- tests/scan_test.go | 6 ++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/callbacks.go b/callbacks.go index e21e0718..cb14aff1 100644 --- a/callbacks.go +++ b/callbacks.go @@ -94,6 +94,11 @@ func (p *processor) Execute(db *DB) { if stmt.Dest != nil { stmt.ReflectValue = reflect.ValueOf(stmt.Dest) for stmt.ReflectValue.Kind() == reflect.Ptr { + if stmt.ReflectValue.IsNil() { + stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem())) + break + } + stmt.ReflectValue = stmt.ReflectValue.Elem() } if !stmt.ReflectValue.IsValid() { diff --git a/scan.go b/scan.go index 0416489d..acd637a4 100644 --- a/scan.go +++ b/scan.go @@ -191,7 +191,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem())) } } - case reflect.Struct: + case reflect.Struct, reflect.Ptr: if db.Statement.ReflectValue.Type() != Schema.ModelType { Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) } diff --git a/tests/scan_test.go b/tests/scan_test.go index 785bb97e..86cb0399 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -28,6 +28,12 @@ func TestScan(t *testing.T) { t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) } + var resPointer *result + DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resPointer) + if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) + } + DB.Table("users").Select("id, name, age").Where("id = ?", user2.ID).Scan(&res) if res.ID != user2.ID || res.Name != user2.Name || res.Age != int(user2.Age) { t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user2) From 81aa949105d6c19e830d3a63a827d561d3927e6a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 27 Jan 2021 11:24:34 +0800 Subject: [PATCH 797/881] Remove the uncessary reflect.Ptr --- scan.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scan.go b/scan.go index acd637a4..0416489d 100644 --- a/scan.go +++ b/scan.go @@ -191,7 +191,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem())) } } - case reflect.Struct, reflect.Ptr: + case reflect.Struct: if db.Statement.ReflectValue.Type() != Schema.ModelType { Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) } From cc61202fe2df0630fdfbc6cc31b455a5d76a2494 Mon Sep 17 00:00:00 2001 From: Ben Date: Wed, 27 Jan 2021 11:50:15 +0800 Subject: [PATCH 798/881] retrieving gorm object support pointer (#4006) --- scan.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scan.go b/scan.go index 0416489d..acd637a4 100644 --- a/scan.go +++ b/scan.go @@ -191,7 +191,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem())) } } - case reflect.Struct: + case reflect.Struct, reflect.Ptr: if db.Statement.ReflectValue.Type() != Schema.ModelType { Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) } From 8500380e609be83dd7db46e1b29ee7ab69b6b2e2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 27 Jan 2021 17:45:48 +0800 Subject: [PATCH 799/881] Add name checker test, close #4007 --- tests/postgres_test.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 85cd34d4..94077d1d 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -15,6 +15,7 @@ func TestPostgres(t *testing.T) { type Harumph struct { gorm.Model + Name string `gorm:"check:name_checker,name <> ''"` Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` Things pq.StringArray `gorm:"type:text[]"` } @@ -30,10 +31,17 @@ func TestPostgres(t *testing.T) { } harumph := Harumph{} - DB.Create(&harumph) + if err := DB.Create(&harumph).Error; err == nil { + t.Fatalf("should failed to create data, name can't be blank") + } + + harumph = Harumph{Name: "jinzhu"} + if err := DB.Create(&harumph).Error; err != nil { + t.Fatalf("should be able to create data, but got %v", err) + } var result Harumph - if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil { + if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu" { t.Errorf("No error should happen, but got %v", err) } } From 4267df02af864237917276cf3abb9473041a9db2 Mon Sep 17 00:00:00 2001 From: David Harkness Date: Wed, 27 Jan 2021 18:21:58 -0800 Subject: [PATCH 800/881] Fix typo in README (#4012) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9c0aded0..a3eabe39 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. * Hooks (Before/After Create/Save/Update/Delete/Find) * Eager loading with `Preload`, `Joins` * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point -* Context, Prepared Statment Mode, DryRun Mode +* Context, Prepared Statement Mode, DryRun Mode * Batch Insert, FindInBatches, Find To Map * SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr * Composite Primary Key From 6e3ac74b7e10ec77bc5d973ce693f0648439b888 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 28 Jan 2021 20:17:19 +0800 Subject: [PATCH 801/881] Fix preloading all associations together with nested associations, close #4016 --- callbacks/query.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callbacks/query.go b/callbacks/query.go index fff46d57..05b572f0 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -172,7 +172,7 @@ func Preload(db *gorm.DB) { if name == clause.Associations { for _, rel := range db.Statement.Schema.Relationships.Relations { if rel.Schema == db.Statement.Schema { - preloadMap[rel.Name] = nil + preloadMap[rel.Name] = map[string][]interface{}{} } } } else { From 7598204dc3b0439196b66505e2a7acdd0537ea31 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 29 Jan 2021 16:40:07 +0800 Subject: [PATCH 802/881] Support `FullSaveAssociations` for association mode, close #4010 --- association.go | 14 ++++++++++++-- gorm.go | 1 - 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/association.go b/association.go index 4c55c7e1..3a2942fd 100644 --- a/association.go +++ b/association.go @@ -385,7 +385,9 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ for name, ok := range selectColumns { columnName := "" if strings.HasPrefix(name, association.Relationship.Name) { - columnName = strings.TrimPrefix(name, association.Relationship.Name) + if columnName = strings.TrimPrefix(name, association.Relationship.Name); columnName == ".*" { + columnName = name + } } else if strings.HasPrefix(name, clause.Associations) { columnName = name } @@ -404,7 +406,15 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name) } } - associationDB := association.DB.Session(&Session{}).Model(nil).Select(selectedSaveColumns).Session(&Session{}) + + associationDB := association.DB.Session(&Session{}).Model(nil) + if !association.DB.FullSaveAssociations { + associationDB.Select(selectedSaveColumns) + } + if len(omitColumns) > 0 { + associationDB.Omit(omitColumns...) + } + associationDB = associationDB.Session(&Session{}) switch reflectValue.Kind() { case reflect.Slice, reflect.Array: diff --git a/gorm.go b/gorm.go index 88885407..1109e8cd 100644 --- a/gorm.go +++ b/gorm.go @@ -167,7 +167,6 @@ func (db *DB) Session(config *Session) *DB { clone: 1, } ) - if config.CreateBatchSize > 0 { tx.Config.CreateBatchSize = config.CreateBatchSize } From db0cc4d60bbc6ab7ce1fe72bcbf78dda3d8328e0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 1 Feb 2021 10:37:12 +0800 Subject: [PATCH 803/881] Fix too long foreign key/checker names, close #4026 --- schema/naming.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/schema/naming.go b/schema/naming.go index 63296967..f6d15f5a 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -54,27 +54,30 @@ func (ns NamingStrategy) JoinTableName(str string) string { // RelationshipFKName generate fk name for relation func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { - return strings.Replace(fmt.Sprintf("fk_%s_%s", rel.Schema.Table, ns.toDBName(rel.Name)), ".", "_", -1) + return ns.formatName("fk", rel.Schema.Table, ns.toDBName(rel.Name)) } // CheckerName generate checker name func (ns NamingStrategy) CheckerName(table, column string) string { - return strings.Replace(fmt.Sprintf("chk_%s_%s", table, column), ".", "_", -1) + return ns.formatName("chk", table, column) } // IndexName generate index name func (ns NamingStrategy) IndexName(table, column string) string { - idxName := fmt.Sprintf("idx_%v_%v", table, ns.toDBName(column)) - idxName = strings.Replace(idxName, ".", "_", -1) + return ns.formatName("idx", table, ns.toDBName(column)) +} - if utf8.RuneCountInString(idxName) > 64 { +func (ns NamingStrategy) formatName(prefix, table, name string) string { + formatedName := strings.Replace(fmt.Sprintf("%v_%v_%v", prefix, table, name), ".", "_", -1) + + if utf8.RuneCountInString(formatedName) > 64 { h := sha1.New() - h.Write([]byte(idxName)) + h.Write([]byte(formatedName)) bs := h.Sum(nil) - idxName = fmt.Sprintf("idx%v%v", table, column)[0:56] + string(bs)[:8] + formatedName = fmt.Sprintf("%v%v%v", prefix, table, name)[0:56] + string(bs)[:8] } - return idxName + return formatedName } var ( From 8f37cb01959201e1b53460c6e0a0b00d9f64d0f1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 1 Feb 2021 10:42:13 +0800 Subject: [PATCH 804/881] Make has to be a const, close #4024 --- schema/relationship.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index 9b7d803c..0eaace89 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -18,6 +18,7 @@ const ( HasMany RelationshipType = "has_many" // HasManyRel has many relationship BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship + has RelationshipType = "has" ) type Relationships struct { @@ -88,7 +89,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { } } - if relation.Type == "has" { + if relation.Type == has { // don't add relations to embeded schema, which might be shared if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil { relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation @@ -176,7 +177,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi }) } - relation.Type = "has" + relation.Type = has } func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) { @@ -476,7 +477,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue } if gl == guessHas || gl == guessEmbeddedHas { - relation.Type = "has" + relation.Type = has } else { relation.Type = BelongsTo } From 3d3208ed602cdf219cc0501a05bd9f00c6b4bd12 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 3 Feb 2021 16:27:49 +0800 Subject: [PATCH 805/881] initialize config plugins --- gorm.go | 8 ++++++++ tests/go.mod | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/gorm.go b/gorm.go index 1109e8cd..6adf455a 100644 --- a/gorm.go +++ b/gorm.go @@ -106,6 +106,14 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { if config.Plugins == nil { config.Plugins = map[string]Plugin{} + } else { + for _, p := range config.Plugins { + defer func(plugin Plugin) { + if errr := plugin.Initialize(db); errr != nil { + err = errr + } + }(p) + } } if config.cacheStore == nil { diff --git a/tests/go.mod b/tests/go.mod index 67db5117..20d7206a 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,7 +8,7 @@ require ( github.com/lib/pq v1.6.0 github.com/stretchr/testify v1.5.1 gorm.io/driver/mysql v1.0.4 - gorm.io/driver/postgres v1.0.7 + gorm.io/driver/postgres v1.0.8 gorm.io/driver/sqlite v1.1.4 gorm.io/driver/sqlserver v1.0.6 gorm.io/gorm v1.20.12 From ef5ef18d4ad7f234fab58540dc843d5356dd2280 Mon Sep 17 00:00:00 2001 From: heige Date: Sun, 7 Feb 2021 10:09:32 +0800 Subject: [PATCH 806/881] recommended to use magic const strings (#4059) --- logger/sql.go | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index d080def2..3ef2a4e2 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -13,6 +13,12 @@ import ( "gorm.io/gorm/utils" ) +const ( + tmFmtWithMS = "2006-01-02 15:04:05.999" + tmFmtZero = "0000-00-00 00:00:00" + nullStr = "NULL" +) + func isPrintable(s []byte) bool { for _, r := range s { if !unicode.IsPrint(rune(r)) { @@ -34,26 +40,26 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a vars[idx] = strconv.FormatBool(v) case time.Time: if v.IsZero() { - vars[idx] = escaper + "0000-00-00 00:00:00" + escaper + vars[idx] = escaper + tmFmtZero + escaper } else { - vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper + vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper } case *time.Time: if v != nil { if v.IsZero() { - vars[idx] = escaper + "0000-00-00 00:00:00" + escaper + vars[idx] = escaper + tmFmtZero + escaper } else { - vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper + vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper } } else { - vars[idx] = "NULL" + vars[idx] = nullStr } case fmt.Stringer: reflectValue := reflect.ValueOf(v) if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper } else { - vars[idx] = "NULL" + vars[idx] = nullStr } case driver.Valuer: reflectValue := reflect.ValueOf(v) @@ -61,7 +67,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a r, _ := v.Value() convertParams(r, idx) } else { - vars[idx] = "NULL" + vars[idx] = nullStr } case []byte: if isPrintable(v) { @@ -78,7 +84,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a default: rv := reflect.ValueOf(v) if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { - vars[idx] = "NULL" + vars[idx] = nullStr } else if valuer, ok := v.(driver.Valuer); ok { v, _ = valuer.Value() convertParams(v, idx) From e80853e7f5eb5313be1c41ae122b34335cbafcf7 Mon Sep 17 00:00:00 2001 From: heige Date: Sun, 7 Feb 2021 10:12:13 +0800 Subject: [PATCH 807/881] optimization check for ParseCheckConstraints (#4063) --- schema/check.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/schema/check.go b/schema/check.go index 7d31ec70..ec66bad2 100644 --- a/schema/check.go +++ b/schema/check.go @@ -5,6 +5,11 @@ import ( "strings" ) +var ( + // match English letters and midline + regEnLetterAndmidline = regexp.MustCompile("^[A-Za-z-_]+$") +) + type Check struct { Name string Constraint string // length(phone) >= 10 @@ -17,7 +22,7 @@ func (schema *Schema) ParseCheckConstraints() map[string]Check { for _, field := range schema.FieldsByDBName { if chk := field.TagSettings["CHECK"]; chk != "" { names := strings.Split(chk, ",") - if len(names) > 1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(names[0]) { + if len(names) > 1 && regEnLetterAndmidline.MatchString(names[0]) { checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} } else { if names[0] == "" { From bb153384d1274fbe3bbc7d33c31cb1946e7fbe73 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 7 Feb 2021 11:18:09 +0800 Subject: [PATCH 808/881] Switch driver.Valuer, fmt.Stringer order when format SQL --- logger/sql.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 3ef2a4e2..4c5f92ed 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -54,13 +54,6 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } else { vars[idx] = nullStr } - case fmt.Stringer: - reflectValue := reflect.ValueOf(v) - if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { - vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper - } else { - vars[idx] = nullStr - } case driver.Valuer: reflectValue := reflect.ValueOf(v) if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { @@ -69,6 +62,13 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } else { vars[idx] = nullStr } + case fmt.Stringer: + reflectValue := reflect.ValueOf(v) + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper + } else { + vars[idx] = nullStr + } case []byte: if isPrintable(v) { vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper From 4373aa01abbe34ae3546681b9ce9095af670f777 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 7 Feb 2021 12:44:59 +0800 Subject: [PATCH 809/881] Don't call AfterFind hooks if no record found, close #4048 --- callbacks/query.go | 2 +- tests/hooks_test.go | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/callbacks/query.go b/callbacks/query.go index 05b572f0..5a97e1ad 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -204,7 +204,7 @@ func Preload(db *gorm.DB) { } func AfterQuery(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(AfterFindInterface); ok { db.AddError(i.AfterFind(tx)) diff --git a/tests/hooks_test.go b/tests/hooks_test.go index fe3f7d08..0e6ab2fe 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -133,6 +133,15 @@ func TestRunCallbacks(t *testing.T) { if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { t.Fatalf("Can't find a deleted record") } + + beforeCallTimes := p.AfterFindCallTimes + if DB.Where("Code = ?", "unique_code").Find(&p).Error != nil { + t.Fatalf("Find don't raise error when record not found") + } + + if p.AfterFindCallTimes != beforeCallTimes { + t.Fatalf("AfterFind should not be called") + } } func TestCallbacksWithErrors(t *testing.T) { From deff0594eee29ae94d66ae476771522252f5b6a7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 7 Feb 2021 14:24:11 +0800 Subject: [PATCH 810/881] Save associations based on creatable/updatable permission, close #4056 --- callbacks/associations.go | 444 +++++++++++++++++++------------------- callbacks/callbacks.go | 8 +- schema/schema.go | 2 + 3 files changed, 230 insertions(+), 224 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 7b01247e..28c769e7 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -9,79 +9,81 @@ import ( "gorm.io/gorm/schema" ) -func SaveBeforeAssociations(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil { - selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) +func SaveBeforeAssociations(create bool) func(db *gorm.DB) { + return func(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create) - // Save Belongs To associations - for _, rel := range db.Statement.Schema.Relationships.BelongsTo { - if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { - continue - } + // Save Belongs To associations + for _, rel := range db.Statement.Schema.Relationships.BelongsTo { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } - setupReferences := func(obj reflect.Value, elem reflect.Value) { - for _, ref := range rel.References { - if !ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(elem) - db.AddError(ref.ForeignKey.Set(obj, pv)) + setupReferences := func(obj reflect.Value, elem reflect.Value) { + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(elem) + db.AddError(ref.ForeignKey.Set(obj, pv)) - if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { - dest[ref.ForeignKey.DBName] = pv - if _, ok := dest[rel.Name]; ok { - dest[rel.Name] = elem.Interface() + if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { + dest[ref.ForeignKey.DBName] = pv + if _, ok := dest[rel.Name]; ok { + dest[rel.Name] = elem.Interface() + } } } } } - } - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - var ( - objs []reflect.Value - fieldType = rel.Field.FieldType - isPtr = fieldType.Kind() == reflect.Ptr - ) + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var ( + objs []reflect.Value + fieldType = rel.Field.FieldType + isPtr = fieldType.Kind() == reflect.Ptr + ) - if !isPtr { - fieldType = reflect.PtrTo(fieldType) - } + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - obj := db.Statement.ReflectValue.Index(i) + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(obj).Kind() == reflect.Struct { - if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value - rv := rel.Field.ReflectValueOf(obj) // relation reflect value - objs = append(objs, obj) - if isPtr { - elems = reflect.Append(elems, rv) - } else { - elems = reflect.Append(elems, rv.Addr()) + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(obj) // relation reflect value + objs = append(objs, obj) + if isPtr { + elems = reflect.Append(elems, rv) + } else { + elems = reflect.Append(elems, rv.Addr()) + } + } + } else { + break + } + } + + if elems.Len() > 0 { + if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { + for i := 0; i < elems.Len(); i++ { + setupReferences(objs[i], elems.Index(i)) } } - } else { - break } - } - - if elems.Len() > 0 { - if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { - for i := 0; i < elems.Len(); i++ { - setupReferences(objs[i], elems.Index(i)) + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() } - } - } - case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value - if rv.Kind() != reflect.Ptr { - rv = rv.Addr() - } - if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { - setupReferences(db.Statement.ReflectValue, rv) + if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { + setupReferences(db.Statement.ReflectValue, rv) + } } } } @@ -89,53 +91,133 @@ func SaveBeforeAssociations(db *gorm.DB) { } } -func SaveAfterAssociations(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil { - selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) +func SaveAfterAssociations(create bool) func(db *gorm.DB) { + return func(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create) - // Save Has One associations - for _, rel := range db.Statement.Schema.Relationships.HasOne { - if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { - continue + // Save Has One associations + for _, rel := range db.Statement.Schema.Relationships.HasOne { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var ( + fieldType = rel.Field.FieldType + isPtr = fieldType.Kind() == reflect.Ptr + ) + + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(obj); !zero { + rv := rel.Field.ReflectValueOf(obj) + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() + } + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(obj) + db.AddError(ref.ForeignKey.Set(rv, fv)) + } else if ref.PrimaryValue != "" { + db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + } + } + + elems = reflect.Append(elems, rv) + } + } + } + + if elems.Len() > 0 { + assignmentColumns := []string{} + for _, ref := range rel.References { + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) + } + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + if f.Kind() != reflect.Ptr { + f = f.Addr() + } + + assignmentColumns := []string{} + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) + ref.ForeignKey.Set(f, fv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(f, ref.PrimaryValue) + } + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) + } + } } - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - var ( - fieldType = rel.Field.FieldType - isPtr = fieldType.Kind() == reflect.Ptr - ) + // Save Has Many associations + for _, rel := range db.Statement.Schema.Relationships.HasMany { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := fieldType.Kind() == reflect.Ptr if !isPtr { fieldType = reflect.PtrTo(fieldType) } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + appendToElems := func(v reflect.Value) { + if _, zero := rel.Field.ValueOf(v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(v)) - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - obj := db.Statement.ReflectValue.Index(i) - - if reflect.Indirect(obj).Kind() == reflect.Struct { - if _, zero := rel.Field.ValueOf(obj); !zero { - rv := rel.Field.ReflectValueOf(obj) - if rv.Kind() != reflect.Ptr { - rv = rv.Addr() - } - + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) for _, ref := range rel.References { if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - db.AddError(ref.ForeignKey.Set(rv, fv)) + pv, _ := ref.PrimaryKey.ValueOf(v) + ref.ForeignKey.Set(elem, pv) } else if ref.PrimaryValue != "" { - db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + ref.ForeignKey.Set(elem, ref.PrimaryValue) } } - elems = reflect.Append(elems, rv) + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) + } } } } + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } + } + case reflect.Struct: + appendToElems(db.Statement.ReflectValue) + } + if elems.Len() > 0 { assignmentColumns := []string{} for _, ref := range rel.References { @@ -144,162 +226,84 @@ func SaveAfterAssociations(db *gorm.DB) { saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) } - case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) - if f.Kind() != reflect.Ptr { - f = f.Addr() - } + } - assignmentColumns := []string{} + // Save Many2Many associations + for _, rel := range db.Statement.Schema.Relationships.Many2Many { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := fieldType.Kind() == reflect.Ptr + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) + objs := []reflect.Value{} + + appendToJoins := func(obj reflect.Value, elem reflect.Value) { + joinValue := reflect.New(rel.JoinTable.ModelType) for _, ref := range rel.References { if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) - ref.ForeignKey.Set(f, fv) + fv, _ := ref.PrimaryKey.ValueOf(obj) + ref.ForeignKey.Set(joinValue, fv) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(f, ref.PrimaryValue) + ref.ForeignKey.Set(joinValue, ref.PrimaryValue) + } else { + fv, _ := ref.PrimaryKey.ValueOf(elem) + ref.ForeignKey.Set(joinValue, fv) } - assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - - saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) + joins = reflect.Append(joins, joinValue) } - } - } - // Save Has Many associations - for _, rel := range db.Statement.Schema.Relationships.HasMany { - if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { - continue - } + appendToElems := func(v reflect.Value) { + if _, zero := rel.Field.ValueOf(v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(v)) - fieldType := rel.Field.IndirectFieldType.Elem() - isPtr := fieldType.Kind() == reflect.Ptr - if !isPtr { - fieldType = reflect.PtrTo(fieldType) - } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) - appendToElems := func(v reflect.Value) { - if _, zero := rel.Field.ValueOf(v); !zero { - f := reflect.Indirect(rel.Field.ReflectValueOf(v)) + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) - for i := 0; i < f.Len(); i++ { - elem := f.Index(i) - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(v) - ref.ForeignKey.Set(elem, pv) - } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(elem, ref.PrimaryValue) + objs = append(objs, v) + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) } } + } + } - if isPtr { - elems = reflect.Append(elems, elem) - } else { - elems = reflect.Append(elems, elem.Addr()) + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) } } + case reflect.Struct: + appendToElems(db.Statement.ReflectValue) } - } - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - obj := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(obj).Kind() == reflect.Struct { - appendToElems(obj) + if elems.Len() > 0 { + if v, ok := selectColumns[rel.Name+".*"]; !ok || v { + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) + } + + for i := 0; i < elems.Len(); i++ { + appendToJoins(objs[i], elems.Index(i)) } } - case reflect.Struct: - appendToElems(db.Statement.ReflectValue) - } - if elems.Len() > 0 { - assignmentColumns := []string{} - for _, ref := range rel.References { - assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + if joins.Len() > 0 { + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{ + SkipHooks: db.Statement.SkipHooks, + DisableNestedTransaction: true, + }).Create(joins.Interface()).Error) } - - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) - } - } - - // Save Many2Many associations - for _, rel := range db.Statement.Schema.Relationships.Many2Many { - if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { - continue - } - - fieldType := rel.Field.IndirectFieldType.Elem() - isPtr := fieldType.Kind() == reflect.Ptr - if !isPtr { - fieldType = reflect.PtrTo(fieldType) - } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) - joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) - objs := []reflect.Value{} - - appendToJoins := func(obj reflect.Value, elem reflect.Value) { - joinValue := reflect.New(rel.JoinTable.ModelType) - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - ref.ForeignKey.Set(joinValue, fv) - } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(joinValue, ref.PrimaryValue) - } else { - fv, _ := ref.PrimaryKey.ValueOf(elem) - ref.ForeignKey.Set(joinValue, fv) - } - } - joins = reflect.Append(joins, joinValue) - } - - appendToElems := func(v reflect.Value) { - if _, zero := rel.Field.ValueOf(v); !zero { - f := reflect.Indirect(rel.Field.ReflectValueOf(v)) - - for i := 0; i < f.Len(); i++ { - elem := f.Index(i) - - objs = append(objs, v) - if isPtr { - elems = reflect.Append(elems, elem) - } else { - elems = reflect.Append(elems, elem.Addr()) - } - } - } - } - - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - obj := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(obj).Kind() == reflect.Struct { - appendToElems(obj) - } - } - case reflect.Struct: - appendToElems(db.Statement.ReflectValue) - } - - if elems.Len() > 0 { - if v, ok := selectColumns[rel.Name+".*"]; !ok || v { - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) - } - - for i := 0; i < elems.Len(); i++ { - appendToJoins(objs[i], elems.Index(i)) - } - } - - if joins.Len() > 0 { - db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{ - SkipHooks: db.Statement.SkipHooks, - DisableNestedTransaction: true, - }).Create(joins.Interface()).Error) } } } diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index dda4b046..7bb27318 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -17,9 +17,9 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { createCallback := db.Callback().Create() createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) createCallback.Register("gorm:before_create", BeforeCreate) - createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) + createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true)) createCallback.Register("gorm:create", Create(config)) - createCallback.Register("gorm:save_after_associations", SaveAfterAssociations) + createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) createCallback.Register("gorm:after_create", AfterCreate) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) @@ -40,9 +40,9 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) updateCallback.Register("gorm:before_update", BeforeUpdate) - updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) + updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) updateCallback.Register("gorm:update", Update) - updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations) + updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) diff --git a/schema/schema.go b/schema/schema.go index e36ed7b6..d08842e6 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -235,6 +235,8 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { if schema.parseRelation(field); schema.err != nil { return schema, schema.err + } else { + schema.FieldsByName[field.Name] = field } } From 883c32e59a0b56a3da972dfc8fb15b9fc281a1ba Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 7 Feb 2021 14:36:27 +0800 Subject: [PATCH 811/881] Support Unscoped when delete with selected associations, close #4062 --- callbacks/delete.go | 3 +++ tests/delete_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/callbacks/delete.go b/callbacks/delete.go index 867aa697..128722a1 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -36,6 +36,9 @@ func DeleteBeforeAssociations(db *gorm.DB) { modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) withoutConditions := false + if db.Statement.Unscoped { + tx = tx.Unscoped() + } if len(db.Statement.Selects) > 0 { var selects []string diff --git a/tests/delete_test.go b/tests/delete_test.go index 37e29fbe..abe85b0e 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -153,6 +153,30 @@ func TestDeleteWithAssociations(t *testing.T) { } } +func TestDeleteAssociationsWithUnscoped(t *testing.T) { + user := GetUser("unscoped_delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1}) + + if err := DB.Create(user).Error; err != nil { + t.Fatalf("failed to create user, got error %v", err) + } + + if err := DB.Unscoped().Select(clause.Associations, "Pets.Toy").Delete(&user).Error; err != nil { + t.Fatalf("failed to delete user, got error %v", err) + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } +} + func TestDeleteSliceWithAssociations(t *testing.T) { users := []User{ *GetUser("delete_slice_with_associations1", Config{Account: true, Pets: 4, Toys: 1, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 4}), From 2ba612e80591c26ef512af629d2e1532fc48b5b9 Mon Sep 17 00:00:00 2001 From: yrong1997 Date: Tue, 9 Feb 2021 16:03:02 +0800 Subject: [PATCH 812/881] Add field tag to ignore migration (#4028) * Add field tag to ignore migration * Fix null value with space * refactor migration tag --- .gitignore | 1 + migrator/migrator.go | 2 +- schema/field.go | 24 +++++++++++++++++++----- tests/migrate_test.go | 22 ++++++++++++++-------- 4 files changed, 35 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index c14d6005..e1b9ecea 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ TODO* documents coverage.txt _book +.idea diff --git a/migrator/migrator.go b/migrator/migrator.go index 91dd8e83..4e5051cf 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -396,7 +396,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } - if alterColumn { + if alterColumn && !field.IgnoreMigration { return m.DB.Migrator().AlterColumn(value, field.Name) } diff --git a/schema/field.go b/schema/field.go index 17cc6c43..5e792ed1 100644 --- a/schema/field.go +++ b/schema/field.go @@ -70,6 +70,7 @@ type Field struct { ReflectValueOf func(reflect.Value) reflect.Value ValueOf func(reflect.Value) (value interface{}, zero bool) Set func(reflect.Value, interface{}) error + IgnoreMigration bool } func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { @@ -189,6 +190,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } // default value is function or null or blank (primary keys) + field.DefaultValue = strings.TrimSpace(field.DefaultValue) skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") && strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == "" switch reflect.Indirect(fieldValue).Kind() { @@ -295,11 +297,23 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } // setup permission - if _, ok := field.TagSettings["-"]; ok { - field.Creatable = false - field.Updatable = false - field.Readable = false - field.DataType = "" + if val, ok := field.TagSettings["-"]; ok { + val = strings.ToLower(strings.TrimSpace(val)) + switch val { + case "-": + field.Creatable = false + field.Updatable = false + field.Readable = false + field.DataType = "" + case "all": + field.Creatable = false + field.Updatable = false + field.Readable = false + field.DataType = "" + field.IgnoreMigration = true + case "migration": + field.IgnoreMigration = true + } } if v, ok := field.TagSettings["->"]; ok { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index ca28dfbc..51843062 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -62,10 +62,11 @@ func TestSmartMigrateColumn(t *testing.T) { DB.AutoMigrate(&UserMigrateColumn{}) type UserMigrateColumn2 struct { - ID uint - Name string `gorm:"size:128"` - Salary float64 `gorm:"precision:2"` - Birthday time.Time `gorm:"precision:2"` + ID uint + Name string `gorm:"size:128"` + Salary float64 `gorm:"precision:2"` + Birthday time.Time `gorm:"precision:2"` + NameIgnoreMigration string `gorm:"size:100"` } if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil { @@ -95,10 +96,11 @@ func TestSmartMigrateColumn(t *testing.T) { } type UserMigrateColumn3 struct { - ID uint - Name string `gorm:"size:256"` - Salary float64 `gorm:"precision:3"` - Birthday time.Time `gorm:"precision:3"` + ID uint + Name string `gorm:"size:256"` + Salary float64 `gorm:"precision:3"` + Birthday time.Time `gorm:"precision:3"` + NameIgnoreMigration string `gorm:"size:128;-:migration"` } if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil { @@ -124,6 +126,10 @@ func TestSmartMigrateColumn(t *testing.T) { if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 { t.Fatalf("birthday's precision should be 2, but got %v", precision) } + case "name_ignore_migration": + if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 100 { + t.Fatalf("name_ignore_migration's length should still be 100 but got %v", length) + } } } From df24821896fb65619c892241ecd00ac3e1acd789 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 9 Feb 2021 17:05:50 +0800 Subject: [PATCH 813/881] Fix SubQuery for raw sql --- statement.go | 6 ++++++ tests/query_test.go | 11 ++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/statement.go b/statement.go index de1b300f..6ea8c883 100644 --- a/statement.go +++ b/statement.go @@ -438,6 +438,12 @@ func (stmt *Statement) clone() *Statement { SkipHooks: stmt.SkipHooks, } + if stmt.SQL.Len() > 0 { + newStmt.SQL.WriteString(stmt.SQL.String()) + newStmt.Vars = make([]interface{}, 0, len(stmt.Vars)) + newStmt.Vars = append(newStmt.Vars, stmt.Vars...) + } + for k, c := range stmt.Clauses { newStmt.Clauses[k] = c } diff --git a/tests/query_test.go b/tests/query_test.go index c6c7acb0..8ed02c98 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -991,7 +991,16 @@ func TestSubQueryWithRaw(t *testing.T) { DB.Create(&users) var count int64 - err := DB.Raw("select count(*) from (?) tmp", + err := DB.Raw("select count(*) from (?) tmp", DB.Raw("select name from users where age >= ? and name in (?)", 10, []string{"subquery_raw_1", "subquery_raw_3"})).Scan(&count).Error + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + + if count != 2 { + t.Errorf("Row count must be 1, instead got %d", count) + } + + err = DB.Raw("select count(*) from (?) tmp", DB.Table("users"). Select("name"). Where("age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_3"}). From 84ea3ec0ccf5c5e7617d3df0c22f9769dc33f3be Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 9 Feb 2021 18:56:13 +0800 Subject: [PATCH 814/881] Fix sub query argument order with multiple raw SQL --- statement.go | 28 ++++++++++++++++++++++++++-- tests/query_test.go | 4 ++-- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/statement.go b/statement.go index 6ea8c883..aac4f073 100644 --- a/statement.go +++ b/statement.go @@ -182,8 +182,32 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } case *DB: subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() - subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...) - subdb.callbacks.Query().Execute(subdb) + if v.Statement.SQL.Len() > 0 { + var ( + vars = subdb.Statement.Vars + sql = v.Statement.SQL.String() + ) + + subdb.Statement.Vars = make([]interface{}, 0, len(vars)) + for _, vv := range vars { + subdb.Statement.Vars = append(subdb.Statement.Vars, vv) + bindvar := strings.Builder{} + v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv) + sql = strings.Replace(sql, bindvar.String(), "?", 1) + } + + subdb.Statement.SQL.Reset() + subdb.Statement.Vars = stmt.Vars + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement) + } else { + clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement) + } + } else { + subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...) + subdb.callbacks.Query().Execute(subdb) + } + writer.WriteString(subdb.Statement.SQL.String()) stmt.Vars = subdb.Statement.Vars default: diff --git a/tests/query_test.go b/tests/query_test.go index 8ed02c98..be6768b1 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -991,13 +991,13 @@ func TestSubQueryWithRaw(t *testing.T) { DB.Create(&users) var count int64 - err := DB.Raw("select count(*) from (?) tmp", DB.Raw("select name from users where age >= ? and name in (?)", 10, []string{"subquery_raw_1", "subquery_raw_3"})).Scan(&count).Error + err := DB.Raw("select count(*) from (?) tmp where 1 = ? AND name IN (?)", DB.Raw("select name from users where age >= ? and name in (?)", 10, []string{"subquery_raw_1", "subquery_raw_2", "subquery_raw_3"}), 1, DB.Raw("select name from users where age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_2", "subquery_raw_3"})).Scan(&count).Error if err != nil { t.Errorf("Expected to get no errors, but got %v", err) } if count != 2 { - t.Errorf("Row count must be 1, instead got %d", count) + t.Errorf("Row count must be 2, instead got %d", count) } err = DB.Raw("select count(*) from (?) tmp", From a13b7a6acbb32b80ceac63de1ae3576bbb0cdb45 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 10 Feb 2021 14:11:29 +0800 Subject: [PATCH 815/881] Fix OnConflict where order for postgres, close #4073 --- clause/on_conflict.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/clause/on_conflict.go b/clause/on_conflict.go index 5ecd8e93..f0c3d7e7 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -26,12 +26,6 @@ func (onConflict OnConflict) Build(builder Builder) { builder.WriteString(`) `) } - if len(onConflict.Where.Exprs) > 0 { - builder.WriteString("WHERE ") - onConflict.Where.Build(builder) - builder.WriteByte(' ') - } - if onConflict.OnConstraint != "" { builder.WriteString("ON CONSTRAINT ") builder.WriteString(onConflict.OnConstraint) @@ -44,6 +38,12 @@ func (onConflict OnConflict) Build(builder Builder) { builder.WriteString("DO UPDATE SET ") onConflict.DoUpdates.Build(builder) } + + if len(onConflict.Where.Exprs) > 0 { + builder.WriteString("WHERE ") + onConflict.Where.Build(builder) + builder.WriteByte(' ') + } } // MergeClause merge onConflict clauses From 5744e29fbdc8391519d1a822cf149550bacbd43d Mon Sep 17 00:00:00 2001 From: Joel Nordell Date: Sat, 13 Feb 2021 18:16:24 -0600 Subject: [PATCH 816/881] Replacer interface for more flexible NamingStrategy (#4042) * Change NameReplacer to an interface, allowing custom Replacers. * Add NoLowerCase option to skip the snake_casing of names. * Move sync.Map from global variable into member of NamingStrategy. This maintains backward compatibility by making the smap optional - the NamingStrategy still works if it is nil. gorm.Open activates it by calling Init() if the given Namer is a schema.NamingStrategy. Also, this changes the key stored in the smap to be the original name, instead of the replaced name. * Refactor NamingStrategy tests to add more assertions about how and when Replacers get called. * Remove the name cache from NamingStrategy. --- schema/naming.go | 19 +++++---- schema/naming_test.go | 96 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 7 deletions(-) diff --git a/schema/naming.go b/schema/naming.go index f6d15f5a..e10c9212 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -4,7 +4,6 @@ import ( "crypto/sha1" "fmt" "strings" - "sync" "unicode/utf8" "github.com/jinzhu/inflection" @@ -20,11 +19,17 @@ type Namer interface { IndexName(table, column string) string } +// Replacer replacer interface like strings.Replacer +type Replacer interface { + Replace(name string) string +} + // NamingStrategy tables, columns naming strategy type NamingStrategy struct { TablePrefix string SingularTable bool - NameReplacer *strings.Replacer + NameReplacer Replacer + NoLowerCase bool } // TableName convert string to table name @@ -42,7 +47,7 @@ func (ns NamingStrategy) ColumnName(table, column string) string { // JoinTableName convert string to join table name func (ns NamingStrategy) JoinTableName(str string) string { - if strings.ToLower(str) == str { + if !ns.NoLowerCase && strings.ToLower(str) == str { return ns.TablePrefix + str } @@ -81,7 +86,6 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string { } var ( - smap sync.Map // https://github.com/golang/lint/blob/master/lint.go#L770 commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} commonInitialismsReplacer *strings.Replacer @@ -98,14 +102,16 @@ func init() { func (ns NamingStrategy) toDBName(name string) string { if name == "" { return "" - } else if v, ok := smap.Load(name); ok { - return v.(string) } if ns.NameReplacer != nil { name = ns.NameReplacer.Replace(name) } + if ns.NoLowerCase { + return name + } + var ( value = commonInitialismsReplacer.Replace(name) buf strings.Builder @@ -143,6 +149,5 @@ func (ns NamingStrategy) toDBName(name string) string { buf.WriteByte(value[len(value)-1]) } ret := buf.String() - smap.Store(name, ret) return ret } diff --git a/schema/naming_test.go b/schema/naming_test.go index b7a32160..08f8d498 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -72,3 +72,99 @@ func TestNamingStrategy(t *testing.T) { t.Errorf("invalid column name generated, got %v", columdName) } } + +type CustomReplacer struct { + f func(string) string +} + +func (r CustomReplacer) Replace(name string) string { + return r.f(name) +} + +func TestCustomReplacer(t *testing.T) { + var ns = NamingStrategy{ + TablePrefix: "public.", + SingularTable: true, + NameReplacer: CustomReplacer{ + func(name string) string { + replaced := "REPLACED_" + strings.ToUpper(name) + return strings.NewReplacer("CID", "_Cid").Replace(replaced) + }, + }, + NoLowerCase: false, + } + + idxName := ns.IndexName("public.table", "name") + if idxName != "idx_public_table_replaced_name" { + t.Errorf("invalid index name generated, got %v", idxName) + } + + chkName := ns.CheckerName("public.table", "name") + if chkName != "chk_public_table_name" { + t.Errorf("invalid checker name generated, got %v", chkName) + } + + joinTable := ns.JoinTableName("user_languages") + if joinTable != "public.user_languages" { // Seems like a bug in NamingStrategy to skip the Replacer when the name is lowercase here. + t.Errorf("invalid join table generated, got %v", joinTable) + } + + joinTable2 := ns.JoinTableName("UserLanguage") + if joinTable2 != "public.replaced_userlanguage" { + t.Errorf("invalid join table generated, got %v", joinTable2) + } + + tableName := ns.TableName("Company") + if tableName != "public.replaced_company" { + t.Errorf("invalid table name generated, got %v", tableName) + } + + columdName := ns.ColumnName("", "NameCID") + if columdName != "replaced_name_cid" { + t.Errorf("invalid column name generated, got %v", columdName) + } +} + +func TestCustomReplacerWithNoLowerCase(t *testing.T) { + var ns = NamingStrategy{ + TablePrefix: "public.", + SingularTable: true, + NameReplacer: CustomReplacer{ + func(name string) string { + replaced := "REPLACED_" + strings.ToUpper(name) + return strings.NewReplacer("CID", "_Cid").Replace(replaced) + }, + }, + NoLowerCase: true, + } + + idxName := ns.IndexName("public.table", "name") + if idxName != "idx_public_table_REPLACED_NAME" { + t.Errorf("invalid index name generated, got %v", idxName) + } + + chkName := ns.CheckerName("public.table", "name") + if chkName != "chk_public_table_name" { + t.Errorf("invalid checker name generated, got %v", chkName) + } + + joinTable := ns.JoinTableName("user_languages") + if joinTable != "public.REPLACED_USER_LANGUAGES" { + t.Errorf("invalid join table generated, got %v", joinTable) + } + + joinTable2 := ns.JoinTableName("UserLanguage") + if joinTable2 != "public.REPLACED_USERLANGUAGE" { + t.Errorf("invalid join table generated, got %v", joinTable2) + } + + tableName := ns.TableName("Company") + if tableName != "public.REPLACED_COMPANY" { + t.Errorf("invalid table name generated, got %v", tableName) + } + + columdName := ns.ColumnName("", "NameCID") + if columdName != "REPLACED_NAME_Cid" { + t.Errorf("invalid column name generated, got %v", columdName) + } +} From 628a0ae707f230c67bca2e632fb302037c707705 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 15 Feb 2021 09:10:51 +0800 Subject: [PATCH 817/881] Fix foreign key & reference with same name, close #4081 --- schema/relationship.go | 20 ++++++++++++++++---- schema/relationship_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index 0eaace89..1aa2d11a 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -81,7 +81,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { } else { switch field.IndirectFieldType.Kind() { case reflect.Struct: - schema.guessRelation(relation, field, guessBelongs) + schema.guessRelation(relation, field, guessGuess) case reflect.Slice: schema.guessRelation(relation, field, guessHas) default: @@ -341,20 +341,32 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel type guessLevel int const ( - guessBelongs guessLevel = iota + guessGuess guessLevel = iota + guessBelongs guessEmbeddedBelongs guessHas guessEmbeddedHas ) -func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) { +func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl guessLevel) { var ( primaryFields, foreignFields []*Field primarySchema, foreignSchema = schema, relation.FieldSchema + gl = cgl ) + if gl == guessGuess { + if field.Schema == relation.FieldSchema { + gl = guessBelongs + } else { + gl = guessHas + } + } + reguessOrErr := func() { - switch gl { + switch cgl { + case guessGuess: + schema.guessRelation(relation, field, guessBelongs) case guessBelongs: schema.guessRelation(relation, field, guessEmbeddedBelongs) case guessEmbeddedBelongs: diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 64d0c2a7..a34777b7 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -433,3 +433,27 @@ func TestEmbeddedRelation(t *testing.T) { t.Fatalf("expects created by relations, but not found") } } + +func TestSameForeignKey(t *testing.T) { + type UserAux struct { + gorm.Model + Aux string + UUID string + } + + type User struct { + gorm.Model + Name string + UUID string + Aux *UserAux `gorm:"foreignkey:UUID;references:UUID"` + } + + checkStructRelation(t, &User{}, + Relation{ + Name: "Aux", Type: schema.HasOne, Schema: "User", FieldSchema: "UserAux", + References: []Reference{ + {"UUID", "User", "UUID", "UserAux", "", true}, + }, + }, + ) +} From 92a238945056cbbe204e096d98fd76e1e01ab61d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 16 Feb 2021 08:35:19 +0800 Subject: [PATCH 818/881] Fix create duplicated constraint, close #4090 --- schema/relationship.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/schema/relationship.go b/schema/relationship.go index 1aa2d11a..606e722a 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -512,6 +512,24 @@ func (rel *Relationship) ParseConstraint() *Constraint { return nil } + if rel.Type == BelongsTo { + for _, r := range rel.FieldSchema.Relationships.Relations { + if r.FieldSchema == rel.Schema && len(rel.References) == len(r.References) { + matched := true + for idx, ref := range r.References { + if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey && + rel.References[idx].PrimaryValue == ref.PrimaryValue) { + matched = false + } + } + + if matched { + return nil + } + } + } + } + var ( name string idx = strings.Index(str, ",") From 73d44a4f97c1e7ed703ca16eeb589525f15decb8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 16 Feb 2021 08:39:04 +0800 Subject: [PATCH 819/881] Fix create duplicated constraint, close #4090 --- tests/migrate_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 51843062..16c48405 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -38,7 +38,6 @@ func TestMigrate(t *testing.T) { {"user_friends", "fk_user_friends_friends"}, {"accounts", "fk_users_account"}, {"users", "fk_users_team"}, - {"users", "fk_users_manager"}, {"users", "fk_users_company"}, } { if !DB.Migrator().HasConstraint(indexes[0], indexes[1]) { @@ -335,7 +334,7 @@ func TestMigrateConstraint(t *testing.T) { t.Skip() } - names := []string{"Account", "fk_users_account", "Pets", "fk_users_pets", "Company", "fk_users_company", "Manager", "fk_users_manager", "Team", "fk_users_team", "Languages", "fk_users_languages"} + names := []string{"Account", "fk_users_account", "Pets", "fk_users_pets", "Company", "fk_users_company", "Team", "fk_users_team", "Languages", "fk_users_languages"} for _, name := range names { if !DB.Migrator().HasConstraint(&User{}, name) { From 79225bfe48831236b060a019e15b473e20644b64 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 18 Feb 2021 10:53:29 +0800 Subject: [PATCH 820/881] Fix Omit/Select without Model value, close #4098 --- statement.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/statement.go b/statement.go index aac4f073..0cb2ca32 100644 --- a/statement.go +++ b/statement.go @@ -600,12 +600,14 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( // select columns for _, column := range stmt.Selects { - if column == "*" { + if stmt.Schema == nil { + results[column] = true + } else if column == "*" { notRestricted = true for _, dbName := range stmt.Schema.DBNames { results[dbName] = true } - } else if column == clause.Associations && stmt.Schema != nil { + } else if column == clause.Associations { for _, rel := range stmt.Schema.Relationships.Relations { results[rel.Name] = true } @@ -618,11 +620,11 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( // omit columns for _, omit := range stmt.Omits { - if omit == clause.Associations { - if stmt.Schema != nil { - for _, rel := range stmt.Schema.Relationships.Relations { - results[rel.Name] = false - } + if stmt.Schema == nil { + results[omit] = false + } else if omit == clause.Associations { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = false } } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { results[field.DBName] = false From 940da051a756e425d7069a51eec412835cb6bbb1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 23 Feb 2021 19:35:20 +0800 Subject: [PATCH 821/881] Skip nested associations when create data with Select, close #4108 --- callbacks/associations.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 28c769e7..dc84e137 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -349,8 +349,6 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, columnName := "" if strings.HasPrefix(name, refName) { columnName = strings.TrimPrefix(name, refName) - } else if strings.HasPrefix(name, clause.Associations) { - columnName = name } if columnName != "" { @@ -374,6 +372,8 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, if len(selects) > 0 { tx = tx.Select(selects) + } else if len(selectColumns) > 0 && len(omits) == 0 { + tx = tx.Omit(clause.Associations) } if len(omits) > 0 { From 828e6b646bbe803d1a6b9d4aba0d8ff8b84d14f4 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 25 Feb 2021 18:49:01 +0800 Subject: [PATCH 822/881] Lazy call registered scopes --- callbacks.go | 12 ++++++++++-- statement.go | 5 +++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/callbacks.go b/callbacks.go index cb14aff1..d1b8cd58 100644 --- a/callbacks.go +++ b/callbacks.go @@ -72,8 +72,10 @@ func (cs *callbacks) Raw() *processor { } func (p *processor) Execute(db *DB) { - curTime := time.Now() - stmt := db.Statement + var ( + curTime = time.Now() + stmt = db.Statement + ) if stmt.Model == nil { stmt.Model = stmt.Dest @@ -106,6 +108,12 @@ func (p *processor) Execute(db *DB) { } } + // call scopes + for _, scope := range stmt.scopes { + db = scope(db) + } + stmt.scopes = nil + for _, f := range p.fns { f(db) } diff --git a/statement.go b/statement.go index 0cb2ca32..a6ddece1 100644 --- a/statement.go +++ b/statement.go @@ -43,6 +43,7 @@ type Statement struct { CurDestIndex int attrs []interface{} assigns []interface{} + scopes []func(*DB) *DB } type join struct { @@ -481,6 +482,10 @@ func (stmt *Statement) clone() *Statement { copy(newStmt.Joins, stmt.Joins) } + for _, scope := range stmt.scopes { + stmt.scopes = append(stmt.scopes, scope) + } + stmt.Settings.Range(func(k, v interface{}) bool { newStmt.Settings.Store(k, v) return true From 6b7d18656d8af6565ea831830f06309c3f8c9c12 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 25 Feb 2021 20:06:26 +0800 Subject: [PATCH 823/881] Lazy call registered scopes --- chainable_api.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 58b9336f..5415f5bd 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -240,11 +240,10 @@ func (db *DB) Offset(offset int) (tx *DB) { // } // // db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) -func (db *DB) Scopes(funcs ...func(*DB) *DB) *DB { - for _, f := range funcs { - db = f(db) - } - return db +func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { + tx = db.getInstance() + tx.Statement.scopes = append(tx.Statement.scopes, funcs...) + return tx } // Preload preload associations with given conditions From ddeb143eb9726dd4aa5a10581280c9b4679c6b90 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 25 Feb 2021 22:01:59 +0800 Subject: [PATCH 824/881] Lazy call registered scopes --- statement.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/statement.go b/statement.go index a6ddece1..7580d965 100644 --- a/statement.go +++ b/statement.go @@ -482,8 +482,9 @@ func (stmt *Statement) clone() *Statement { copy(newStmt.Joins, stmt.Joins) } - for _, scope := range stmt.scopes { - stmt.scopes = append(stmt.scopes, scope) + if len(stmt.scopes) > 0 { + newStmt.scopes = make([]func(*DB) *DB, len(stmt.scopes)) + copy(newStmt.scopes, stmt.scopes) } stmt.Settings.Range(func(k, v interface{}) bool { From 189547f615919db93a70a7c48ffe4ad819d14962 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 26 Feb 2021 16:43:43 +0800 Subject: [PATCH 825/881] Fix new session with Begin, close #4120 --- finisher_api.go | 2 +- tests/transaction_test.go | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 4a3c323b..2d7409c7 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -565,7 +565,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er func (db *DB) Begin(opts ...*sql.TxOptions) *DB { var ( // clone statement - tx = db.Session(&Session{Context: db.Statement.Context}) + tx = db.getInstance().Session(&Session{Context: db.Statement.Context}) opt *sql.TxOptions err error ) diff --git a/tests/transaction_test.go b/tests/transaction_test.go index c17fea3b..4e4b6149 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -41,7 +41,8 @@ func TestTransaction(t *testing.T) { t.Fatalf("Should not find record after rollback, but got %v", err) } - tx2 := DB.Begin() + txDB := DB.Where("fake_name = ?", "fake_name") + tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin() user2 := *GetUser("transaction-2", Config{}) if err := tx2.Save(&user2).Error; err != nil { t.Fatalf("No error should raise, but got %v", err) From eb9a704fda14b74a49d9b9d4d965706c848415dd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 26 Feb 2021 17:11:25 +0800 Subject: [PATCH 826/881] Fix update UpdatedAt when full saving associations, close #4115 --- callbacks/associations.go | 5 +++++ callbacks/create.go | 5 +++++ tests/update_has_one_test.go | 12 +++++++++++- 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index dc84e137..2deb8ede 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -361,6 +361,7 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, } tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{ + FullSaveAssociations: db.FullSaveAssociations, SkipHooks: db.Statement.SkipHooks, DisableNestedTransaction: true, }) @@ -370,6 +371,10 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, return true }) + if tx.Statement.FullSaveAssociations { + tx = tx.InstanceSet("gorm:update_track_time", true) + } + if len(selects) > 0 { tx = tx.Select(selects) } else if len(selectColumns) > 0 && len(omits) == 0 { diff --git a/callbacks/create.go b/callbacks/create.go index 5656b861..10da731f 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -320,6 +320,11 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { field.Set(stmt.ReflectValue, curTime) values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) } + } else if field.AutoUpdateTime > 0 { + if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok { + field.Set(stmt.ReflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) + } } } diff --git a/tests/update_has_one_test.go b/tests/update_has_one_test.go index 54568546..a61629f8 100644 --- a/tests/update_has_one_test.go +++ b/tests/update_has_one_test.go @@ -2,6 +2,7 @@ package tests_test import ( "testing" + "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" @@ -31,7 +32,10 @@ func TestUpdateHasOne(t *testing.T) { var user3 User DB.Preload("Account").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + var lastUpdatedAt = user2.Account.UpdatedAt + time.Sleep(time.Second) if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) @@ -39,7 +43,13 @@ func TestUpdateHasOne(t *testing.T) { var user4 User DB.Preload("Account").Find(&user4, "id = ?", user.ID) - CheckUser(t, user4, user) + + if lastUpdatedAt.Format(time.RFC3339) == user4.Account.UpdatedAt.Format(time.RFC3339) { + t.Fatalf("updated at should be updated, but not, old: %v, new %v", lastUpdatedAt.Format(time.RFC3339), user3.Account.UpdatedAt.Format(time.RFC3339)) + } else { + user.Account.UpdatedAt = user4.Account.UpdatedAt + CheckUser(t, user4, user) + } t.Run("Polymorphic", func(t *testing.T) { var pet = Pet{Name: "create"} From 3694ef4a2c72220ef2726115a1ee8de8a386219d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 26 Feb 2021 17:30:00 +0800 Subject: [PATCH 827/881] Fix get current table --- migrator/migrator.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 4e5051cf..263c3ffc 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -490,7 +490,8 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ } } } - return nil, nil, "" + + return nil, nil, stmt.Schema.Table } func (m Migrator) CreateConstraint(value interface{}, name string) error { From 01570995762405b43e6b34cb5ca655de5c90b83b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 4 Mar 2021 17:14:08 +0800 Subject: [PATCH 828/881] Use functional options --- gorm.go | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/gorm.go b/gorm.go index 6adf455a..024a8079 100644 --- a/gorm.go +++ b/gorm.go @@ -56,6 +56,26 @@ type Config struct { cacheStore *sync.Map } +func (c *Config) Apply(config *Config) error { + return nil +} + +func (c *Config) AfterInitialize(db *DB) error { + if db != nil { + for _, plugin := range c.Plugins { + if err := plugin.Initialize(db); err != nil { + return err + } + } + } + return nil +} + +type Option interface { + Apply(*Config) error + AfterInitialize(*DB) error +} + // DB GORM DB definition type DB struct { *Config @@ -83,9 +103,16 @@ type Session struct { } // Open initialize db session based on dialector -func Open(dialector Dialector, config *Config) (db *DB, err error) { - if config == nil { - config = &Config{} +func Open(dialector Dialector, opts ...Option) (db *DB, err error) { + config := &Config{} + + for _, opt := range opts { + if opt != nil { + if err := opt.Apply(config); err != nil { + return nil, err + } + defer opt.AfterInitialize(db) + } } if config.NamingStrategy == nil { @@ -106,14 +133,6 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { if config.Plugins == nil { config.Plugins = map[string]Plugin{} - } else { - for _, p := range config.Plugins { - defer func(plugin Plugin) { - if errr := plugin.Initialize(db); errr != nil { - err = errr - } - }(p) - } } if config.cacheStore == nil { From 42999e980916d8a5ee257eb116a351bceace691f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 4 Mar 2021 18:28:32 +0800 Subject: [PATCH 829/881] Fix overwrite preloading associations, close #4134 --- callbacks/query.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/callbacks/query.go b/callbacks/query.go index 5a97e1ad..658216df 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -172,7 +172,9 @@ func Preload(db *gorm.DB) { if name == clause.Associations { for _, rel := range db.Statement.Schema.Relationships.Relations { if rel.Schema == db.Statement.Schema { - preloadMap[rel.Name] = map[string][]interface{}{} + if _, ok := preloadMap[rel.Name]; !ok { + preloadMap[rel.Name] = map[string][]interface{}{} + } } } } else { From 90476fea7a2b6701829fa5b3ff6338021549ba3e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 4 Mar 2021 18:40:47 +0800 Subject: [PATCH 830/881] Fix Join with slice IN, close #4133 --- clause/expression.go | 38 ++++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 7a4c09f4..f76ce138 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -78,9 +78,10 @@ type NamedExpr struct { // Build build raw expression func (expr NamedExpr) Build(builder Builder) { var ( - idx int - inName bool - namedMap = make(map[string]interface{}, len(expr.Vars)) + idx int + inName bool + afterParenthesis bool + namedMap = make(map[string]interface{}, len(expr.Vars)) ) for _, v := range expr.Vars { @@ -131,13 +132,42 @@ func (expr NamedExpr) Build(builder Builder) { inName = false } + afterParenthesis = false builder.WriteByte(v) } else if v == '?' && len(expr.Vars) > idx { - builder.AddVar(builder, expr.Vars[idx]) + if afterParenthesis { + if _, ok := expr.Vars[idx].(driver.Valuer); ok { + builder.AddVar(builder, expr.Vars[idx]) + } else { + switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() == 0 { + builder.AddVar(builder, nil) + } else { + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + } + default: + builder.AddVar(builder, expr.Vars[idx]) + } + } + } else { + builder.AddVar(builder, expr.Vars[idx]) + } + idx++ } else if inName { name = append(name, v) } else { + if v == '(' { + afterParenthesis = true + } else { + afterParenthesis = false + } builder.WriteByte(v) } } From 664755270ddba77cc669de814afca71ae5575fce Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 4 Mar 2021 19:16:08 +0800 Subject: [PATCH 831/881] Don't override the from clauses, close #4129 --- callbacks/query.go | 5 +++++ tests/sql_builder_test.go | 45 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/callbacks/query.go b/callbacks/query.go index 658216df..aaa19c03 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -104,6 +104,11 @@ func BuildQuerySQL(db *gorm.DB) { } joins := []clause.Join{} + + if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { + joins = fromClause.Joins + } + for _, join := range db.Statement.Joins { if db.Statement.Schema == nil { joins = append(joins, clause.Join{ diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index acb08130..081b96c9 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -6,6 +6,7 @@ import ( "testing" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -242,3 +243,47 @@ func TestCombineStringConditions(t *testing.T) { t.Fatalf("invalid sql generated, got %v", sql) } } + +func TestFromWithJoins(t *testing.T) { + var result User + + newDB := DB.Session(&gorm.Session{NewDB: true, DryRun: true}).Table("users") + + newDB.Clauses( + clause.From{ + Tables: []clause.Table{{Name: "users"}}, + Joins: []clause.Join{ + { + Table: clause.Table{Name: "companies", Raw: false}, + ON: clause.Where{ + Exprs: []clause.Expression{ + clause.Eq{ + Column: clause.Column{ + Table: "users", + Name: "company_id", + }, + Value: clause.Column{ + Table: "companies", + Name: "id", + }, + }, + }, + }, + }, + }, + }, + ) + + newDB.Joins("inner join rgs on rgs.id = user.id") + + stmt := newDB.First(&result).Statement + str := stmt.SQL.String() + + if !strings.Contains(str, "rgs.id = user.id") { + t.Errorf("The second join condition is over written instead of combining") + } + + if !strings.Contains(str, "`users`.`company_id` = `companies`.`id`") && !strings.Contains(str, "\"users\".\"company_id\" = \"companies\".\"id\"") { + t.Errorf("The first join condition is over written instead of combining") + } +} From adf85d5b82fe8b3a9aa5ad627ee5268cc519ab4f Mon Sep 17 00:00:00 2001 From: Sivchari <55221074+sivchari@users.noreply.github.com> Date: Thu, 4 Mar 2021 20:44:15 +0900 Subject: [PATCH 832/881] change the method of initializing slice (#4097) * change the method of initializing slice and fixed the length to be specified as 0 * keep the association.go code in the var group * keep the association.go code in the var group * change to initializing in var group --- callbacks/associations.go | 10 +++++----- callbacks/delete.go | 8 ++++---- callbacks/preload.go | 9 +++++++-- schema/naming.go | 2 +- 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 2deb8ede..10819dcc 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -39,7 +39,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: var ( - objs []reflect.Value + objs = make([]reflect.Value, 0, db.Statement.ReflectValue.Len()) fieldType = rel.Field.FieldType isPtr = fieldType.Kind() == reflect.Ptr ) @@ -140,7 +140,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { } if elems.Len() > 0 { - assignmentColumns := []string{} + assignmentColumns := make([]string, 0, len(rel.References)) for _, ref := range rel.References { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } @@ -154,7 +154,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { f = f.Addr() } - assignmentColumns := []string{} + assignmentColumns := make([]string, 0, len(rel.References)) for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) @@ -219,7 +219,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { } if elems.Len() > 0 { - assignmentColumns := []string{} + assignmentColumns := make([]string, 0, len(rel.References)) for _, ref := range rel.References { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } @@ -324,7 +324,7 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[ } if len(defaultUpdatingColumns) > 0 { - var columns []clause.Column + columns := make([]clause.Column, 0, len(s.PrimaryFieldDBNames)) for _, dbName := range s.PrimaryFieldDBNames { columns = append(columns, clause.Column{Name: dbName}) } diff --git a/callbacks/delete.go b/callbacks/delete.go index 128722a1..64dd7236 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -41,7 +41,7 @@ func DeleteBeforeAssociations(db *gorm.DB) { } if len(db.Statement.Selects) > 0 { - var selects []string + selects := make([]string, 0, len(db.Statement.Selects)) for _, s := range db.Statement.Selects { if s == clause.Associations { selects = append(selects, s) @@ -69,9 +69,9 @@ func DeleteBeforeAssociations(db *gorm.DB) { } case schema.Many2Many: var ( - queryConds []clause.Expression - foreignFields []*schema.Field - relForeignKeys []string + queryConds = make([]clause.Expression, 0, len(rel.References)) + foreignFields = make([]*schema.Field, 0, len(rel.References)) + relForeignKeys = make([]string, 0, len(rel.References)) modelValue = reflect.New(rel.JoinTable.ModelType).Interface() table = rel.JoinTable.Table tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) diff --git a/callbacks/preload.go b/callbacks/preload.go index 27e3c3dd..eafd407d 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -27,8 +27,13 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload }) if rel.JoinTable != nil { - var joinForeignFields, joinRelForeignFields []*schema.Field - var joinForeignKeys []string + + var ( + joinForeignFields = make([]*schema.Field, 0, len(rel.References)) + joinRelForeignFields = make([]*schema.Field, 0, len(rel.References)) + joinForeignKeys = make([]string, 0, len(rel.References)) + ) + for _, ref := range rel.References { if ref.OwnPrimaryKey { joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName) diff --git a/schema/naming.go b/schema/naming.go index e10c9212..0643d1bd 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -92,7 +92,7 @@ var ( ) func init() { - var commonInitialismsForReplacer []string + commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms)) for _, initialism := range commonInitialisms { commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) } From 1476b2f7d443197f8cad869d7da3bd142cfc277d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 4 Mar 2021 20:37:39 +0800 Subject: [PATCH 833/881] Fix apply config --- gorm.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gorm.go b/gorm.go index 024a8079..a484b002 100644 --- a/gorm.go +++ b/gorm.go @@ -57,6 +57,9 @@ type Config struct { } func (c *Config) Apply(config *Config) error { + if config != c { + *config = *c + } return nil } From 294625759c63af2ea412369a13b8f4d3c76b4433 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 5 Mar 2021 14:12:55 +0800 Subject: [PATCH 834/881] Fix after initialize db callback --- gorm.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gorm.go b/gorm.go index a484b002..f11fb9e1 100644 --- a/gorm.go +++ b/gorm.go @@ -114,7 +114,9 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { if err := opt.Apply(config); err != nil { return nil, err } - defer opt.AfterInitialize(db) + defer func() { + opt.AfterInitialize(db) + }() } } From d6c23586ae435a124353d3c5dfa6f504c24c5c3c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 5 Mar 2021 19:42:54 +0800 Subject: [PATCH 835/881] Revert "Don't override the from clauses, close #4129" close #4139 This reverts commit 664755270ddba77cc669de814afca71ae5575fce. --- callbacks/query.go | 5 ----- tests/sql_builder_test.go | 45 --------------------------------------- 2 files changed, 50 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index aaa19c03..658216df 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -104,11 +104,6 @@ func BuildQuerySQL(db *gorm.DB) { } joins := []clause.Join{} - - if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { - joins = fromClause.Joins - } - for _, join := range db.Statement.Joins { if db.Statement.Schema == nil { joins = append(joins, clause.Join{ diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 081b96c9..acb08130 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -6,7 +6,6 @@ import ( "testing" "gorm.io/gorm" - "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -243,47 +242,3 @@ func TestCombineStringConditions(t *testing.T) { t.Fatalf("invalid sql generated, got %v", sql) } } - -func TestFromWithJoins(t *testing.T) { - var result User - - newDB := DB.Session(&gorm.Session{NewDB: true, DryRun: true}).Table("users") - - newDB.Clauses( - clause.From{ - Tables: []clause.Table{{Name: "users"}}, - Joins: []clause.Join{ - { - Table: clause.Table{Name: "companies", Raw: false}, - ON: clause.Where{ - Exprs: []clause.Expression{ - clause.Eq{ - Column: clause.Column{ - Table: "users", - Name: "company_id", - }, - Value: clause.Column{ - Table: "companies", - Name: "id", - }, - }, - }, - }, - }, - }, - }, - ) - - newDB.Joins("inner join rgs on rgs.id = user.id") - - stmt := newDB.First(&result).Statement - str := stmt.SQL.String() - - if !strings.Contains(str, "rgs.id = user.id") { - t.Errorf("The second join condition is over written instead of combining") - } - - if !strings.Contains(str, "`users`.`company_id` = `companies`.`id`") && !strings.Contains(str, "\"users\".\"company_id\" = \"companies\".\"id\"") { - t.Errorf("The first join condition is over written instead of combining") - } -} From a948c846071f7e4fd264c6a95a81a0ef04293a28 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 5 Mar 2021 22:18:12 +0800 Subject: [PATCH 836/881] Revert "Revert "Don't override the from clauses, close #4129" close #4139" This reverts commit d6c23586ae435a124353d3c5dfa6f504c24c5c3c. --- callbacks/query.go | 6 ++++++ tests/sql_builder_test.go | 45 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/callbacks/query.go b/callbacks/query.go index 658216df..1868c247 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -104,6 +104,11 @@ func BuildQuerySQL(db *gorm.DB) { } joins := []clause.Join{} + + if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { + joins = fromClause.Joins + } + for _, join := range db.Statement.Joins { if db.Statement.Schema == nil { joins = append(joins, clause.Join{ @@ -154,6 +159,7 @@ func BuildQuerySQL(db *gorm.DB) { } } + db.Statement.Joins = nil db.Statement.AddClause(clause.From{Joins: joins}) } else { db.Statement.AddClauseIfNotExists(clause.From{}) diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index acb08130..081b96c9 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -6,6 +6,7 @@ import ( "testing" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -242,3 +243,47 @@ func TestCombineStringConditions(t *testing.T) { t.Fatalf("invalid sql generated, got %v", sql) } } + +func TestFromWithJoins(t *testing.T) { + var result User + + newDB := DB.Session(&gorm.Session{NewDB: true, DryRun: true}).Table("users") + + newDB.Clauses( + clause.From{ + Tables: []clause.Table{{Name: "users"}}, + Joins: []clause.Join{ + { + Table: clause.Table{Name: "companies", Raw: false}, + ON: clause.Where{ + Exprs: []clause.Expression{ + clause.Eq{ + Column: clause.Column{ + Table: "users", + Name: "company_id", + }, + Value: clause.Column{ + Table: "companies", + Name: "id", + }, + }, + }, + }, + }, + }, + }, + ) + + newDB.Joins("inner join rgs on rgs.id = user.id") + + stmt := newDB.First(&result).Statement + str := stmt.SQL.String() + + if !strings.Contains(str, "rgs.id = user.id") { + t.Errorf("The second join condition is over written instead of combining") + } + + if !strings.Contains(str, "`users`.`company_id` = `companies`.`id`") && !strings.Contains(str, "\"users\".\"company_id\" = \"companies\".\"id\"") { + t.Errorf("The first join condition is over written instead of combining") + } +} From 495ec4bd87e9fb7751e7d5d10f9feae7c671eef0 Mon Sep 17 00:00:00 2001 From: heige Date: Sun, 7 Mar 2021 10:56:32 +0800 Subject: [PATCH 837/881] invalid db error and value and invalid value length error (#4151) --- association.go | 3 +-- callbacks.go | 2 +- errors.go | 6 ++++++ gorm.go | 3 +-- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/association.go b/association.go index 3a2942fd..572f1526 100644 --- a/association.go +++ b/association.go @@ -1,7 +1,6 @@ package gorm import ( - "errors" "fmt" "reflect" "strings" @@ -441,7 +440,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ break } - association.Error = errors.New("invalid association values, length doesn't match") + association.Error = ErrInvalidValueOfLength return } diff --git a/callbacks.go b/callbacks.go index d1b8cd58..5b878af0 100644 --- a/callbacks.go +++ b/callbacks.go @@ -104,7 +104,7 @@ func (p *processor) Execute(db *DB) { stmt.ReflectValue = stmt.ReflectValue.Elem() } if !stmt.ReflectValue.IsValid() { - db.AddError(fmt.Errorf("invalid value")) + db.AddError(ErrInvalidValue) } } diff --git a/errors.go b/errors.go index 08755083..5f464d2b 100644 --- a/errors.go +++ b/errors.go @@ -31,4 +31,10 @@ var ( ErrEmptySlice = errors.New("empty slice found") // ErrDryRunModeUnsupported dry run mode unsupported ErrDryRunModeUnsupported = errors.New("dry run mode unsupported") + // ErrInvaildDB invalid db + ErrInvaildDB = errors.New("invalid db") + // ErrInvalidValue invalid value + ErrInvalidValue = errors.New("invalid value") + // ErrInvalidValueOfLength invalid values do not match length + ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match") ) diff --git a/gorm.go b/gorm.go index f11fb9e1..0eb377e9 100644 --- a/gorm.go +++ b/gorm.go @@ -3,7 +3,6 @@ package gorm import ( "context" "database/sql" - "errors" "fmt" "sync" "time" @@ -331,7 +330,7 @@ func (db *DB) DB() (*sql.DB, error) { return sqldb, nil } - return nil, errors.New("invalid db") + return nil, ErrInvaildDB } func (db *DB) getInstance() *DB { From bc347758e55b1c95a7f4c1eccfc9775f1736b901 Mon Sep 17 00:00:00 2001 From: heige Date: Sun, 7 Mar 2021 10:57:22 +0800 Subject: [PATCH 838/881] for Config.cacheStore store PreparedStmtDB key (#4149) --- gorm.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/gorm.go b/gorm.go index 0eb377e9..53df4194 100644 --- a/gorm.go +++ b/gorm.go @@ -12,6 +12,9 @@ import ( "gorm.io/gorm/schema" ) +// for Config.cacheStore store PreparedStmtDB key +const preparedStmtDBKey = "preparedStmt" + // Config GORM config type Config struct { // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity @@ -161,7 +164,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { Mux: &sync.RWMutex{}, PreparedSQL: make([]string, 0, 100), } - db.cacheStore.Store("preparedStmt", preparedStmt) + db.cacheStore.Store(preparedStmtDBKey, preparedStmt) if config.PrepareStmt { db.ConnPool = preparedStmt @@ -224,7 +227,7 @@ func (db *DB) Session(config *Session) *DB { } if config.PrepareStmt { - if v, ok := db.cacheStore.Load("preparedStmt"); ok { + if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { preparedStmt := v.(*PreparedStmtDB) tx.Statement.ConnPool = &PreparedStmtDB{ ConnPool: db.Config.ConnPool, From a3abb5fedf1ae939c1383b13cfbaee3b9d6c9f7f Mon Sep 17 00:00:00 2001 From: Ratan Phayade Date: Sun, 7 Mar 2021 08:29:00 +0530 Subject: [PATCH 839/881] support named params in Select API (#4142) * adds support for named arguments in select * changes clause identifies and adds test --- chainable_api.go | 7 ++++++- clause/clause.go | 6 +++--- tests/query_test.go | 6 ++++++ 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 5415f5bd..12db6830 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -98,7 +98,12 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { Distinct: db.Statement.Distinct, Expression: clause.Expr{SQL: v, Vars: args}, }) - } else { + } else if strings.Count(v, "@") > 0 && len(args) > 0 { + tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, + Expression: clause.NamedExpr{SQL: v, Vars: args}, + }) + } else { tx.Statement.Selects = []string{v} for _, arg := range args { diff --git a/clause/clause.go b/clause/clause.go index 828d2cf2..de19f2e3 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -62,9 +62,9 @@ func (c Clause) Build(builder Builder) { } const ( - PrimaryKey string = "@@@py@@@" // primary key - CurrentTable string = "@@@ct@@@" // current table - Associations string = "@@@as@@@" // associations + PrimaryKey string = "~~~py~~~" // primary key + CurrentTable string = "~~~ct~~~" // current table + Associations string = "~~~as~~~" // associations ) var ( diff --git a/tests/query_test.go b/tests/query_test.go index be6768b1..ee157a13 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -628,6 +628,12 @@ func TestSelect(t *testing.T) { t.Fatalf("Build Select with func, but got %v", r.Statement.SQL.String()) } + // named arguments + r = dryDB.Table("users").Select("COALESCE(age, @default)", sql.Named("default", 42)).Find(&User{}) + if !regexp.MustCompile(`SELECT COALESCE\(age,.*\) FROM .*users.*`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with func, but got %v", r.Statement.SQL.String()) + } + if _, err := DB.Table("users").Select("COALESCE(age,?)", "42").Rows(); err != nil { t.Fatalf("Failed, got error: %v", err) } From 221d0a0ec1c929182cab16e9c2620dfae459796a Mon Sep 17 00:00:00 2001 From: heige Date: Mon, 8 Mar 2021 10:20:04 +0800 Subject: [PATCH 840/881] optimize value of reflection length (#4152) --- finisher_api.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 2d7409c7..bef65ae5 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -35,10 +35,12 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { tx = db.getInstance() callFc := func(tx *DB) error { - for i := 0; i < reflectValue.Len(); i += batchSize { + // the reflection length judgment of the optimized value + reflectLen := reflectValue.Len() + for i := 0; i < reflectLen; i += batchSize { ends := i + batchSize - if ends > reflectValue.Len() { - ends = reflectValue.Len() + if ends > reflectLen { + ends = reflectLen } subtx := tx.getInstance() From 02cb40531ea2234acc8b201486588a0a6bc72da6 Mon Sep 17 00:00:00 2001 From: heige Date: Mon, 8 Mar 2021 10:21:33 +0800 Subject: [PATCH 841/881] Optimize parse constraint (#4153) * for Config.cacheStore store PreparedStmtDB key * invalid db error and value and invalid value length error (#4151) * support named params in Select API (#4142) * adds support for named arguments in select * changes clause identifies and adds test * optimize match english letters and midline Co-authored-by: Ratan Phayade --- schema/check.go | 6 +++--- schema/relationship.go | 7 +++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/schema/check.go b/schema/check.go index ec66bad2..161a6ac6 100644 --- a/schema/check.go +++ b/schema/check.go @@ -6,8 +6,8 @@ import ( ) var ( - // match English letters and midline - regEnLetterAndmidline = regexp.MustCompile("^[A-Za-z-_]+$") + // reg match english letters and midline + regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$") ) type Check struct { @@ -22,7 +22,7 @@ func (schema *Schema) ParseCheckConstraints() map[string]Check { for _, field := range schema.FieldsByDBName { if chk := field.TagSettings["CHECK"]; chk != "" { names := strings.Split(chk, ",") - if len(names) > 1 && regEnLetterAndmidline.MatchString(names[0]) { + if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) { checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} } else { if names[0] == "" { diff --git a/schema/relationship.go b/schema/relationship.go index 606e722a..1b93ef88 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -3,7 +3,6 @@ package schema import ( "fmt" "reflect" - "regexp" "strings" "github.com/jinzhu/inflection" @@ -536,7 +535,11 @@ func (rel *Relationship) ParseConstraint() *Constraint { settings = ParseTagSetting(str, ",") ) - if idx != -1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(str[0:idx]) { + // optimize match english letters and midline + // The following code is basically called in for. + // In order to avoid the performance problems caused by repeated compilation of regular expressions, + // it only needs to be done once outside, so optimization is done here. + if idx != -1 && regEnLetterAndMidline.MatchString(str[0:idx]) { name = str[0:idx] } else { name = rel.Schema.namer.RelationshipFKName(*rel) From 0348b1d3c155b1c8b2f0ae3968a7c71be6e68ad1 Mon Sep 17 00:00:00 2001 From: Shubhendra Singh Chauhan Date: Mon, 8 Mar 2021 08:16:43 +0530 Subject: [PATCH 842/881] chore: improve code quality (#4123) * Combine multiple `append`s into a single call * Clean up copied struct fields with type conversion * Remove unnecessary use of slice --- schema/relationship.go | 4 +--- schema/utils.go | 2 +- soft_delete.go | 2 +- statement.go | 2 +- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index 1b93ef88..a8863bfe 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -428,9 +428,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu lookUpNames := []string{lookUpName} if len(primaryFields) == 1 { - lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID") - lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"Id") - lookUpNames = append(lookUpNames, schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) + lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) } for _, name := range lookUpNames { diff --git a/schema/utils.go b/schema/utils.go index 6e5fd528..d311c61b 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -142,7 +142,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map if notZero { dataKey := utils.ToStringKey(fieldValues...) if _, ok := dataResults[dataKey]; !ok { - results = append(results, fieldValues[:]) + results = append(results, fieldValues) dataResults[dataKey] = []reflect.Value{elem} } else { dataResults[dataKey] = append(dataResults[dataKey], elem) diff --git a/soft_delete.go b/soft_delete.go index bdbf03c2..b16041f1 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -129,7 +129,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { stmt.DB.AddError(ErrMissingWhereClause) } else { - SoftDeleteQueryClause{Field: sd.Field}.ModifyStatement(stmt) + SoftDeleteQueryClause(sd).ModifyStatement(stmt) } stmt.AddClauseIfNotExists(clause.Update{}) diff --git a/statement.go b/statement.go index 7580d965..6f336799 100644 --- a/statement.go +++ b/statement.go @@ -288,7 +288,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { - where.Exprs[0] = clause.AndConditions{Exprs: orConds.Exprs} + where.Exprs[0] = clause.AndConditions(orConds) } } conds = append(conds, clause.And(where.Exprs...)) From 675de6fc165aaabdfee959d1d09be58fe41c67aa Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 8 Mar 2021 19:21:09 +0800 Subject: [PATCH 843/881] Clear scopes before invoke scopes methods --- callbacks.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/callbacks.go b/callbacks.go index 5b878af0..ba7dae04 100644 --- a/callbacks.go +++ b/callbacks.go @@ -109,10 +109,11 @@ func (p *processor) Execute(db *DB) { } // call scopes - for _, scope := range stmt.scopes { + scopes := stmt.scopes + stmt.scopes = nil + for _, scope := range scopes { db = scope(db) } - stmt.scopes = nil for _, f := range p.fns { f(db) From 14b9bd163ced1e25874eaae0fe9fbfe723f5b91f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 10 Mar 2021 19:32:56 +0800 Subject: [PATCH 844/881] Don't panic when using nil pointer, close #4168 --- callbacks.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callbacks.go b/callbacks.go index ba7dae04..3e6723a1 100644 --- a/callbacks.go +++ b/callbacks.go @@ -96,7 +96,7 @@ func (p *processor) Execute(db *DB) { if stmt.Dest != nil { stmt.ReflectValue = reflect.ValueOf(stmt.Dest) for stmt.ReflectValue.Kind() == reflect.Ptr { - if stmt.ReflectValue.IsNil() { + if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() { stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem())) break } From 9fccb17d076a6dafd0bfd3329169e50097d0f2fc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 10 Mar 2021 19:46:59 +0800 Subject: [PATCH 845/881] Fix double pointer for where conditions, close #4159 --- statement.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/statement.go b/statement.go index 6f336799..3d64d443 100644 --- a/statement.go +++ b/statement.go @@ -339,6 +339,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } default: reflectValue := reflect.Indirect(reflect.ValueOf(arg)) + for reflectValue.Kind() == reflect.Ptr { + reflectValue = reflectValue.Elem() + } + if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { selectedColumns := map[string]bool{} if idx == 0 { From 912360097a2f54bb0f0ee4b02f9b39c591071837 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 11 Mar 2021 10:29:52 +0800 Subject: [PATCH 846/881] Fix Scopes with Migrator, close #4145 --- gorm.go | 6 ++++++ migrator.go | 7 +++++++ tests/migrate_test.go | 10 +++++++++- 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/gorm.go b/gorm.go index 53df4194..88212e94 100644 --- a/gorm.go +++ b/gorm.go @@ -122,6 +122,12 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { } } + if d, ok := dialector.(interface{ Apply(*Config) error }); ok { + if err = d.Apply(config); err != nil { + return + } + } + if config.NamingStrategy == nil { config.NamingStrategy = schema.NamingStrategy{} } diff --git a/migrator.go b/migrator.go index 28ac35e7..40936ef9 100644 --- a/migrator.go +++ b/migrator.go @@ -7,6 +7,13 @@ import ( // Migrator returns migrator func (db *DB) Migrator() Migrator { + // apply scopes to migrator + scopes := db.Statement.scopes + db.Statement.scopes = nil + for _, scope := range scopes { + db = scope(db) + } + return db.Dialector.Migrator(db.Session(&Session{})) } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 16c48405..4da3856f 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -15,7 +15,7 @@ func TestMigrate(t *testing.T) { rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) - DB.Migrator().DropTable("user_speaks", "user_friends") + DB.Migrator().DropTable("user_speaks", "user_friends", "ccc") if err := DB.Migrator().DropTable(allModels...); err != nil { t.Fatalf("Failed to drop table, got error %v", err) @@ -31,6 +31,14 @@ func TestMigrate(t *testing.T) { } } + DB.Scopes(func(db *gorm.DB) *gorm.DB { + return db.Table("ccc") + }).Migrator().CreateTable(&Company{}) + + if !DB.Migrator().HasTable("ccc") { + t.Errorf("failed to create table ccc") + } + for _, indexes := range [][2]string{ {"user_speaks", "fk_user_speaks_user"}, {"user_speaks", "fk_user_speaks_language"}, From c575a4e71922f7eb1c892e12eb23a0cab4adccd2 Mon Sep 17 00:00:00 2001 From: ruozhixian Date: Thu, 11 Mar 2021 16:36:49 +0800 Subject: [PATCH 847/881] support to preload all children in multiple levels associations --- callbacks/query.go | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 1868c247..df5b4d60 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -185,12 +185,26 @@ func Preload(db *gorm.DB) { } } else { preloadFields := strings.Split(name, ".") - if _, ok := preloadMap[preloadFields[0]]; !ok { - preloadMap[preloadFields[0]] = map[string][]interface{}{} - } + if preloadFields[0] == clause.Associations { + for _, rel := range db.Statement.Schema.Relationships.Relations { + if rel.Schema == db.Statement.Schema { + if _, ok := preloadMap[rel.Name]; !ok { + preloadMap[rel.Name] = map[string][]interface{}{} + } - if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { - preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name] + if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { + preloadMap[rel.Name][value] = db.Statement.Preloads[name] + } + } + } + } else { + if _, ok := preloadMap[preloadFields[0]]; !ok { + preloadMap[preloadFields[0]] = map[string][]interface{}{} + } + + if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { + preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name] + } } } } From 2055e29eb81281289673d7ebc612c245fce7c333 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 14 Mar 2021 10:18:43 +0800 Subject: [PATCH 848/881] Refactor nested preload all associations --- callbacks/query.go | 32 +++++++++++--------------------- tests/go.mod | 4 ++-- tests/preload_test.go | 4 ++++ 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index df5b4d60..11753472 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -175,36 +175,26 @@ func Preload(db *gorm.DB) { if db.Error == nil && len(db.Statement.Preloads) > 0 { preloadMap := map[string]map[string][]interface{}{} for name := range db.Statement.Preloads { - if name == clause.Associations { + preloadFields := strings.Split(name, ".") + if preloadFields[0] == clause.Associations { for _, rel := range db.Statement.Schema.Relationships.Relations { if rel.Schema == db.Statement.Schema { if _, ok := preloadMap[rel.Name]; !ok { preloadMap[rel.Name] = map[string][]interface{}{} } + + if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { + preloadMap[rel.Name][value] = db.Statement.Preloads[name] + } } } } else { - preloadFields := strings.Split(name, ".") - if preloadFields[0] == clause.Associations { - for _, rel := range db.Statement.Schema.Relationships.Relations { - if rel.Schema == db.Statement.Schema { - if _, ok := preloadMap[rel.Name]; !ok { - preloadMap[rel.Name] = map[string][]interface{}{} - } + if _, ok := preloadMap[preloadFields[0]]; !ok { + preloadMap[preloadFields[0]] = map[string][]interface{}{} + } - if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { - preloadMap[rel.Name][value] = db.Statement.Preloads[name] - } - } - } - } else { - if _, ok := preloadMap[preloadFields[0]]; !ok { - preloadMap[preloadFields[0]] = map[string][]interface{}{} - } - - if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { - preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name] - } + if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { + preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name] } } } diff --git a/tests/go.mod b/tests/go.mod index 20d7206a..0765142c 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,11 +7,11 @@ require ( github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 github.com/stretchr/testify v1.5.1 - gorm.io/driver/mysql v1.0.4 + gorm.io/driver/mysql v1.0.5 gorm.io/driver/postgres v1.0.8 gorm.io/driver/sqlite v1.1.4 gorm.io/driver/sqlserver v1.0.6 - gorm.io/gorm v1.20.12 + gorm.io/gorm v1.21.3 ) replace gorm.io/gorm => ../ diff --git a/tests/preload_test.go b/tests/preload_test.go index 4b31b12c..c9f5d278 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -65,6 +65,10 @@ func TestNestedPreload(t *testing.T) { DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + var user3 User + DB.Preload(clause.Associations+"."+clause.Associations).Find(&user3, "id = ?", user.ID) + CheckUser(t, user3, user) } func TestNestedPreloadForSlice(t *testing.T) { From 07f3795f934819f3fd7f09fa8cbf2960a4d07b61 Mon Sep 17 00:00:00 2001 From: heige Date: Wed, 17 Mar 2021 11:32:17 +0800 Subject: [PATCH 849/881] optimize MigrateColumn method for regexp (#4188) --- migrator/migrator.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 263c3ffc..075b5ca6 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -12,6 +12,11 @@ import ( "gorm.io/gorm/schema" ) +var ( + regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`) + regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`) +) + // Migrator m struct type Migrator struct { Config @@ -373,8 +378,10 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy alterColumn = true } else { // has size in data type and not equal - matches := regexp.MustCompile(`[^\d](\d+)[^\d]?`).FindAllStringSubmatch(realDataType, -1) - matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]?`).FindAllStringSubmatch(fullDataType, -1) + + // Since the following code is frequently called in the for loop, reg optimization is needed here + matches := regRealDataType.FindAllStringSubmatch(realDataType, -1) + matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) { alterColumn = true } From 27bb9137d3ad1751e47ceb6e325fb5d17b0eb7aa Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 18 Mar 2021 11:44:04 +0800 Subject: [PATCH 850/881] Refactor OnConflict.UpdateALl --- callbacks/create.go | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 10da731f..909d984a 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -353,15 +353,14 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } } - onConflict := clause.OnConflict{ - Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)), - DoUpdates: clause.AssignmentColumns(columns), - } + onConflict.DoUpdates = clause.AssignmentColumns(columns) - for idx, field := range stmt.Schema.PrimaryFields { - onConflict.Columns[idx] = clause.Column{Name: field.DBName} + // use primary fields as default OnConflict columns + if len(onConflict.Columns) == 0 { + for _, field := range stmt.Schema.PrimaryFields { + onConflict.Columns = append(onConflict.Columns, clause.Column{Name: field.DBName}) + } } - stmt.AddClause(onConflict) } } From a3d9bbfc36e40e1aa9b633f6a5c2fb2ad82d4dd6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 19 Mar 2021 13:21:43 +0800 Subject: [PATCH 851/881] build *clause.Expr --- statement.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/statement.go b/statement.go index 3d64d443..7a827ca8 100644 --- a/statement.go +++ b/statement.go @@ -167,6 +167,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) case clause.Expr: v.Build(stmt) + case *clause.Expr: + v.Build(stmt) case driver.Valuer: stmt.Vars = append(stmt.Vars, v) stmt.DB.Dialector.BindVarTo(writer, stmt, v) From e85b73e5a5d9de181c12ce4d4ed14da79119cf8a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 19 Mar 2021 13:44:25 +0800 Subject: [PATCH 852/881] Fix nested Scopes, close #4196 --- callbacks.go | 10 ++++++---- migrator.go | 10 ++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/callbacks.go b/callbacks.go index 3e6723a1..315eea17 100644 --- a/callbacks.go +++ b/callbacks.go @@ -109,10 +109,12 @@ func (p *processor) Execute(db *DB) { } // call scopes - scopes := stmt.scopes - stmt.scopes = nil - for _, scope := range scopes { - db = scope(db) + for len(stmt.scopes) > 0 { + scopes := stmt.scopes + stmt.scopes = nil + for _, scope := range scopes { + db = scope(db) + } } for _, f := range p.fns { diff --git a/migrator.go b/migrator.go index 40936ef9..f39dd9fd 100644 --- a/migrator.go +++ b/migrator.go @@ -8,10 +8,12 @@ import ( // Migrator returns migrator func (db *DB) Migrator() Migrator { // apply scopes to migrator - scopes := db.Statement.scopes - db.Statement.scopes = nil - for _, scope := range scopes { - db = scope(db) + for len(db.Statement.scopes) > 0 { + scopes := db.Statement.scopes + db.Statement.scopes = nil + for _, scope := range scopes { + db = scope(db) + } } return db.Dialector.Migrator(db.Session(&Session{})) From 220349ccf2990c47988a54df94e838803829898c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 19 Mar 2021 15:15:26 +0800 Subject: [PATCH 853/881] Fix omit associations, close #4161 --- callbacks/associations.go | 2 +- schema/relationship_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 10819dcc..2a4efbe1 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -377,7 +377,7 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, if len(selects) > 0 { tx = tx.Select(selects) - } else if len(selectColumns) > 0 && len(omits) == 0 { + } else if restricted && len(omits) == 0 { tx = tx.Omit(clause.Associations) } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index a34777b7..2971698c 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -398,6 +398,31 @@ func TestMultipleMany2Many(t *testing.T) { ) } +func TestSelfReferentialMany2Many(t *testing.T) { + type User struct { + ID int32 `gorm:"primaryKey"` + Name string + CreatedBy int32 + Creators []User `gorm:"foreignKey:CreatedBy"` + AnotherPro interface{} `gorm:"-"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Creators", Type: schema.HasMany, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "CreatedBy", "User", "", true}}, + }) + + user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse schema") + } + + relSchema := user.Relationships.Relations["Creators"].FieldSchema + if user != relSchema { + t.Fatalf("schema should be same, expects %p but got %p", user, relSchema) + } +} + type CreatedByModel struct { CreatedByID uint CreatedBy *CreatedUser From a9fe025ef53b419ea5d6406f5f79a2bc7e52d71a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 19 Mar 2021 15:54:32 +0800 Subject: [PATCH 854/881] Add GetDBConnector interface --- gorm.go | 4 ++-- interfaces.go | 4 ++++ prepare_stmt.go | 12 ++++++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/gorm.go b/gorm.go index 88212e94..9323c46d 100644 --- a/gorm.go +++ b/gorm.go @@ -331,8 +331,8 @@ func (db *DB) AddError(err error) error { func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool - if stmtDB, ok := connPool.(*PreparedStmtDB); ok { - connPool = stmtDB.ConnPool + if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() } if sqldb, ok := connPool.(*sql.DB); ok { diff --git a/interfaces.go b/interfaces.go index e933952b..44b2fced 100644 --- a/interfaces.go +++ b/interfaces.go @@ -57,3 +57,7 @@ type TxCommitter interface { type Valuer interface { GormValue(context.Context, *DB) clause.Expr } + +type GetDBConnector interface { + GetDBConn() (*sql.DB, error) +} diff --git a/prepare_stmt.go b/prepare_stmt.go index 78a8adb4..bc7ef180 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -18,6 +18,18 @@ type PreparedStmtDB struct { ConnPool } +func (db *PreparedStmtDB) GetDB() (*sql.DB, error) { + if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() + } + + if sqldb, ok := db.ConnPool.(*sql.DB); ok { + return sqldb, nil + } + + return nil, ErrInvaildDB +} + func (db *PreparedStmtDB) Close() { db.Mux.Lock() for _, query := range db.PreparedSQL { From 8c92d9694a73c565351dc547f395453cc75ef94b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 19 Mar 2021 16:34:51 +0800 Subject: [PATCH 855/881] Fix to call Scopes with using Migrator --- migrator.go | 12 +++++++----- tests/scopes_test.go | 9 +++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/migrator.go b/migrator.go index f39dd9fd..7dddcabf 100644 --- a/migrator.go +++ b/migrator.go @@ -7,16 +7,18 @@ import ( // Migrator returns migrator func (db *DB) Migrator() Migrator { + tx := db.getInstance() + // apply scopes to migrator - for len(db.Statement.scopes) > 0 { - scopes := db.Statement.scopes - db.Statement.scopes = nil + for len(tx.Statement.scopes) > 0 { + scopes := tx.Statement.scopes + tx.Statement.scopes = nil for _, scope := range scopes { - db = scope(db) + tx = scope(tx) } } - return db.Dialector.Migrator(db.Session(&Session{})) + return tx.Dialector.Migrator(tx.Session(&Session{})) } // AutoMigrate run auto migration for given models diff --git a/tests/scopes_test.go b/tests/scopes_test.go index c9787d36..9836c41e 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -45,4 +45,13 @@ func TestScopes(t *testing.T) { if len(users3) != 2 { t.Errorf("Should found two users's name in 1, 3, but got %v", len(users3)) } + + db := DB.Scopes(func(tx *gorm.DB) *gorm.DB { + return tx.Table("custom_table") + }).Session(&gorm.Session{}) + + db.AutoMigrate(&User{}) + if db.Find(&User{}).Statement.Table != "custom_table" { + t.Errorf("failed to call Scopes") + } } From 26dd4c980a62d47c990a05da9e5566bff3b2b00c Mon Sep 17 00:00:00 2001 From: Genta Kamitani Date: Mon, 22 Mar 2021 15:11:07 +0900 Subject: [PATCH 856/881] Fix: FindInBatches ignores errors (#4203) --- finisher_api.go | 2 ++ tests/query_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/finisher_api.go b/finisher_api.go index bef65ae5..b5cbfaa6 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -190,6 +190,8 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat if result.Error == nil && result.RowsAffected != 0 { tx.AddError(fc(result, batch)) + } else if result.Error != nil { + tx.AddError(result.Error) } if tx.Error != nil || int(result.RowsAffected) < batchSize { diff --git a/tests/query_test.go b/tests/query_test.go index ee157a13..489ac807 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -292,6 +292,34 @@ func TestFindInBatches(t *testing.T) { } } +func TestFindInBatchesWithError(t *testing.T) { + var users = []User{ + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + } + + DB.Create(&users) + + var ( + results []User + totalBatch int + ) + + if result := DB.Table("wrong_table").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + totalBatch += batch + return nil + }); result.Error == nil || result.RowsAffected > 0 { + t.Fatal("expected errors to have occurred, but nothing happened") + } + if totalBatch != 0 { + t.Fatalf("incorrect total batch, expected: %v, got: %v", 0, totalBatch) + } +} + func TestFillSmallerStruct(t *testing.T) { user := User{Name: "SmallerUser", Age: 100} DB.Save(&user) From 4d5cec8bdd901743a87df798b2c4d9320a0ac48c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 24 Mar 2021 14:22:36 +0800 Subject: [PATCH 857/881] Add golang 1.16 --- .github/workflows/tests.yml | 12 ++++++------ go.mod | 2 +- go.sum | 4 ++-- tests/go.mod | 2 +- tests/tests_all.sh | 10 ++-------- 5 files changed, 12 insertions(+), 18 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f26caa86..fec7d000 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,8 +13,8 @@ jobs: sqlite: strategy: matrix: - go: ['1.15', '1.14', '1.13'] - platform: [ubuntu-latest, macos-latest] # can not run in windows OS + go: ['1.16', '1.15', '1.14'] + platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} steps: @@ -38,7 +38,7 @@ jobs: sqlite_windows: strategy: matrix: - go: ['1.15', '1.14', '1.13'] + go: ['1.16', '1.15', '1.14'] platform: [windows-latest] runs-on: ${{ matrix.platform }} @@ -64,7 +64,7 @@ jobs: strategy: matrix: dbversion: ['mysql:latest', 'mysql:5.7', 'mysql:5.6', 'mariadb:latest'] - go: ['1.15', '1.14', '1.13'] + go: ['1.16', '1.15', '1.14'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -108,7 +108,7 @@ jobs: strategy: matrix: dbversion: ['postgres:latest', 'postgres:11', 'postgres:10'] - go: ['1.15', '1.14', '1.13'] + go: ['1.16', '1.15', '1.14'] platform: [ubuntu-latest] # can not run in macOS and widnowsOS runs-on: ${{ matrix.platform }} @@ -150,7 +150,7 @@ jobs: sqlserver: strategy: matrix: - go: ['1.15', '1.14', '1.13'] + go: ['1.16', '1.15', '1.14'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} diff --git a/go.mod b/go.mod index faf63a46..d95d3f10 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,5 @@ go 1.14 require ( github.com/jinzhu/inflection v1.0.0 - github.com/jinzhu/now v1.1.1 + github.com/jinzhu/now v1.1.2 ) diff --git a/go.sum b/go.sum index 148bd6f5..c66a6b57 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,4 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.1 h1:g39TucaRWyV3dwDO++eEc6qf8TVIQ/Da48WmqjZ3i7E= -github.com/jinzhu/now v1.1.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.2 h1:eVKgfIdy9b6zbWBMgFpfDPoAMifwSZagU9HmEU6zgiI= +github.com/jinzhu/now v1.1.2/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= diff --git a/tests/go.mod b/tests/go.mod index 0765142c..7743e63a 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -4,7 +4,7 @@ go 1.14 require ( github.com/google/uuid v1.1.1 - github.com/jinzhu/now v1.1.1 + github.com/jinzhu/now v1.1.2 github.com/lib/pq v1.6.0 github.com/stretchr/testify v1.5.1 gorm.io/driver/mysql v1.0.5 diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 744a40e9..2d6c35c3 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -9,11 +9,11 @@ fi if [ -d tests ] then cd tests - cp go.mod go.mod.bak - sed '/^[[:blank:]]*gorm.io\/driver/d' go.mod.bak > go.mod cd .. fi +go get -u ./... + for dialect in "${dialects[@]}" ; do if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ] then @@ -39,9 +39,3 @@ for dialect in "${dialects[@]}" ; do fi fi done - -if [ -d tests ] -then - cd tests - mv go.mod.bak go.mod -fi From 704e53a774f4e6ed1edaf4ffddc92833a7d4c918 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 24 Mar 2021 16:17:49 +0800 Subject: [PATCH 858/881] Call scopes before parse model value, close #4209 --- callbacks.go | 21 ++++++++++++--------- chainable_api.go | 2 +- tests/count_test.go | 8 ++++++++ tests/go.mod | 4 ++-- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/callbacks.go b/callbacks.go index 315eea17..f2ee0ea5 100644 --- a/callbacks.go +++ b/callbacks.go @@ -77,12 +77,23 @@ func (p *processor) Execute(db *DB) { stmt = db.Statement ) + // call scopes + for len(stmt.scopes) > 0 { + scopes := stmt.scopes + stmt.scopes = nil + for _, scope := range scopes { + db = scope(db) + } + } + + // assign model values if stmt.Model == nil { stmt.Model = stmt.Dest } else if stmt.Dest == nil { stmt.Dest = stmt.Model } + // parse model values if stmt.Model != nil { if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" { @@ -93,6 +104,7 @@ func (p *processor) Execute(db *DB) { } } + // assign stmt.ReflectValue if stmt.Dest != nil { stmt.ReflectValue = reflect.ValueOf(stmt.Dest) for stmt.ReflectValue.Kind() == reflect.Ptr { @@ -108,15 +120,6 @@ func (p *processor) Execute(db *DB) { } } - // call scopes - for len(stmt.scopes) > 0 { - scopes := stmt.scopes - stmt.scopes = nil - for _, scope := range scopes { - db = scope(db) - } - } - for _, f := range p.fns { f(db) } diff --git a/chainable_api.go b/chainable_api.go index 12db6830..e17d9bb2 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -103,7 +103,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { Distinct: db.Statement.Distinct, Expression: clause.NamedExpr{SQL: v, Vars: args}, }) - } else { + } else { tx.Statement.Selects = []string{v} for _, arg := range args { diff --git a/tests/count_test.go b/tests/count_test.go index ffe675d9..0fef82f7 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -121,4 +121,12 @@ func TestCount(t *testing.T) { }) AssertEqual(t, users, expects) + + var count9 int64 + if err := DB.Debug().Scopes(func(tx *gorm.DB) *gorm.DB { + fmt.Println("kdkdkdkdk") + return tx.Table("users") + }).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count9).Find(&users).Error; err != nil || count9 != 3 { + t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) + } } diff --git a/tests/go.mod b/tests/go.mod index 7743e63a..d4b0c975 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -10,8 +10,8 @@ require ( gorm.io/driver/mysql v1.0.5 gorm.io/driver/postgres v1.0.8 gorm.io/driver/sqlite v1.1.4 - gorm.io/driver/sqlserver v1.0.6 - gorm.io/gorm v1.21.3 + gorm.io/driver/sqlserver v1.0.7 + gorm.io/gorm v1.21.4 ) replace gorm.io/gorm => ../ From 8204d0ada27896ec312b054f36a0e32fa8c1504a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 24 Mar 2021 16:44:51 +0800 Subject: [PATCH 859/881] Update tests script --- tests/tests_all.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 2d6c35c3..e0ed97a4 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -9,11 +9,11 @@ fi if [ -d tests ] then cd tests + go get -u ./... + go mod download cd .. fi -go get -u ./... - for dialect in "${dialects[@]}" ; do if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ] then From 88078e48d0a0a3c8a31c6be4072182c7cee68756 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 24 Mar 2021 16:56:41 +0800 Subject: [PATCH 860/881] Remove sqlite_windows test case --- .github/workflows/tests.yml | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fec7d000..e2ea89a7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -35,31 +35,6 @@ jobs: - name: Tests run: GORM_DIALECT=sqlite ./tests/tests_all.sh - sqlite_windows: - strategy: - matrix: - go: ['1.16', '1.15', '1.14'] - platform: [windows-latest] - runs-on: ${{ matrix.platform }} - - steps: - - name: Set up Go 1.x - uses: actions/setup-go@v2 - with: - go-version: ${{ matrix.go }} - - - name: Check out code into the Go module directory - uses: actions/checkout@v2 - - - name: go mod package cache - uses: actions/cache@v2 - with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - - - name: Tests - run: cd tests && set GORM_DIALECT=sqlite && go test $race -count=1 -v ./... #run the line in widnows's CMD, default GORM_DIALECT is sqlite - mysql: strategy: matrix: From 26e0c6fb69841be8c387746fb31559801b30a7b9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 24 Mar 2021 17:12:30 +0800 Subject: [PATCH 861/881] skip test sqlserver due to it will raise data race for invalid sql --- tests/query_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/query_test.go b/tests/query_test.go index 489ac807..34999337 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -293,6 +293,10 @@ func TestFindInBatches(t *testing.T) { } func TestFindInBatchesWithError(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlserver" { + t.Skip("skip sqlserver due to it will raise data race for invalid sql") + } + var users = []User{ *GetUser("find_in_batches_with_error", Config{}), *GetUser("find_in_batches_with_error", Config{}), From a8b72546c1c9bbe01e126104095be842022ca6ef Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 25 Mar 2021 10:17:57 +0800 Subject: [PATCH 862/881] Fix get database connection for prepared stmt, close #4214 --- prepare_stmt.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index bc7ef180..122e98d2 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -18,7 +18,7 @@ type PreparedStmtDB struct { ConnPool } -func (db *PreparedStmtDB) GetDB() (*sql.DB, error) { +func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { return dbConnector.GetDBConn() } From 0eba7a9ed16f415c5a20dbfec8d6e3d7864b4fc8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 26 Mar 2021 14:20:42 +0800 Subject: [PATCH 863/881] Fix apply option --- gorm.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gorm.go b/gorm.go index 9323c46d..b612e1f4 100644 --- a/gorm.go +++ b/gorm.go @@ -116,9 +116,9 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { if err := opt.Apply(config); err != nil { return nil, err } - defer func() { + defer func(opt Option) { opt.AfterInitialize(db) - }() + }(opt) } } From 73c6d3e64e4341bfa47d1d2a2bd72f7d20caf149 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 29 Mar 2021 18:36:01 +0800 Subject: [PATCH 864/881] Add AfterInitialize error --- gorm.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gorm.go b/gorm.go index b612e1f4..9c4d444f 100644 --- a/gorm.go +++ b/gorm.go @@ -117,7 +117,9 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { return nil, err } defer func(opt Option) { - opt.AfterInitialize(db) + if errr := opt.AfterInitialize(db); errr != nil { + err = errr + } }(opt) } } From 33601dc72f4abf86ce68cbb663f7f5c898bee0a3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 30 Mar 2021 18:28:09 +0800 Subject: [PATCH 865/881] Support Having w/o Group --- clause/group_by.go | 6 ++++++ tests/group_by_test.go | 10 ++++++++++ 2 files changed, 16 insertions(+) diff --git a/clause/group_by.go b/clause/group_by.go index 88231916..84242fb8 100644 --- a/clause/group_by.go +++ b/clause/group_by.go @@ -39,4 +39,10 @@ func (groupBy GroupBy) MergeClause(clause *Clause) { groupBy.Having = append(copiedHaving, groupBy.Having...) } clause.Expression = groupBy + + if len(groupBy.Columns) == 0 { + clause.Name = "" + } else { + clause.Name = groupBy.Name() + } } diff --git a/tests/group_by_test.go b/tests/group_by_test.go index 7e41e94a..96dfc547 100644 --- a/tests/group_by_test.go +++ b/tests/group_by_test.go @@ -96,4 +96,14 @@ func TestGroupBy(t *testing.T) { if name != "groupby" || active != true || total != 40 { t.Errorf("group by two columns, name %v, age %v, active: %v", name, total, active) } + + if DB.Dialector.Name() == "mysql" { + if err := DB.Model(&User{}).Select("name, age as total").Where("name LIKE ?", "groupby%").Having("total > ?", 300).Scan(&result).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if result.Name != "groupby1" || result.Total != 330 { + t.Errorf("name should be groupby, total should be 660, but got %+v", result) + } + } } From 8cfa9d98f0ec913fdb1091a4cf3812b25b7fdce4 Mon Sep 17 00:00:00 2001 From: gavwu <68006288+gavwu@users.noreply.github.com> Date: Fri, 2 Apr 2021 09:56:38 +0800 Subject: [PATCH 866/881] Update field.go (#4228) seems like the `if-else` branch do the same thing, so remove it --- schema/field.go | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/schema/field.go b/schema/field.go index 5e792ed1..1881ad1a 100644 --- a/schema/field.go +++ b/schema/field.go @@ -441,15 +441,8 @@ func (field *Field) setupValuerAndSetter() { // ReflectValueOf switch { case len(field.StructField.Index) == 1: - if field.FieldType.Kind() == reflect.Ptr { - field.ReflectValueOf = func(value reflect.Value) reflect.Value { - fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) - return fieldValue - } - } else { - field.ReflectValueOf = func(value reflect.Value) reflect.Value { - return reflect.Indirect(value).Field(field.StructField.Index[0]) - } + field.ReflectValueOf = func(value reflect.Value) reflect.Value { + return reflect.Indirect(value).Field(field.StructField.Index[0]) } case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: field.ReflectValueOf = func(value reflect.Value) reflect.Value { From 673053f56a037fdd01031bee397188ff17830376 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 9 Apr 2021 09:35:41 +0800 Subject: [PATCH 867/881] Fix context cancel error, close #4259, close #4260 --- scan.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scan.go b/scan.go index acd637a4..e82e3f07 100644 --- a/scan.go +++ b/scan.go @@ -241,7 +241,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } } - if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound { + if err := rows.Err(); err != nil && err != db.Error { + db.AddError(err) + } + + if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil { db.AddError(ErrRecordNotFound) } } From f3bdfa82616fc9cb6ec3b5c47ebc73cfbe73a309 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 9 Apr 2021 10:20:36 +0800 Subject: [PATCH 868/881] Add IgnoreRecordNotFoundError option for logger --- errors.go | 4 +++- logger/logger.go | 12 ++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/errors.go b/errors.go index 5f464d2b..3126b8e7 100644 --- a/errors.go +++ b/errors.go @@ -2,11 +2,13 @@ package gorm import ( "errors" + + "gorm.io/gorm/logger" ) var ( // ErrRecordNotFound record not found error - ErrRecordNotFound = errors.New("record not found") + ErrRecordNotFound = logger.ErrRecordNotFound // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` ErrInvalidTransaction = errors.New("no valid transaction") // ErrNotImplemented not implemented diff --git a/logger/logger.go b/logger/logger.go index cd6bf57f..f14748c1 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -2,6 +2,7 @@ package logger import ( "context" + "errors" "fmt" "io/ioutil" "log" @@ -11,6 +12,8 @@ import ( "gorm.io/gorm/utils" ) +var ErrRecordNotFound = errors.New("record not found") + // Colors const ( Reset = "\033[0m" @@ -43,9 +46,10 @@ type Writer interface { } type Config struct { - SlowThreshold time.Duration - Colorful bool - LogLevel LogLevel + SlowThreshold time.Duration + Colorful bool + IgnoreRecordNotFoundError bool + LogLevel LogLevel } // Interface logger interface @@ -138,7 +142,7 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i if l.LogLevel > Silent { elapsed := time.Since(begin) switch { - case err != nil && l.LogLevel >= Error: + case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError): sql, rows := fc() if rows == -1 { l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) From ad53074f1d548297205cd0a6affe333ab2b22e54 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 9 Apr 2021 11:07:14 +0800 Subject: [PATCH 869/881] Pass db error to new instance --- gorm.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gorm.go b/gorm.go index 9c4d444f..f1e3745f 100644 --- a/gorm.go +++ b/gorm.go @@ -346,7 +346,7 @@ func (db *DB) DB() (*sql.DB, error) { func (db *DB) getInstance() *DB { if db.clone > 0 { - tx := &DB{Config: db.Config} + tx := &DB{Config: db.Config, Error: db.Error} if db.clone == 1 { // clone with new statement From d278ca49ef30f003c9624ae58d4d8726f728c1f7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 9 Apr 2021 11:43:24 +0800 Subject: [PATCH 870/881] sort GORM options before apply --- gorm.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/gorm.go b/gorm.go index f1e3745f..0da218f6 100644 --- a/gorm.go +++ b/gorm.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "sort" "sync" "time" @@ -111,6 +112,12 @@ type Session struct { func Open(dialector Dialector, opts ...Option) (db *DB, err error) { config := &Config{} + sort.Slice(opts, func(i, j int) bool { + _, isConfig := opts[i].(*Config) + _, isConfig2 := opts[j].(*Config) + return isConfig && !isConfig2 + }) + for _, opt := range opts { if opt != nil { if err := opt.Apply(config); err != nil { From d7911300f83d79a57bc456a487addc031f2d9ff5 Mon Sep 17 00:00:00 2001 From: yrong1997 Date: Tue, 13 Apr 2021 09:39:43 +0800 Subject: [PATCH 871/881] Respect ignore migration when add column (#4276) continue https://github.com/go-gorm/gorm/pull/4028 --- migrator/migrator.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 075b5ca6..1800ab54 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -295,10 +295,13 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error { func (m Migrator) AddColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - return m.DB.Exec( - "ALTER TABLE ? ADD ? ?", - m.CurrentTable(stmt), clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), - ).Error + if !field.IgnoreMigration { + return m.DB.Exec( + "ALTER TABLE ? ADD ? ?", + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), + ).Error + } + return nil } return fmt.Errorf("failed to look up field with name: %s", field) }) From 5555b010dc2617b07dc4a444a130506b1f7e6e56 Mon Sep 17 00:00:00 2001 From: heige Date: Tue, 13 Apr 2021 09:41:30 +0800 Subject: [PATCH 872/881] feat: Optimal value type acquisition for v (#4278) --- schema/field.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/schema/field.go b/schema/field.go index 1881ad1a..5dbc96f1 100644 --- a/schema/field.go +++ b/schema/field.go @@ -479,17 +479,19 @@ func (field *Field) setupValuerAndSetter() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { reflectV := reflect.ValueOf(v) + // Optimal value type acquisition for v + reflectValType := reflectV.Type() - if reflectV.Type().AssignableTo(field.FieldType) { + if reflectValType.AssignableTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV) return - } else if reflectV.Type().ConvertibleTo(field.FieldType) { + } else if reflectValType.ConvertibleTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) return } else if field.FieldType.Kind() == reflect.Ptr { fieldValue := field.ReflectValueOf(value) - if reflectV.Type().AssignableTo(field.FieldType.Elem()) { + if reflectValType.AssignableTo(field.FieldType.Elem()) { if !fieldValue.IsValid() { fieldValue = reflect.New(field.FieldType.Elem()) } else if fieldValue.IsNil() { @@ -497,7 +499,7 @@ func (field *Field) setupValuerAndSetter() { } fieldValue.Elem().Set(reflectV) return - } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { + } else if reflectValType.ConvertibleTo(field.FieldType.Elem()) { if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } From 74e7a9ca079ba44c4e9038088dace76726f2b69c Mon Sep 17 00:00:00 2001 From: heige Date: Wed, 14 Apr 2021 13:00:54 +0800 Subject: [PATCH 873/881] Optimize reflect value length and method (#4280) * Respect ignore migration when add column (#4276) continue https://github.com/go-gorm/gorm/pull/4028 * feat: Optimal value type acquisition for v (#4278) * feat: optimize relect value length and value * feat: optimize ConvertSliceOfMapToValuesForCreate method Co-authored-by: yrong1997 --- callbacks/associations.go | 5 +++-- callbacks/helper.go | 11 ++++++++--- schema/utils.go | 6 +++--- statement.go | 12 ++++++++---- 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 2a4efbe1..6d74f20d 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -288,12 +288,13 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { appendToElems(db.Statement.ReflectValue) } - if elems.Len() > 0 { + // optimize elems of reflect value length + if elemLen := elems.Len(); elemLen > 0 { if v, ok := selectColumns[rel.Name+".*"]; !ok || v { saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) } - for i := 0; i < elems.Len(); i++ { + for i := 0; i < elemLen; i++ { appendToJoins(objs[i], elems.Index(i)) } } diff --git a/callbacks/helper.go b/callbacks/helper.go index 3ac63fa1..ad85a1c6 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -41,16 +41,21 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter // ConvertSliceOfMapToValuesForCreate convert slice of map to values func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { var ( - columns = make([]string, 0, len(mapValues)) - result = map[string][]interface{}{} - selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) + columns = make([]string, 0, len(mapValues)) ) + // when the length of mapValues,return directly here + // no need to call stmt.SelectAndOmitColumns method if len(mapValues) == 0 { stmt.AddError(gorm.ErrEmptySlice) return } + var ( + result = make(map[string][]interface{}, len(mapValues)) + selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) + ) + for idx, mapValue := range mapValues { for k, v := range mapValue { if stmt.Schema != nil { diff --git a/schema/utils.go b/schema/utils.go index d311c61b..add22047 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -71,10 +71,10 @@ func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (refle reflectResults = reflect.Append(reflectResults, result.Addr()) case reflect.Slice, reflect.Array: for i := 0; i < result.Len(); i++ { - if result.Index(i).Kind() == reflect.Ptr { - reflectResults = reflect.Append(reflectResults, result.Index(i)) + if elem := result.Index(i); elem.Kind() == reflect.Ptr { + reflectResults = reflect.Append(reflectResults, elem) } else { - reflectResults = reflect.Append(reflectResults, result.Index(i).Addr()) + reflectResults = reflect.Append(reflectResults, elem.Addr()) } } } diff --git a/statement.go b/statement.go index 7a827ca8..099c66d2 100644 --- a/statement.go +++ b/statement.go @@ -328,8 +328,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } else if _, ok := v[key].(Valuer); ok { conds = append(conds, clause.Eq{Column: key, Value: v[key]}) } else { - values := make([]interface{}, reflectValue.Len()) - for i := 0; i < reflectValue.Len(); i++ { + // optimize relect value length + valueLen := reflectValue.Len() + values := make([]interface{}, valueLen) + for i := 0; i < valueLen; i++ { values[i] = reflectValue.Index(i).Interface() } @@ -396,8 +398,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if len(args) == 1 { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: - values := make([]interface{}, reflectValue.Len()) - for i := 0; i < reflectValue.Len(); i++ { + // optimize relect value length + valueLen := reflectValue.Len() + values := make([]interface{}, valueLen) + for i := 0; i < valueLen; i++ { values[i] = reflectValue.Index(i).Interface() } From d483ffa45c51162ba9defe3a59c0ed62793c037f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 15 Apr 2021 10:37:05 +0800 Subject: [PATCH 874/881] Fix Preload with nil pointer --- callbacks.go | 1 - callbacks/preload.go | 1 - tests/preload_test.go | 5 ++++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/callbacks.go b/callbacks.go index f2ee0ea5..ee96fcb9 100644 --- a/callbacks.go +++ b/callbacks.go @@ -110,7 +110,6 @@ func (p *processor) Execute(db *DB) { for stmt.ReflectValue.Kind() == reflect.Ptr { if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() { stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem())) - break } stmt.ReflectValue = stmt.ReflectValue.Elem() diff --git a/callbacks/preload.go b/callbacks/preload.go index eafd407d..25c5e659 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -27,7 +27,6 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload }) if rel.JoinTable != nil { - var ( joinForeignFields = make([]*schema.Field, 0, len(rel.References)) joinRelForeignFields = make([]*schema.Field, 0, len(rel.References)) diff --git a/tests/preload_test.go b/tests/preload_test.go index c9f5d278..8f49955e 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -63,12 +63,15 @@ func TestNestedPreload(t *testing.T) { var user2 User DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID) - CheckUser(t, user2, user) var user3 User DB.Preload(clause.Associations+"."+clause.Associations).Find(&user3, "id = ?", user.ID) CheckUser(t, user3, user) + + var user4 *User + DB.Preload("Pets.Toy").Find(&user4, "id = ?", user.ID) + CheckUser(t, *user4, user) } func TestNestedPreloadForSlice(t *testing.T) { From 7701c885077051c864da309ed850631ada7d0eea Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 16 Apr 2021 19:27:23 +0800 Subject: [PATCH 875/881] Assign transaction error to db --- callbacks/transaction.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/callbacks/transaction.go b/callbacks/transaction.go index 45c6ca11..8ba2ba3b 100644 --- a/callbacks/transaction.go +++ b/callbacks/transaction.go @@ -11,6 +11,8 @@ func BeginTransaction(db *gorm.DB) { db.InstanceSet("gorm:started_transaction", true) } else if tx.Error == gorm.ErrInvalidTransaction { tx.Error = nil + } else { + db.Error = tx.Error } } } From 15a46bc0425cfdb59678f5a0a4af407853c08492 Mon Sep 17 00:00:00 2001 From: Chris Faulkner Date: Mon, 19 Apr 2021 06:03:39 -0700 Subject: [PATCH 876/881] Fix some typos (#4294) --- .github/workflows/tests.yml | 2 +- errors.go | 4 ++-- gorm.go | 2 +- logger/sql.go | 4 ++-- prepare_stmt.go | 2 +- schema/naming.go | 10 +++++----- schema/relationship.go | 8 ++++---- schema/schema_helper_test.go | 2 +- statement.go | 4 ++-- 9 files changed, 19 insertions(+), 19 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e2ea89a7..370417fc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -84,7 +84,7 @@ jobs: matrix: dbversion: ['postgres:latest', 'postgres:11', 'postgres:10'] go: ['1.16', '1.15', '1.14'] - platform: [ubuntu-latest] # can not run in macOS and widnowsOS + platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} services: diff --git a/errors.go b/errors.go index 3126b8e7..569207a6 100644 --- a/errors.go +++ b/errors.go @@ -33,8 +33,8 @@ var ( ErrEmptySlice = errors.New("empty slice found") // ErrDryRunModeUnsupported dry run mode unsupported ErrDryRunModeUnsupported = errors.New("dry run mode unsupported") - // ErrInvaildDB invalid db - ErrInvaildDB = errors.New("invalid db") + // ErrInvalidDB invalid db + ErrInvalidDB = errors.New("invalid db") // ErrInvalidValue invalid value ErrInvalidValue = errors.New("invalid value") // ErrInvalidValueOfLength invalid values do not match length diff --git a/gorm.go b/gorm.go index 0da218f6..e105a933 100644 --- a/gorm.go +++ b/gorm.go @@ -348,7 +348,7 @@ func (db *DB) DB() (*sql.DB, error) { return sqldb, nil } - return nil, ErrInvaildDB + return nil, ErrInvalidDB } func (db *DB) getInstance() *DB { diff --git a/logger/sql.go b/logger/sql.go index 4c5f92ed..3d31d23c 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -28,7 +28,7 @@ func isPrintable(s []byte) bool { return true } -var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} +var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { var convertParams func(interface{}, int) @@ -91,7 +91,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { convertParams(reflect.Indirect(rv).Interface(), idx) } else { - for _, t := range convertableTypes { + for _, t := range convertibleTypes { if rv.Type().ConvertibleTo(t) { convertParams(rv.Convert(t).Interface(), idx) return diff --git a/prepare_stmt.go b/prepare_stmt.go index 122e98d2..14570061 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -27,7 +27,7 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { return sqldb, nil } - return nil, ErrInvaildDB + return nil, ErrInvalidDB } func (db *PreparedStmtDB) Close() { diff --git a/schema/naming.go b/schema/naming.go index 0643d1bd..1962c3c6 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -73,16 +73,16 @@ func (ns NamingStrategy) IndexName(table, column string) string { } func (ns NamingStrategy) formatName(prefix, table, name string) string { - formatedName := strings.Replace(fmt.Sprintf("%v_%v_%v", prefix, table, name), ".", "_", -1) + formattedName := strings.Replace(fmt.Sprintf("%v_%v_%v", prefix, table, name), ".", "_", -1) - if utf8.RuneCountInString(formatedName) > 64 { + if utf8.RuneCountInString(formattedName) > 64 { h := sha1.New() - h.Write([]byte(formatedName)) + h.Write([]byte(formattedName)) bs := h.Sum(nil) - formatedName = fmt.Sprintf("%v%v%v", prefix, table, name)[0:56] + string(bs)[:8] + formattedName = fmt.Sprintf("%v%v%v", prefix, table, name)[0:56] + string(bs)[:8] } - return formatedName + return formattedName } var ( diff --git a/schema/relationship.go b/schema/relationship.go index a8863bfe..061e9120 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -89,7 +89,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { } if relation.Type == has { - // don't add relations to embeded schema, which might be shared + // don't add relations to embedded schema, which might be shared if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil { relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation } @@ -308,9 +308,9 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel f.Size = fieldsMap[f.Name].Size } relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) - ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] + ownPrimaryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] - if ownPriamryField { + if ownPrimaryField { joinRel := relation.JoinTable.Relationships.Relations[relName] joinRel.Field = relation.Field joinRel.References = append(joinRel.References, &Reference{ @@ -331,7 +331,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel relation.References = append(relation.References, &Reference{ PrimaryKey: fieldsMap[f.Name], ForeignKey: f, - OwnPrimaryKey: ownPriamryField, + OwnPrimaryKey: ownPrimaryField, }) } } diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index cc0306e0..6d2bc664 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -29,7 +29,7 @@ func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields } if !found { - t.Errorf("schema %v failed to found priamry key: %v", s, field) + t.Errorf("schema %v failed to found primary key: %v", s, field) } } }) diff --git a/statement.go b/statement.go index 099c66d2..32bc462a 100644 --- a/statement.go +++ b/statement.go @@ -328,7 +328,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } else if _, ok := v[key].(Valuer); ok { conds = append(conds, clause.Eq{Column: key, Value: v[key]}) } else { - // optimize relect value length + // optimize reflect value length valueLen := reflectValue.Len() values := make([]interface{}, valueLen) for i := 0; i < valueLen; i++ { @@ -398,7 +398,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if len(args) == 1 { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: - // optimize relect value length + // optimize reflect value length valueLen := reflectValue.Len() values := make([]interface{}, valueLen) for i := 0; i < valueLen; i++ { From d327926425afecbe084997ba195497107cd71a92 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 19 Apr 2021 21:32:32 +0800 Subject: [PATCH 877/881] Check ReflectValue.CanAddr before set field value --- errors.go | 2 +- statement.go | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/errors.go b/errors.go index 569207a6..f1f6c137 100644 --- a/errors.go +++ b/errors.go @@ -36,7 +36,7 @@ var ( // ErrInvalidDB invalid db ErrInvalidDB = errors.New("invalid db") // ErrInvalidValue invalid value - ErrInvalidValue = errors.New("invalid value") + ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice") // ErrInvalidValueOfLength invalid values do not match length ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match") ) diff --git a/statement.go b/statement.go index 32bc462a..2734752d 100644 --- a/statement.go +++ b/statement.go @@ -539,6 +539,11 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . } } + if !stmt.ReflectValue.CanAddr() { + stmt.AddError(ErrInvalidValue) + return + } + switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: if len(fromCallbacks) > 0 { From a855fe64026a65bba106d6614873638c64b3fc8b Mon Sep 17 00:00:00 2001 From: Sky34gl3 Date: Thu, 22 Apr 2021 07:11:19 +0200 Subject: [PATCH 878/881] Fixed naming longer than 64 characters (#4310) Co-authored-by: Mickael MAUGER --- schema/naming.go | 3 ++- schema/naming_test.go | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/schema/naming.go b/schema/naming.go index 1962c3c6..d53942e4 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -2,6 +2,7 @@ package schema import ( "crypto/sha1" + "encoding/hex" "fmt" "strings" "unicode/utf8" @@ -80,7 +81,7 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string { h.Write([]byte(formattedName)) bs := h.Sum(nil) - formattedName = fmt.Sprintf("%v%v%v", prefix, table, name)[0:56] + string(bs)[:8] + formattedName = fmt.Sprintf("%v%v%v", prefix, table, name)[0:56] + hex.EncodeToString(bs)[:8] } return formattedName } diff --git a/schema/naming_test.go b/schema/naming_test.go index 08f8d498..face9364 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -168,3 +168,12 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) { t.Errorf("invalid column name generated, got %v", columdName) } } + +func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { + var ns = NamingStrategy{} + + formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") + if formattedName != "prefixtablethisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLo180f2c67" { + t.Errorf("invalid formatted name generated, got %v", formattedName) + } +} From 82cb4ebfe2e69c8953536f12e1039807c5643334 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 22 Apr 2021 13:12:15 +0800 Subject: [PATCH 879/881] Fix overwrite Statement in scopes --- callbacks.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/callbacks.go b/callbacks.go index ee96fcb9..20fec429 100644 --- a/callbacks.go +++ b/callbacks.go @@ -72,20 +72,20 @@ func (cs *callbacks) Raw() *processor { } func (p *processor) Execute(db *DB) { - var ( - curTime = time.Now() - stmt = db.Statement - ) - // call scopes - for len(stmt.scopes) > 0 { - scopes := stmt.scopes - stmt.scopes = nil + for len(db.Statement.scopes) > 0 { + scopes := db.Statement.scopes + db.Statement.scopes = nil for _, scope := range scopes { db = scope(db) } } + var ( + curTime = time.Now() + stmt = db.Statement + ) + // assign model values if stmt.Model == nil { stmt.Model = stmt.Dest From 6951be0284135a5ecd6f359eb4d173b8fb35e572 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 28 Apr 2021 17:19:30 +0800 Subject: [PATCH 880/881] Allow customize clauses --- callbacks.go | 15 +++++++++++++-- callbacks/callbacks.go | 36 ++++++++++++++++++++++++++++++++++-- callbacks/create.go | 4 ++-- callbacks/delete.go | 2 +- callbacks/query.go | 2 +- callbacks/update.go | 2 +- statement.go | 1 + 7 files changed, 53 insertions(+), 9 deletions(-) diff --git a/callbacks.go b/callbacks.go index 20fec429..01d9ed30 100644 --- a/callbacks.go +++ b/callbacks.go @@ -32,6 +32,7 @@ type callbacks struct { type processor struct { db *DB + Clauses []string fns []func(*DB) callbacks []*callback } @@ -82,10 +83,16 @@ func (p *processor) Execute(db *DB) { } var ( - curTime = time.Now() - stmt = db.Statement + curTime = time.Now() + stmt = db.Statement + resetBuildClauses bool ) + if len(stmt.BuildClauses) == 0 { + stmt.BuildClauses = p.Clauses + resetBuildClauses = true + } + // assign model values if stmt.Model == nil { stmt.Model = stmt.Dest @@ -131,6 +138,10 @@ func (p *processor) Execute(db *DB) { stmt.SQL.Reset() stmt.Vars = nil } + + if resetBuildClauses { + stmt.BuildClauses = nil + } } func (p *processor) Get(name string) func(*DB) { diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index 7bb27318..d85c1928 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -4,9 +4,20 @@ import ( "gorm.io/gorm" ) +var ( + createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"} + queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"} + updateClauses = []string{"UPDATE", "SET", "WHERE"} + deleteClauses = []string{"DELETE", "FROM", "WHERE"} +) + type Config struct { LastInsertIDReversed bool WithReturning bool + CreateClauses []string + QueryClauses []string + UpdateClauses []string + DeleteClauses []string } func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { @@ -22,11 +33,19 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) createCallback.Register("gorm:after_create", AfterCreate) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + if len(config.CreateClauses) == 0 { + config.CreateClauses = createClauses + } + createCallback.Clauses = config.CreateClauses queryCallback := db.Callback().Query() queryCallback.Register("gorm:query", Query) queryCallback.Register("gorm:preload", Preload) queryCallback.Register("gorm:after_query", AfterQuery) + if len(config.QueryClauses) == 0 { + config.QueryClauses = queryClauses + } + queryCallback.Clauses = config.QueryClauses deleteCallback := db.Callback().Delete() deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) @@ -35,6 +54,10 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { deleteCallback.Register("gorm:delete", Delete) deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + if len(config.DeleteClauses) == 0 { + config.DeleteClauses = deleteClauses + } + deleteCallback.Clauses = config.DeleteClauses updateCallback := db.Callback().Update() updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) @@ -45,7 +68,16 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + if len(config.UpdateClauses) == 0 { + config.UpdateClauses = updateClauses + } + updateCallback.Clauses = config.UpdateClauses - db.Callback().Row().Register("gorm:row", RowQuery) - db.Callback().Raw().Register("gorm:raw", RawExec) + rowCallback := db.Callback().Row() + rowCallback.Register("gorm:row", RowQuery) + rowCallback.Clauses = config.QueryClauses + + rawCallback := db.Callback().Raw() + rawCallback.Register("gorm:raw", RawExec) + rawCallback.Clauses = config.QueryClauses } diff --git a/callbacks/create.go b/callbacks/create.go index 909d984a..727bd380 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -47,7 +47,7 @@ func Create(config *Config) func(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) - db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") + db.Statement.Build(db.Statement.BuildClauses...) } if !db.DryRun && db.Error == nil { @@ -118,7 +118,7 @@ func CreateWithReturning(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) - db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") + db.Statement.Build(db.Statement.BuildClauses...) } if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { diff --git a/callbacks/delete.go b/callbacks/delete.go index 64dd7236..91659c51 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -135,7 +135,7 @@ func Delete(db *gorm.DB) { } db.Statement.AddClauseIfNotExists(clause.From{}) - db.Statement.Build("DELETE", "FROM", "WHERE") + db.Statement.Build(db.Statement.BuildClauses...) } if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil { diff --git a/callbacks/query.go b/callbacks/query.go index 11753472..d0341284 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -167,7 +167,7 @@ func BuildQuerySQL(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clauseSelect) - db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") + db.Statement.Build(db.Statement.BuildClauses...) } } diff --git a/callbacks/update.go b/callbacks/update.go index db5b52fb..75bb02db 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -66,7 +66,7 @@ func Update(db *gorm.DB) { } else { return } - db.Statement.Build("UPDATE", "SET", "WHERE") + db.Statement.Build(db.Statement.BuildClauses...) } if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { diff --git a/statement.go b/statement.go index 2734752d..a87fd212 100644 --- a/statement.go +++ b/statement.go @@ -27,6 +27,7 @@ type Statement struct { Dest interface{} ReflectValue reflect.Value Clauses map[string]clause.Clause + BuildClauses []string Distinct bool Selects []string // selected columns Omits []string // omit columns From f0d0bbbc1012d309ae7b1802da7c5a16d896e2a9 Mon Sep 17 00:00:00 2001 From: Karolos Lykos Date: Thu, 29 Apr 2021 02:15:37 +0300 Subject: [PATCH 881/881] Added missing white space (#4330) * Added missing white space * Added missing white space * Added missing white space --- clause/on_conflict.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clause/on_conflict.go b/clause/on_conflict.go index f0c3d7e7..127d9bc1 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -40,7 +40,7 @@ func (onConflict OnConflict) Build(builder Builder) { } if len(onConflict.Where.Exprs) > 0 { - builder.WriteString("WHERE ") + builder.WriteString(" WHERE ") onConflict.Where.Build(builder) builder.WriteByte(' ') }