From 750fd9030a4c9dee3dcedce532a2181261dc26f5 Mon Sep 17 00:00:00 2001 From: Lukas Dietrich Date: Mon, 4 Sep 2017 16:22:02 +0200 Subject: [PATCH 01/47] Fix postgres dialect for dbs with multiple schemas (#1558) If a postgres database contains more than one schema methods like HasTable(...) would return true even if the current schema does not contain a table with that name. --- dialect_postgres.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dialect_postgres.go b/dialect_postgres.go index ed5248e0..4d362919 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -85,7 +85,7 @@ func (s *postgres) DataTypeOf(field *StructField) string { func (s postgres) HasIndex(tableName string, indexName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count) + s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count) return count > 0 } @@ -97,13 +97,13 @@ func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { func (s postgres) HasTable(tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count) return count > 0 } func (s postgres) HasColumn(tableName string, columnName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count) return count > 0 } From 981d5db663eb018386c8df25e79fbecb3a4722e1 Mon Sep 17 00:00:00 2001 From: Dhiver Date: Mon, 4 Sep 2017 16:23:42 +0200 Subject: [PATCH 02/47] Fix postgres dialect UUID sqlType evaluation (#1564) --- dialect_postgres.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dialect_postgres.go b/dialect_postgres.go index 4d362919..75aef9ba 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -67,8 +67,9 @@ func (s *postgres) DataTypeOf(field *StructField) string { default: if IsByteArrayOrSlice(dataValue) { sqlType = "bytea" - } else if isUUID(dataValue) { - sqlType = "uuid" + if isUUID(dataValue) { + sqlType = "uuid" + } } } } From 6e456250f7ceb5a89da60223ae16ce6cbe563398 Mon Sep 17 00:00:00 2001 From: Teppei Fukuda Date: Mon, 4 Sep 2017 22:25:57 +0800 Subject: [PATCH 03/47] Erros skip nil in Add function (#1566) --- errors.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/errors.go b/errors.go index 832fa9b0..6845188e 100644 --- a/errors.go +++ b/errors.go @@ -29,6 +29,10 @@ func (errs Errors) GetErrors() []error { // Add adds an error func (errs Errors) Add(newErrors ...error) Errors { for _, err := range newErrors { + if err == nil { + continue + } + if errors, ok := err.(Errors); ok { errs = errs.Add(errors...) } else { From c0ac6a7d506f738a1239d8a6750f69dd67d626ef Mon Sep 17 00:00:00 2001 From: Domen Ipavec Date: Mon, 4 Sep 2017 16:35:37 +0200 Subject: [PATCH 04/47] Do not ignore order on distinct query (#1570) --- search.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/search.go b/search.go index 2e273584..90138595 100644 --- a/search.go +++ b/search.go @@ -2,7 +2,6 @@ package gorm import ( "fmt" - "regexp" ) type search struct { @@ -73,13 +72,7 @@ func (s *search) Order(value interface{}, reorder ...bool) *search { return s } -var distinctSQLRegexp = regexp.MustCompile(`(?i)distinct[^a-z]+[a-z]+`) - func (s *search) Select(query interface{}, args ...interface{}) *search { - if distinctSQLRegexp.MatchString(fmt.Sprint(query)) { - s.ignoreOrderQuery = true - } - s.selects = map[string]interface{}{"query": query, "args": args} return s } From b1885a643b4977c9089d77eb07c0fd96591f94b8 Mon Sep 17 00:00:00 2001 From: Cedric GESTES Date: Mon, 4 Sep 2017 16:39:19 +0200 Subject: [PATCH 05/47] Support cloudsqlpostgres dialect (#1577) This is needed for proper cloud sql proxy. see https://github.com/GoogleCloudPlatform/cloudsql-proxy and https://github.com/GoogleCloudPlatform/cloudsql-proxy/blob/master/proxy/dialers/postgres/hook_test.go for details. --- dialect_postgres.go | 1 + 1 file changed, 1 insertion(+) diff --git a/dialect_postgres.go b/dialect_postgres.go index 75aef9ba..6fdf4df1 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -13,6 +13,7 @@ type postgres struct { func init() { RegisterDialect("postgres", &postgres{}) + RegisterDialect("cloudsqlpostgres", &postgres{}) } func (postgres) GetName() string { From 3a9e91ab372120a0e35b518430255308e3d8d5ea Mon Sep 17 00:00:00 2001 From: Horacio Duran Date: Thu, 28 Sep 2017 11:48:21 -0300 Subject: [PATCH 06/47] Correct ModifyColumn SQL syntax. (#1614) * Correct ModifyColumn SQL syntax. The generated SQL for ModifyColumn was: `ALTER TABLE "tablename" MODIFY "columname" type` But should have been: `ALTER TABLE "tablename" ALTER COLUMN "columname" TYPE type` since Modify does not seem to be entirely compatible with all Engines * Test ModifyColumn * Skip ModifyColumnType test on incompatible DBs Some DB Engines don't fully support alter table so we skip when the dialect does not correspond to one of the ones that are known to support it. --- migration_test.go | 25 +++++++++++++++++++++++++ scope.go | 2 +- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/migration_test.go b/migration_test.go index 9fc14fa0..3f3a5c8f 100644 --- a/migration_test.go +++ b/migration_test.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "errors" "fmt" + "os" "reflect" "testing" "time" @@ -432,3 +433,27 @@ func TestMultipleIndexes(t *testing.T) { t.Error("MultipleIndexes unique index failed") } } + +func TestModifyColumnType(t *testing.T) { + dialect := os.Getenv("GORM_DIALECT") + if dialect != "postgres" && + dialect != "mysql" && + dialect != "mssql" { + t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type") + } + + type ModifyColumnType struct { + gorm.Model + Name1 string `gorm:"length:100"` + Name2 string `gorm:"length:200"` + } + DB.DropTable(&ModifyColumnType{}) + DB.CreateTable(&ModifyColumnType{}) + + name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2") + name2Type := DB.Dialect().DataTypeOf(name2Field.StructField) + + if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil { + t.Errorf("No error should happen when ModifyColumn, but got %v", err) + } +} diff --git a/scope.go b/scope.go index fda7f653..51ebd5a0 100644 --- a/scope.go +++ b/scope.go @@ -1139,7 +1139,7 @@ func (scope *Scope) dropTable() *Scope { } func (scope *Scope) modifyColumn(column string, typ string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() + scope.Raw(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() } func (scope *Scope) dropColumn(column string) { From 9c9de896864248929269a7cb2d64ed73b5fdf834 Mon Sep 17 00:00:00 2001 From: Konrad Kleine Date: Tue, 10 Oct 2017 15:04:23 +0200 Subject: [PATCH 07/47] Use log.PrintX instead of fmt.PrintX (#1634) --- callback.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/callback.go b/callback.go index 17f75451..a4382147 100644 --- a/callback.go +++ b/callback.go @@ -1,8 +1,6 @@ package gorm -import ( - "fmt" -) +import "log" // DefaultCallback default callbacks defined by gorm var DefaultCallback = &Callback{} @@ -95,7 +93,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { if cp.kind == "row_query" { if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { - fmt.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName) + log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName) cp.before = "gorm:row_query" } } @@ -109,7 +107,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope * // Remove a registered callback // db.Callback().Create().Remove("gorm:update_time_stamp_when_create") func (cp *CallbackProcessor) Remove(callbackName string) { - fmt.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()) + log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()) cp.name = callbackName cp.remove = true cp.parent.processors = append(cp.parent.processors, cp) @@ -122,7 +120,7 @@ func (cp *CallbackProcessor) Remove(callbackName string) { // scope.SetColumn("Updated", now) // }) func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { - fmt.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()) + log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()) cp.name = callbackName cp.processor = &callback cp.replace = true @@ -161,7 +159,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { for _, cp := range cps { // show warning message the callback name already exists if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { - fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) + log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) } allNames = append(allNames, cp.name) } From 0a51f6cdc55d1650d9ed3b4c13026cfa9133b01e Mon Sep 17 00:00:00 2001 From: Aetheus Date: Tue, 10 Oct 2017 21:28:39 +0800 Subject: [PATCH 08/47] add JSONB type (#1626) * add JSONB type * add comments to satisfy gofmt --- dialects/postgres/postgres.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index adeeec7b..b8e76891 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -6,6 +6,9 @@ import ( _ "github.com/lib/pq" "github.com/lib/pq/hstore" + "encoding/json" + "errors" + "fmt" ) type Hstore map[string]*string @@ -52,3 +55,23 @@ func (h *Hstore) Scan(value interface{}) error { return nil } + +// Jsonb Postgresql's JSONB data type +type Jsonb struct { + json.RawMessage +} + +// Value get value of Jsonb +func (j Jsonb) Value() (driver.Value, error) { + return j.MarshalJSON() +} + +// Scan scan value into Jsonb +func (j *Jsonb) Scan(value interface{}) error { + bytes, ok := value.([]byte) + if !ok { + return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value)) + } + + return json.Unmarshal(bytes, j) +} From 26262ef9bb897b06d4e7ad6f1316e1037e030283 Mon Sep 17 00:00:00 2001 From: Wing Gao Date: Tue, 28 Nov 2017 13:05:10 +0800 Subject: [PATCH 09/47] autoIndex should throw an error on failed --- scope.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/scope.go b/scope.go index 51ebd5a0..f1b9da4b 100644 --- a/scope.go +++ b/scope.go @@ -1228,11 +1228,19 @@ func (scope *Scope) autoIndex() *Scope { } for name, columns := range indexes { - scope.NewDB().Model(scope.Value).AddIndex(name, columns...) + db := scope.NewDB().Model(scope.Value).AddIndex(name, columns...) + if db.Error != nil { + scope.db.Error = db.Error + return scope + } } for name, columns := range uniqueIndexes { - scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) + db := scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) + if db.Error != nil { + scope.db.Error = db.Error + return scope + } } return scope From 2ff44ee8d72785386e42e11f637ac8fa816cc4ca Mon Sep 17 00:00:00 2001 From: s-takehana Date: Wed, 31 Jan 2018 17:32:36 +0900 Subject: [PATCH 10/47] Fix regex in BuildForeignKeyName #1681 (#1728) --- dialect_common.go | 2 +- dialect_mysql.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dialect_common.go b/dialect_common.go index a99627f2..7d0c3ce7 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -146,7 +146,7 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string { keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest) - keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") + keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_") return keyName } diff --git a/dialect_mysql.go b/dialect_mysql.go index 6fcd0079..686ad1ee 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -166,8 +166,8 @@ func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { h.Write([]byte(keyName)) bs := h.Sum(nil) - // sha1 is 40 digits, keep first 24 characters of destination - destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_")) + // sha1 is 40 characters, keep first 24 characters of destination + destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(dest, "_")) if len(destRunes) > 24 { destRunes = destRunes[:24] } From a2c7c4b63f2ba7da2ae6428269bfc43efd29a4e8 Mon Sep 17 00:00:00 2001 From: rightjoin Date: Wed, 31 Jan 2018 14:38:03 +0530 Subject: [PATCH 11/47] UID should come before UI in common abbreviations (#1678) This will fix the following issue https://github.com/jinzhu/gorm/issues/1460 --- utils.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils.go b/utils.go index 97a3d175..dfaae939 100644 --- a/utils.go +++ b/utils.go @@ -23,7 +23,7 @@ var NowFunc = func() time.Time { } // Copied from golint -var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} +var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialismsReplacer *strings.Replacer var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) From b9035a7602b7734076ac4a3146fc88d285e326a5 Mon Sep 17 00:00:00 2001 From: s-takehana Date: Wed, 31 Jan 2018 17:32:36 +0900 Subject: [PATCH 12/47] Fix regex in BuildForeignKeyName #1681 (#1728) --- dialect_common.go | 2 +- dialect_mysql.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dialect_common.go b/dialect_common.go index a99627f2..7d0c3ce7 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -146,7 +146,7 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string { keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest) - keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") + keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_") return keyName } diff --git a/dialect_mysql.go b/dialect_mysql.go index 6fcd0079..686ad1ee 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -166,8 +166,8 @@ func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { h.Write([]byte(keyName)) bs := h.Sum(nil) - // sha1 is 40 digits, keep first 24 characters of destination - destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_")) + // sha1 is 40 characters, keep first 24 characters of destination + destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(dest, "_")) if len(destRunes) > 24 { destRunes = destRunes[:24] } From 630c12b54936a0b20a6ddf8a35dab18279165dd8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 31 Jan 2018 17:14:21 +0800 Subject: [PATCH 13/47] Refactor #1693 --- scope.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/scope.go b/scope.go index f1b9da4b..0a7e8861 100644 --- a/scope.go +++ b/scope.go @@ -1228,18 +1228,14 @@ func (scope *Scope) autoIndex() *Scope { } for name, columns := range indexes { - db := scope.NewDB().Model(scope.Value).AddIndex(name, columns...) - if db.Error != nil { - scope.db.Error = db.Error - return scope + if db := scope.NewDB().Model(scope.Value).AddIndex(name, columns...); db.Error != nil { + scope.db.AddError(db.Error) } } for name, columns := range uniqueIndexes { - db := scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) - if db.Error != nil { - scope.db.Error = db.Error - return scope + if db := scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil { + scope.db.AddError(db.Error) } } From cbc3d3cd509bee9f1c0d6f03bf02ff91e9dd47dd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 31 Jan 2018 18:16:20 +0800 Subject: [PATCH 14/47] Add go report card --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 44eb4a69..e904ef80 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm) [![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921) [![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) From ca46ec0770003aab3c0ed7d7b336643362221c21 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 31 Jan 2018 18:22:30 +0800 Subject: [PATCH 15/47] Smaller image --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e904ef80..e5c21dc5 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Supporting the project -[![http://patreon.com/jinzhu](http://patreon_public_assets.s3.amazonaws.com/sized/becomeAPatronBanner.png)](http://patreon.com/jinzhu) +[![http://patreon.com/jinzhu](https://c5.patreon.com/external/logo/become_a_patron_button.png)](http://patreon.com/jinzhu) ## Author From 802104cc7cfe58153cccc9bc76e5b9078296c16b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 2 Feb 2018 22:01:31 +0800 Subject: [PATCH 16/47] Use BuildKeyName to build db's index name --- dialect.go | 4 ++-- dialect_common.go | 4 ++-- dialect_mysql.go | 6 +++--- scope.go | 7 ++++--- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/dialect.go b/dialect.go index e879588b..9d3be249 100644 --- a/dialect.go +++ b/dialect.go @@ -41,8 +41,8 @@ type Dialect interface { // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` LastInsertIDReturningSuffix(tableName, columnName string) string - // BuildForeignKeyName returns a foreign key name for the given table, field and reference - BuildForeignKeyName(tableName, field, dest string) string + // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference + BuildKeyName(kind, tableName string, fields ...string) string // CurrentDatabase return current database name CurrentDatabase() string diff --git a/dialect_common.go b/dialect_common.go index 7d0c3ce7..ef351f9e 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -144,8 +144,8 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s return "" } -func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string { - keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest) +func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { + keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_") return keyName } diff --git a/dialect_mysql.go b/dialect_mysql.go index 686ad1ee..d2fd53ca 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -157,8 +157,8 @@ func (mysql) SelectFromDummyTable() string { return "FROM DUAL" } -func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { - keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest) +func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string { + keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...) if utf8.RuneCountInString(keyName) <= 64 { return keyName } @@ -167,7 +167,7 @@ func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { bs := h.Sum(nil) // sha1 is 40 characters, keep first 24 characters of destination - destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(dest, "_")) + destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_")) if len(destRunes) > 24 { destRunes = destRunes[:24] } diff --git a/scope.go b/scope.go index 0a7e8861..c447d8a0 100644 --- a/scope.go +++ b/scope.go @@ -1165,7 +1165,8 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { } func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { - keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest) + // Compatible with old generated key + keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return @@ -1209,7 +1210,7 @@ func (scope *Scope) autoIndex() *Scope { for _, name := range names { if name == "INDEX" || name == "" { - name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) + name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName) } indexes[name] = append(indexes[name], field.DBName) } @@ -1220,7 +1221,7 @@ func (scope *Scope) autoIndex() *Scope { for _, name := range names { if name == "UNIQUE_INDEX" || name == "" { - name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) + name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName) } uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) } From 57f031e08380b8b76252ccbb6a0bc21c85b28a7d Mon Sep 17 00:00:00 2001 From: Piyush Mishra Date: Fri, 2 Feb 2018 22:29:40 +0530 Subject: [PATCH 17/47] Use table name to guess current database if none is given --- dialect_common.go | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/dialect_common.go b/dialect_common.go index ef351f9e..9ccff6e9 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -92,7 +92,8 @@ func (s *commonDialect) DataTypeOf(field *StructField) string { func (s commonDialect) HasIndex(tableName string, indexName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count) + currentDatabase, tableName := s.currentDatabaseAndTable(tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) return count > 0 } @@ -107,13 +108,25 @@ func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bo func (s commonDialect) HasTable(tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count) + currentDatabase, tableName := s.currentDatabaseAndTable(tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) return count > 0 } +func (s commonDialect) currentDatabaseAndTable(tableName string) (string, string) { + currentDatabase := s.CurrentDatabase() + if currentDatabase == "" && strings.Contains(tableName, ".") { + splitStrings := strings.SplitN(tableName, ".", 2) + currentDatabase = splitStrings[0] + tableName = splitStrings[1] + } + return currentDatabase, tableName +} + func (s commonDialect) HasColumn(tableName string, columnName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) + currentDatabase, tableName := s.currentDatabaseAndTable(tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) return count > 0 } From 87fc1b24737a885147240041293603eceb844356 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 3 Feb 2018 20:27:19 +0800 Subject: [PATCH 18/47] Refactor PR #1751 --- dialect.go | 8 ++++++++ dialect_common.go | 17 ++++------------- dialect_mysql.go | 3 ++- dialects/mssql/mssql.go | 14 ++++++++++++-- 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/dialect.go b/dialect.go index 9d3be249..90b1723f 100644 --- a/dialect.go +++ b/dialect.go @@ -114,3 +114,11 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel return fieldValue, dataType, size, strings.TrimSpace(additionalType) } + +func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) { + if strings.Contains(tableName, ".") { + splitStrings := strings.SplitN(tableName, ".", 2) + return splitStrings[0], splitStrings[1] + } + return dialect.CurrentDatabase(), tableName +} diff --git a/dialect_common.go b/dialect_common.go index 9ccff6e9..30f035a5 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -92,7 +92,7 @@ func (s *commonDialect) DataTypeOf(field *StructField) string { func (s commonDialect) HasIndex(tableName string, indexName string) bool { var count int - currentDatabase, tableName := s.currentDatabaseAndTable(tableName) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) return count > 0 } @@ -108,24 +108,14 @@ func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bo func (s commonDialect) HasTable(tableName string) bool { var count int - currentDatabase, tableName := s.currentDatabaseAndTable(tableName) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) return count > 0 } -func (s commonDialect) currentDatabaseAndTable(tableName string) (string, string) { - currentDatabase := s.CurrentDatabase() - if currentDatabase == "" && strings.Contains(tableName, ".") { - splitStrings := strings.SplitN(tableName, ".", 2) - currentDatabase = splitStrings[0] - tableName = splitStrings[1] - } - return currentDatabase, tableName -} - func (s commonDialect) HasColumn(tableName string, columnName string) bool { var count int - currentDatabase, tableName := s.currentDatabaseAndTable(tableName) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) return count > 0 } @@ -157,6 +147,7 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s return "" } +// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_") diff --git a/dialect_mysql.go b/dialect_mysql.go index d2fd53ca..f4858e10 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -144,7 +144,8 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) return count > 0 } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index de2ae7ca..a4f8e87c 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -128,13 +128,15 @@ func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { func (s mssql) HasTable(tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count) return count > 0 } func (s mssql) HasColumn(tableName string, columnName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) return count > 0 } @@ -168,3 +170,11 @@ func (mssql) SelectFromDummyTable() string { func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } + +func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { + if strings.Contains(tableName, ".") { + splitStrings := strings.SplitN(tableName, ".", 2) + return splitStrings[0], splitStrings[1] + } + return dialect.CurrentDatabase(), tableName +} From 3f98904fe72ef13a4add9c051dbff5509e233679 Mon Sep 17 00:00:00 2001 From: Louis Tran Date: Thu, 8 Feb 2018 16:21:39 -0800 Subject: [PATCH 19/47] Update PULL_REQUEST_TEMPLATE.md, A vs. An (#1757) Only a small change. `a` agreement => `an` agreement --- .github/PULL_REQUEST_TEMPLATE.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 0ee0d73b..4923abdc 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -6,7 +6,7 @@ Make sure these boxes checked before submitting your pull request. - [] Write good commit message, try to squash your commits into a single one - [] Run `./build.sh` in `gh-pages` branch for document changes -For significant changes like big bug fixes, new features, please open an issue to make a agreement on an implementation design/plan first before starting it. +For significant changes like big bug fixes, new features, please open an issue to make an agreement on an implementation design/plan first before starting it. Thank you. From 48e41440afa6a741a3e345f2cfbabca08f6fb1ac Mon Sep 17 00:00:00 2001 From: Adrian Heng Date: Fri, 9 Feb 2018 08:22:30 +0800 Subject: [PATCH 20/47] Allow for proper table creation with Jsonb fields (#1758) * DataTypeOf should now correctly identify dataValues that are 'json.RawMessage' types as 'jsonb' columns * move the json check to its own function * ran gofmt and did some minor tweaks to satisfy CodeClimate --- dialect_postgres.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/dialect_postgres.go b/dialect_postgres.go index 6fdf4df1..3bcea536 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -1,6 +1,7 @@ package gorm import ( + "encoding/json" "fmt" "reflect" "strings" @@ -68,9 +69,14 @@ func (s *postgres) DataTypeOf(field *StructField) string { default: if IsByteArrayOrSlice(dataValue) { sqlType = "bytea" + if isUUID(dataValue) { sqlType = "uuid" } + + if isJSON(dataValue) { + sqlType = "jsonb" + } } } } @@ -130,3 +136,8 @@ func isUUID(value reflect.Value) bool { lower := strings.ToLower(typename) return "uuid" == lower || "guid" == lower } + +func isJSON(value reflect.Value) bool { + _, ok := value.Interface().(json.RawMessage) + return ok +} From 38f96c65140f00f0b15efc495a487cfd5db510b8 Mon Sep 17 00:00:00 2001 From: daisy1754 Date: Fri, 9 Feb 2018 05:59:33 -0800 Subject: [PATCH 21/47] Add handling for empty Jsonb to fix #1649 (#1650) --- dialects/postgres/postgres.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index b8e76891..1d0dcb60 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -63,6 +63,9 @@ type Jsonb struct { // Value get value of Jsonb func (j Jsonb) Value() (driver.Value, error) { + if len(j.RawMessage) == 0 { + return nil, nil + } return j.MarshalJSON() } From 0e1cb6ece9d27b56ee6c1e514987175bba94711b Mon Sep 17 00:00:00 2001 From: Amit Yadav <154998+ayadav@users.noreply.github.com> Date: Fri, 9 Feb 2018 19:50:26 +0530 Subject: [PATCH 22/47] Add support to remove foreign key constraints (#1686) --- main.go | 8 ++++++++ scope.go | 10 ++++++++++ 2 files changed, 18 insertions(+) diff --git a/main.go b/main.go index 16fa0b79..b23ae2f2 100644 --- a/main.go +++ b/main.go @@ -611,6 +611,14 @@ func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate return scope.db } +// RemoveForeignKey Remove foreign key from the given scope, e.g: +// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") +func (s *DB) RemoveForeignKey(field string, dest string) *DB { + scope := s.clone().NewScope(s.Value) + scope.removeForeignKey(field, dest) + return scope.db +} + // Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode func (s *DB) Association(column string) *Association { var err error diff --git a/scope.go b/scope.go index c447d8a0..4c404b38 100644 --- a/scope.go +++ b/scope.go @@ -1175,6 +1175,16 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() } +func (scope *Scope) removeForeignKey(field string, dest string) { + keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest) + + if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { + return + } + var query = `ALTER TABLE %s DROP CONSTRAINT %s;` + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() +} + func (scope *Scope) removeIndex(indexName string) { scope.Dialect().RemoveIndex(scope.TableName(), indexName) } From e9309d361f8777f861997089ce142744109e1aa2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 9 Feb 2018 22:34:59 +0800 Subject: [PATCH 23/47] Fix build exception --- scope.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope.go b/scope.go index 4c404b38..0ef087bc 100644 --- a/scope.go +++ b/scope.go @@ -1176,7 +1176,7 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on } func (scope *Scope) removeForeignKey(field string, dest string) { - keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest) + keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest) if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return From 89a726ce5da26da893dd3c2d8475e1d66677fd9c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 9 Feb 2018 22:58:34 +0800 Subject: [PATCH 24/47] Move ModifyColumn implemention to Dialect --- dialect.go | 2 ++ dialect_common.go | 5 +++++ dialect_mysql.go | 5 +++++ dialects/mssql/mssql.go | 5 +++++ migration_test.go | 5 +---- scope.go | 2 +- 6 files changed, 19 insertions(+), 5 deletions(-) diff --git a/dialect.go b/dialect.go index 90b1723f..fe8e2f62 100644 --- a/dialect.go +++ b/dialect.go @@ -33,6 +33,8 @@ type Dialect interface { HasTable(tableName string) bool // HasColumn check has column or not HasColumn(tableName string, columnName string) bool + // ModifyColumn modify column's type + ModifyColumn(tableName string, columnName string, typ string) error // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case LimitAndOffsetSQL(limit, offset interface{}) string diff --git a/dialect_common.go b/dialect_common.go index 30f035a5..06d0bd07 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -120,6 +120,11 @@ func (s commonDialect) HasColumn(tableName string, columnName string) bool { return count > 0 } +func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ)) + return err +} + func (s commonDialect) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DATABASE()").Scan(&name) return diff --git a/dialect_mysql.go b/dialect_mysql.go index f4858e10..b9887a5c 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -127,6 +127,11 @@ func (s mysql) RemoveIndex(tableName string, indexName string) error { return err } +func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ)) + return err +} + func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { if limit != nil { if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index a4f8e87c..10a779de 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -140,6 +140,11 @@ func (s mssql) HasColumn(tableName string, columnName string) bool { return count > 0 } +func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ)) + return err +} + func (s mssql) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) return diff --git a/migration_test.go b/migration_test.go index 3f3a5c8f..6b4470a6 100644 --- a/migration_test.go +++ b/migration_test.go @@ -435,10 +435,7 @@ func TestMultipleIndexes(t *testing.T) { } func TestModifyColumnType(t *testing.T) { - dialect := os.Getenv("GORM_DIALECT") - if dialect != "postgres" && - dialect != "mysql" && - dialect != "mssql" { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" { t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type") } diff --git a/scope.go b/scope.go index 0ef087bc..a10cb3a2 100644 --- a/scope.go +++ b/scope.go @@ -1139,7 +1139,7 @@ func (scope *Scope) dropTable() *Scope { } func (scope *Scope) modifyColumn(column string, typ string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() + scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ)) } func (scope *Scope) dropColumn(column string) { From ae696d051fdd183c27ca75f7aa13bde5649b7264 Mon Sep 17 00:00:00 2001 From: miyauchi Date: Fri, 20 Oct 2017 10:24:09 +0900 Subject: [PATCH 25/47] corresponds timestamp precision for mysql --- dialect_mysql.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dialect_mysql.go b/dialect_mysql.go index b9887a5c..573bfc0f 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -96,9 +96,9 @@ func (s *mysql) DataTypeOf(field *StructField) string { case reflect.Struct: if _, ok := dataValue.Interface().(time.Time); ok { if _, ok := field.TagSettings["NOT NULL"]; ok { - sqlType = "timestamp" + sqlType = fmt.Sprintf("timestamp(%d)", size) } else { - sqlType = "timestamp NULL" + sqlType = fmt.Sprintf("timestamp(%d) NULL", size) } } default: From 8d4e3e5a832d78a11ea13bb1166569095238cfd0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 9 Feb 2018 23:18:47 +0800 Subject: [PATCH 26/47] Use tag PRECISION to set time's precision for mysql --- dialect_mysql.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/dialect_mysql.go b/dialect_mysql.go index 573bfc0f..fee61819 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -95,10 +95,15 @@ func (s *mysql) DataTypeOf(field *StructField) string { } case reflect.Struct: if _, ok := dataValue.Interface().(time.Time); ok { + precision := "" + if p, ok := field.TagSettings["PRECISION"]; ok { + precision = fmt.Sprintf("(%s)", p) + } + if _, ok := field.TagSettings["NOT NULL"]; ok { - sqlType = fmt.Sprintf("timestamp(%d)", size) + sqlType = fmt.Sprintf("timestamp%v", precision) } else { - sqlType = fmt.Sprintf("timestamp(%d) NULL", size) + sqlType = fmt.Sprintf("timestamp%v NULL", precision) } } default: From ec72a4cb6b0fc60c2dda9ab842416b17ae4b3ad7 Mon Sep 17 00:00:00 2001 From: Geoff Baskwill Date: Fri, 9 Feb 2018 10:22:53 -0500 Subject: [PATCH 27/47] Call Query callback chain when preloading many2many (#1622) When using `Preload` on a `many2many` association, the `Query` callback chain was not being called. This made it difficult to write a plugin that could reliably get called regardless of how objects were being queried. Now `handleManyToManyPreload` will call the `Query` callback chain for each object that is retrieved by following the association. Since the data has already been read by the `handleManyToManyPreload` method, a new scope setting called `gorm:skip_queryCallback` is set to `true` before calling the callbacks. Callbacks can check for the presence of this setting if they should not be run; the default `queryCallback` is an example of this case. Fixes jinzhu/gorm#1621. --- callback_query.go | 4 ++++ callback_query_preload.go | 4 ++++ preload_test.go | 40 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+) diff --git a/callback_query.go b/callback_query.go index 20e88161..f9940880 100644 --- a/callback_query.go +++ b/callback_query.go @@ -15,6 +15,10 @@ func init() { // queryCallback used to query data from database func queryCallback(scope *Scope) { + if _, skip := scope.Get("gorm:skip_query_callback"); skip { + return + } + defer scope.trace(NowFunc()) var ( diff --git a/callback_query_preload.go b/callback_query_preload.go index 21ab22ce..f2a218c7 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -324,6 +324,10 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface scope.scan(rows, columns, append(fields, joinTableFields...)) + scope.New(elem.Addr().Interface()). + Set("gorm:skip_query_callback", true). + callCallbacks(scope.db.parent.callbacks.queries) + var foreignKeys = make([]interface{}, len(sourceKeys)) // generate hashed forkey keys in join table for idx, joinTableField := range joinTableFields { diff --git a/preload_test.go b/preload_test.go index 1b89e77b..66f2629b 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1627,6 +1627,46 @@ func TestPrefixedPreloadDuplication(t *testing.T) { } } +func TestPreloadManyToManyCallbacks(t *testing.T) { + type ( + Level2 struct { + ID uint + } + Level1 struct { + ID uint + Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"` + } + ) + + DB.DropTableIfExists("level1_level2s") + DB.DropTableIfExists(new(Level1)) + DB.DropTableIfExists(new(Level2)) + + if err := DB.AutoMigrate(new(Level1), new(Level2)).Error; err != nil { + t.Error(err) + } + + lvl := Level1{ + Level2s: []Level2{ + Level2{}, + }, + } + DB.Save(&lvl) + + called := 0 + + DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) { + called = called + 1 + }) + + found := Level1{ID: lvl.ID} + DB.Preload("Level2s").First(&found, &found) + + if called != 2 { + t.Errorf("Wanted callback to be called 2 times but got %d", called) + } +} + func toJSONString(v interface{}) []byte { r, _ := json.MarshalIndent(v, "", " ") return r From 77eb925ea09471b7082d9d5749b2c96be726eac2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 00:07:16 +0800 Subject: [PATCH 28/47] Refactor preloading many2many for auto preload --- callback_query.go | 2 +- callback_query_preload.go | 5 ++++- preload_test.go | 14 ++++++++------ 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/callback_query.go b/callback_query.go index f9940880..ba10cc7d 100644 --- a/callback_query.go +++ b/callback_query.go @@ -15,7 +15,7 @@ func init() { // queryCallback used to query data from database func queryCallback(scope *Scope) { - if _, skip := scope.Get("gorm:skip_query_callback"); skip { + if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { return } diff --git a/callback_query_preload.go b/callback_query_preload.go index f2a218c7..30f6b585 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -10,6 +10,9 @@ import ( // preloadCallback used to preload associations func preloadCallback(scope *Scope) { + if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { + return + } if _, ok := scope.Get("gorm:auto_preload"); ok { autoPreload(scope) @@ -325,7 +328,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface scope.scan(rows, columns, append(fields, joinTableFields...)) scope.New(elem.Addr().Interface()). - Set("gorm:skip_query_callback", true). + InstanceSet("gorm:skip_query_callback", true). callCallbacks(scope.db.parent.callbacks.queries) var foreignKeys = make([]interface{}, len(sourceKeys)) diff --git a/preload_test.go b/preload_test.go index 66f2629b..311ad0be 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1630,10 +1630,12 @@ func TestPrefixedPreloadDuplication(t *testing.T) { func TestPreloadManyToManyCallbacks(t *testing.T) { type ( Level2 struct { - ID uint + ID uint + Name string } Level1 struct { ID uint + Name string Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"` } ) @@ -1647,8 +1649,9 @@ func TestPreloadManyToManyCallbacks(t *testing.T) { } lvl := Level1{ + Name: "l1", Level2s: []Level2{ - Level2{}, + Level2{Name: "l2-1"}, Level2{Name: "l2-2"}, }, } DB.Save(&lvl) @@ -1659,11 +1662,10 @@ func TestPreloadManyToManyCallbacks(t *testing.T) { called = called + 1 }) - found := Level1{ID: lvl.ID} - DB.Preload("Level2s").First(&found, &found) + DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID) - if called != 2 { - t.Errorf("Wanted callback to be called 2 times but got %d", called) + if called != 3 { + t.Errorf("Wanted callback to be called 3 times but got %d", called) } } From 97495a5e4067bc254ac33ce0e54c0af97a8c35d5 Mon Sep 17 00:00:00 2001 From: Wing Gao Date: Fri, 13 Oct 2017 15:08:55 +0800 Subject: [PATCH 29/47] Add new tag "not_auto_increment" to set a column can auto increase or not --- dialect_common.go | 14 ++++++++++++-- dialect_mysql.go | 12 ++++++------ dialect_postgres.go | 4 ++-- dialect_sqlite3.go | 4 ++-- 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/dialect_common.go b/dialect_common.go index 06d0bd07..64d720db 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -38,6 +38,16 @@ func (commonDialect) Quote(key string) string { return fmt.Sprintf(`"%s"`, key) } +func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { + // add a new tag "NOT_AUTO_INCREMENT" + _, not := field.TagSettings["NOT_AUTO_INCREMENT"] + if not { + return false + } + _, ok := field.TagSettings["AUTO_INCREMENT"] + return ok || field.IsPrimaryKey +} + func (s *commonDialect) DataTypeOf(field *StructField) string { var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) @@ -46,13 +56,13 @@ func (s *commonDialect) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "BOOLEAN" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if s.fieldCanAutoIncrement(field) { sqlType = "INTEGER AUTO_INCREMENT" } else { sqlType = "INTEGER" } case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if s.fieldCanAutoIncrement(field) { sqlType = "BIGINT AUTO_INCREMENT" } else { sqlType = "BIGINT" diff --git a/dialect_mysql.go b/dialect_mysql.go index fee61819..1feed1f6 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -44,42 +44,42 @@ func (s *mysql) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "boolean" case reflect.Int8: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "tinyint AUTO_INCREMENT" } else { sqlType = "tinyint" } case reflect.Int, reflect.Int16, reflect.Int32: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "int AUTO_INCREMENT" } else { sqlType = "int" } case reflect.Uint8: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "tinyint unsigned AUTO_INCREMENT" } else { sqlType = "tinyint unsigned" } case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "int unsigned AUTO_INCREMENT" } else { sqlType = "int unsigned" } case reflect.Int64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigint AUTO_INCREMENT" } else { sqlType = "bigint" } case reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigint unsigned AUTO_INCREMENT" } else { diff --git a/dialect_postgres.go b/dialect_postgres.go index 3bcea536..c44c6a5b 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -33,14 +33,14 @@ func (s *postgres) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "boolean" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "serial" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint32, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigserial" } else { diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index de9c05cb..f26f6be3 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -28,14 +28,14 @@ func (s *sqlite3) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "bool" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "integer primary key autoincrement" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint64: - if field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "integer primary key autoincrement" } else { From 2c68f695c3de3b05f31e0f4c0132a19e236a0f23 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 08:24:39 +0800 Subject: [PATCH 30/47] Set AutoIncrement to false with tag --- dialect_common.go | 9 +++------ main_test.go | 4 +++- test_all.sh | 2 +- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/dialect_common.go b/dialect_common.go index 64d720db..1e5e3b61 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -39,13 +39,10 @@ func (commonDialect) Quote(key string) string { } func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { - // add a new tag "NOT_AUTO_INCREMENT" - _, not := field.TagSettings["NOT_AUTO_INCREMENT"] - if not { - return false + if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + return value != "FALSE" } - _, ok := field.TagSettings["AUTO_INCREMENT"] - return ok || field.IsPrimaryKey + return field.IsPrimaryKey } func (s *commonDialect) DataTypeOf(field *StructField) string { diff --git a/main_test.go b/main_test.go index 34f96a86..499324bc 100644 --- a/main_test.go +++ b/main_test.go @@ -72,8 +72,10 @@ func OpenTestConnection() (db *gorm.DB, err error) { // db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) // db.SetLogger(log.New(os.Stdout, "\r\n", 0)) - if os.Getenv("DEBUG") == "true" { + if debug := os.Getenv("DEBUG"); debug == "true" { db.LogMode(true) + } else if debug == "false" { + db.LogMode(false) } db.DB().SetMaxIdleConns(10) diff --git a/test_all.sh b/test_all.sh index 80b319bf..5cfb3321 100755 --- a/test_all.sh +++ b/test_all.sh @@ -1,5 +1,5 @@ dialects=("postgres" "mysql" "mssql" "sqlite") for dialect in "${dialects[@]}" ; do - GORM_DIALECT=${dialect} go test + DEBUG=false GORM_DIALECT=${dialect} go test done From ae509ab23743e034b8c4e1d0d72d60a31ac7f6fd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 08:30:05 +0800 Subject: [PATCH 31/47] Port AUTO_INCREMENT false support to mssql --- dialects/mssql/mssql.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 10a779de..1c735a84 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -65,14 +65,14 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { case reflect.Bool: sqlType = "bit" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "int IDENTITY(1,1)" } else { sqlType = "int" } case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigint IDENTITY(1,1)" } else { @@ -111,6 +111,13 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } +func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool { + if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + return value != "FALSE" + } + return field.IsPrimaryKey +} + func (s mssql) HasIndex(tableName string, indexName string) bool { var count int s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) From e0f9087c8d67b035172c15aabe1953aae4293d9f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 11:06:43 +0800 Subject: [PATCH 32/47] Setup test env --- docker-compose.yml | 30 ++++++++++++++++++++++++++++++ main_test.go | 26 +++++++++++--------------- wercker.yml | 17 +++++++++++++++-- 3 files changed, 56 insertions(+), 17 deletions(-) create mode 100644 docker-compose.yml diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..79bf5fc3 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,30 @@ +version: '3' + +services: + mysql: + image: 'mysql:latest' + ports: + - 9910:3306 + environment: + - MYSQL_DATABASE=gorm + - MYSQL_USER=gorm + - MYSQL_PASSWORD=gorm + - MYSQL_RANDOM_ROOT_PASSWORD="yes" + postgres: + image: 'postgres:latest' + ports: + - 9920:5432 + environment: + - POSTGRES_USER=gorm + - POSTGRES_DB=gorm + - POSTGRES_PASSWORD=gorm + mssql: + image: 'mcmoe/mssqldocker:latest' + ports: + - 9930:1433 + environment: + - ACCEPT_EULA=Y + - SA_PASSWORD=LoremIpsum86 + - MSSQL_DB=gorm + - MSSQL_USER=gorm + - MSSQL_PASSWORD=LoremIpsum86 diff --git a/main_test.go b/main_test.go index 499324bc..83e6f7aa 100644 --- a/main_test.go +++ b/main_test.go @@ -36,27 +36,20 @@ func init() { } func OpenTestConnection() (db *gorm.DB, err error) { + dbDSN := os.Getenv("GORM_DSN") switch os.Getenv("GORM_DIALECT") { case "mysql": - // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; - // CREATE DATABASE gorm; - // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; fmt.Println("testing mysql...") - dbhost := os.Getenv("GORM_DBADDRESS") - if dbhost != "" { - dbhost = fmt.Sprintf("tcp(%v)", dbhost) + if dbDSN == "" { + dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" } - db, err = gorm.Open("mysql", fmt.Sprintf("gorm:gorm@%v/gorm?charset=utf8&parseTime=True", dbhost)) + db, err = gorm.Open("mysql", dbDSN) case "postgres": fmt.Println("testing postgres...") - dbhost := os.Getenv("GORM_DBHOST") - if dbhost != "" { - dbhost = fmt.Sprintf("host=%v ", dbhost) + if dbDSN == "" { + dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" } - db, err = gorm.Open("postgres", fmt.Sprintf("%vuser=gorm password=gorm DB.name=gorm sslmode=disable", dbhost)) - case "foundation": - fmt.Println("testing foundation...") - db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable") + db, err = gorm.Open("postgres", dbDSN) case "mssql": // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; // CREATE DATABASE gorm; @@ -64,7 +57,10 @@ func OpenTestConnection() (db *gorm.DB, err error) { // CREATE USER gorm FROM LOGIN gorm; // sp_changedbowner 'gorm'; fmt.Println("testing mssql...") - db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm") + if dbDSN == "" { + dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + } + db, err = gorm.Open("mssql", dbDSN) default: fmt.Println("testing sqlite3...") db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) diff --git a/wercker.yml b/wercker.yml index ff6fb17c..c3045c54 100644 --- a/wercker.yml +++ b/wercker.yml @@ -13,6 +13,14 @@ services: POSTGRES_USER: gorm POSTGRES_PASSWORD: gorm POSTGRES_DB: gorm + - name: mssql + id: mcmoe/mssqldocker: + env: + ACCEPT_EULA: Y + SA_PASSWORD: LoremIpsum86 + MSSQL_DB: gorm + MSSQL_USER: gorm + MSSQL_PASSWORD: LoremIpsum86 # The steps that will be executed in the build pipeline build: @@ -45,9 +53,14 @@ build: - script: name: test mysql code: | - GORM_DIALECT=mysql GORM_DBADDRESS=mariadb:3306 go test ./... + GORM_DIALECT=mysql GORM_DSN=gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True go test ./... - script: name: test postgres code: | - GORM_DIALECT=postgres GORM_DBHOST=postgres go test ./... + GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test mssql + code: | + GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test ./... From 2e5d98a42020e99e9270e5caa9125b9de2dc56e8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 11:45:38 +0800 Subject: [PATCH 33/47] Update wercker.yml --- wercker.yml | 102 +++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 98 insertions(+), 4 deletions(-) diff --git a/wercker.yml b/wercker.yml index c3045c54..2f2370b3 100644 --- a/wercker.yml +++ b/wercker.yml @@ -2,19 +2,73 @@ box: golang services: - - id: mariadb:10.0 + - name: mariadb + id: mariadb:latest env: MYSQL_DATABASE: gorm MYSQL_USER: gorm MYSQL_PASSWORD: gorm MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - id: postgres + - name: mysql + id: mysql:8 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql57 + id: mysql:5.7 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql56 + id: mysql:5.6 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql55 + id: mysql:5.5 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: postgres + id: postgres:latest + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres96 + id: postgres:9.6 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres95 + id: postgres:9.5 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres94 + id: postgres:9.4 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres93 + id: postgres:9.3 env: POSTGRES_USER: gorm POSTGRES_PASSWORD: gorm POSTGRES_DB: gorm - name: mssql - id: mcmoe/mssqldocker: + id: mcmoe/mssqldocker:latest env: ACCEPT_EULA: Y SA_PASSWORD: LoremIpsum86 @@ -50,16 +104,56 @@ build: code: | go test ./... + - script: + name: test mariadb + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test ./... + - script: name: test mysql code: | - GORM_DIALECT=mysql GORM_DSN=gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True go test ./... + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test ./... + + - script: + name: test mysql5.7 + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test ./... + + - script: + name: test mysql5.6 + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test ./... + + - script: + name: test mysql5.5 + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test ./... - script: name: test postgres code: | GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + - script: + name: test postgres96 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test postgres95 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test postgres94 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test postgres93 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + - script: name: test mssql code: | From 706b8f55da67c097aede7662a45c9ae577ea3ed9 Mon Sep 17 00:00:00 2001 From: Giuseppe Date: Sat, 10 Feb 2018 05:28:01 +0100 Subject: [PATCH 34/47] Use brackets for quoting (#1736) --- dialects/mssql/mssql.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 1c735a84..1dd5fb69 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -54,7 +54,7 @@ func (mssql) BindVar(i int) string { } func (mssql) Quote(key string) string { - return fmt.Sprintf(`"%s"`, key) + return fmt.Sprintf(`[%s]`, key) } func (s *mssql) DataTypeOf(field *gorm.StructField) string { From 21fb3ae1febe4581f80a4d5633f3fffd6d10a606 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 13:15:04 +0800 Subject: [PATCH 35/47] Simplify GitHub templates --- .github/ISSUE_TEMPLATE.md | 21 ++++++--------------- .github/PULL_REQUEST_TEMPLATE.md | 5 ----- README.md | 2 +- 3 files changed, 7 insertions(+), 21 deletions(-) diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index 8b4f03b7..a0b64bfa 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -1,10 +1,4 @@ -Before posting a bug report about a problem, please try to verify that it is a bug and that it has not been reported already, please apply corresponding GitHub labels to the issue, for feature requests, please apply `type:feature`. - -DON'T post usage related questions, ask in https://gitter.im/jinzhu/gorm or http://stackoverflow.com/questions/tagged/go-gorm, - -Please answer these questions before submitting your issue. Thanks! - - +Your issue may already be reported! Please search on the [issue track](https://github.com/jinzhu/gorm/issues) before creating one. ### What version of Go are you using (`go version`)? @@ -12,9 +6,9 @@ Please answer these questions before submitting your issue. Thanks! ### Which database and its version are you using? -### What did you do? +### Please provide a complete runnable program to reproduce your issue. **IMPORTANT** -Please provide a complete runnable program to reproduce your issue. +Need to runnable with [GORM's docker compose config](https://github.com/jinzhu/gorm/blob/master/docker-compose.yml) or please provides your config. ```go package main @@ -32,10 +26,9 @@ var db *gorm.DB func init() { var err error db, err = gorm.Open("sqlite3", "test.db") - // Please use below username, password as your database's account for the script. - // db, err = gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") - // db, err = gorm.Open("mysql", "gorm:gorm@/dbname?charset=utf8&parseTime=True") - // db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm") + // db, err = gorm.Open("postgres", "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable") + // db, err = gorm.Open("mysql", "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True") + // db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm") if err != nil { panic(err) } @@ -43,8 +36,6 @@ func init() { } func main() { - // your code here - if /* failure condition */ { fmt.Println("failed") } else { diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 4923abdc..b467b6ce 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -3,12 +3,7 @@ Make sure these boxes checked before submitting your pull request. - [] Do only one thing - [] No API-breaking changes - [] New code/logic commented & tested -- [] Write good commit message, try to squash your commits into a single one -- [] Run `./build.sh` in `gh-pages` branch for document changes For significant changes like big bug fixes, new features, please open an issue to make an agreement on an implementation design/plan first before starting it. -Thank you. - - ### What did this pull request do? diff --git a/README.md b/README.md index e5c21dc5..8c6e2302 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm) -[![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921) +[![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b) [![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) ## Overview From aa3fd6de13fee7e0ae715eeaad3bc2f329db2366 Mon Sep 17 00:00:00 2001 From: Jess Smith Date: Sat, 10 Feb 2018 01:26:01 -0500 Subject: [PATCH 36/47] Sort column names before generating SQL in `DB.UpdateColumns` (#1734) --- callback_update.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/callback_update.go b/callback_update.go index 6948439f..373bd726 100644 --- a/callback_update.go +++ b/callback_update.go @@ -3,6 +3,7 @@ package gorm import ( "errors" "fmt" + "sort" "strings" ) @@ -59,7 +60,16 @@ func updateCallback(scope *Scope) { var sqls []string if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - for column, value := range updateAttrs.(map[string]interface{}) { + // Sort the column names so that the generated SQL is the same every time. + updateMap := updateAttrs.(map[string]interface{}) + var columns []string + for c := range updateMap { + columns = append(columns, c) + } + sort.Strings(columns) + + for _, column := range columns { + value := updateMap[column] sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) } } else { From 9235b47ea28d816ef25d6bf4e037ccb5c7c7096b Mon Sep 17 00:00:00 2001 From: joe-at-startupmedia Date: Wed, 4 Oct 2017 08:19:16 +0000 Subject: [PATCH 37/47] Allows foreign keys to be saved without saving the assoication when specified #1628 --- callback_save.go | 53 ++++++++++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/callback_save.go b/callback_save.go index f4bc918e..ad4eda2f 100644 --- a/callback_save.go +++ b/callback_save.go @@ -11,35 +11,34 @@ func commitOrRollbackTransactionCallback(scope *Scope) { } func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) { - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { - if relationship := field.Relationship; relationship != nil { - return true, relationship - } - } - } - return false, nil + if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { + if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { + return true, field.Relationship + } + return false, field.Relationship + } + return false, nil } func saveBeforeAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } - for _, field := range scope.Fields() { - if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() - scope.Err(scope.NewDB().Save(fieldValue).Error) - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) - } - } - } - } - } + for _, field := range scope.Fields() { + ok, relationship := saveFieldAsAssociation(scope, field); + if relationship != nil && relationship.Kind == "belongs_to" { + fieldValue := field.Field.Addr().Interface() + if ok && scope.shouldSaveAssociations() { + scope.Err(scope.NewDB().Save(fieldValue).Error) + } + if len(relationship.ForeignFieldNames) != 0 { + // set value's foreign key + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + } + } + } + } + } } func saveAfterAssociationsCallback(scope *Scope) { @@ -47,7 +46,7 @@ func saveAfterAssociationsCallback(scope *Scope) { return } for _, field := range scope.Fields() { - if ok, relationship := saveFieldAsAssociation(scope, field); ok && + if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { value := field.Field From 9f409820dfdfc2ab7cb20a56d4cefdf1a111c315 Mon Sep 17 00:00:00 2001 From: joe-at-startupmedia Date: Tue, 10 Oct 2017 18:20:56 +0000 Subject: [PATCH 38/47] Formatting code with gomt --- callback_save.go | 50 ++++++++++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/callback_save.go b/callback_save.go index ad4eda2f..fa32c907 100644 --- a/callback_save.go +++ b/callback_save.go @@ -11,34 +11,34 @@ func commitOrRollbackTransactionCallback(scope *Scope) { } func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) { - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { - return true, field.Relationship - } - return false, field.Relationship - } - return false, nil + if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { + if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { + return true, field.Relationship + } + return false, field.Relationship + } + return false, nil } func saveBeforeAssociationsCallback(scope *Scope) { - for _, field := range scope.Fields() { - ok, relationship := saveFieldAsAssociation(scope, field); - if relationship != nil && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() - if ok && scope.shouldSaveAssociations() { - scope.Err(scope.NewDB().Save(fieldValue).Error) - } - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) - } - } - } - } - } + for _, field := range scope.Fields() { + ok, relationship := saveFieldAsAssociation(scope, field) + if relationship != nil && relationship.Kind == "belongs_to" { + fieldValue := field.Field.Addr().Interface() + if ok && scope.shouldSaveAssociations() { + scope.Err(scope.NewDB().Save(fieldValue).Error) + } + if len(relationship.ForeignFieldNames) != 0 { + // set value's foreign key + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + } + } + } + } + } } func saveAfterAssociationsCallback(scope *Scope) { From 63cb513b4978a49870ff20d27fb18c721f64d977 Mon Sep 17 00:00:00 2001 From: Ezequiel Muns Date: Wed, 1 Nov 2017 18:45:08 +0100 Subject: [PATCH 39/47] Tests for saving foreign key when save_associations:false --- association_test.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/association_test.go b/association_test.go index c84f84ed..f37047d1 100644 --- a/association_test.go +++ b/association_test.go @@ -902,6 +902,20 @@ func TestSkipSaveAssociation(t *testing.T) { DB.Save(&User{Name: "jinzhu", Company: Company{Name: "skip_save_association"}}) if !DB.Where("name = ?", "skip_save_association").First(&Company{}).RecordNotFound() { - t.Errorf("Company skip_save_association should not been saved") + t.Errorf("Company skip_save_association should not have been saved") + } + + // if foreign key is set, this should be saved even if association isn't + company := Company{Name: "skip_save_association"} + DB.Save(&company) + company.Name = "skip_save_association_modified" + user := User{Name: "jinzhu", CompanyID: company.ID, Company: company} + DB.Save(&user) + + if !DB.Where("name = ?", "skip_save_association_modified").First(&Company{}).RecordNotFound() { + t.Errorf("Company skip_save_association should not have been updated") + } + if DB.Where("id = ? AND company_id = ?", user.ID, company.ID).First(&User{}).RecordNotFound() { + t.Errorf("User's foreign key should have been saved") } } From 43dc867644b879f8f87fd0598ac0b459232d9293 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 15:16:20 +0800 Subject: [PATCH 40/47] Allow save association relations w/o saving association --- association_test.go | 2 +- callback_save.go | 31 ++++++++++++++++++------------- scope.go | 2 +- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/association_test.go b/association_test.go index f37047d1..34822dbc 100644 --- a/association_test.go +++ b/association_test.go @@ -909,7 +909,7 @@ func TestSkipSaveAssociation(t *testing.T) { company := Company{Name: "skip_save_association"} DB.Save(&company) company.Name = "skip_save_association_modified" - user := User{Name: "jinzhu", CompanyID: company.ID, Company: company} + user := User{Name: "jinzhu", Company: company} DB.Save(&user) if !DB.Where("name = ?", "skip_save_association_modified").First(&Company{}).RecordNotFound() { diff --git a/callback_save.go b/callback_save.go index fa32c907..544354d0 100644 --- a/callback_save.go +++ b/callback_save.go @@ -12,22 +12,25 @@ func commitOrRollbackTransactionCallback(scope *Scope) { func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) { if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { - return true, field.Relationship + if field.Relationship != nil { + if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; (!ok || (value != "false" && value != "skip")) && scope.allowSaveAssociations() { + return true, field.Relationship + } + return false, field.Relationship } - return false, field.Relationship } return false, nil } func saveBeforeAssociationsCallback(scope *Scope) { for _, field := range scope.Fields() { - ok, relationship := saveFieldAsAssociation(scope, field) - if relationship != nil && relationship.Kind == "belongs_to" { + if allowSaveAssociation, relationship := saveFieldAsAssociation(scope, field); relationship != nil && relationship.Kind == "belongs_to" { fieldValue := field.Field.Addr().Interface() - if ok && scope.shouldSaveAssociations() { + + if allowSaveAssociation { scope.Err(scope.NewDB().Save(fieldValue).Error) } + if len(relationship.ForeignFieldNames) != 0 { // set value's foreign key for idx, fieldName := range relationship.ForeignFieldNames { @@ -42,11 +45,8 @@ func saveBeforeAssociationsCallback(scope *Scope) { } func saveAfterAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } for _, field := range scope.Fields() { - if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship != nil && + if allowSaveAssociation, relationship := saveFieldAsAssociation(scope, field); relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { value := field.Field @@ -70,9 +70,11 @@ func saveAfterAssociationsCallback(scope *Scope) { scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) } - scope.Err(newDB.Save(elem).Error) + if allowSaveAssociation { + scope.Err(newDB.Save(elem).Error) + } - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { + if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil && !newScope.PrimaryKeyZero() { scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) } } @@ -91,7 +93,10 @@ func saveAfterAssociationsCallback(scope *Scope) { if relationship.PolymorphicType != "" { scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) } - scope.Err(scope.NewDB().Save(elem).Error) + + if allowSaveAssociation { + scope.Err(scope.NewDB().Save(elem).Error) + } } } } diff --git a/scope.go b/scope.go index a10cb3a2..9ae33913 100644 --- a/scope.go +++ b/scope.go @@ -993,7 +993,7 @@ func (scope *Scope) changeableField(field *Field) bool { return true } -func (scope *Scope) shouldSaveAssociations() bool { +func (scope *Scope) allowSaveAssociations() bool { if saveAssociations, ok := scope.Get("gorm:save_associations"); ok { if v, ok := saveAssociations.(bool); ok && !v { return false From b2b568daa8e27966c39c942e5aefc74bcc8af88d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 16:47:48 +0800 Subject: [PATCH 41/47] Add tag association_autoupdate, association_autocreate, association_save_reference support --- association_test.go | 147 +++++++++++++++++++++++++++++++++++++++--- callback_save.go | 153 +++++++++++++++++++++++++++++++------------- query_test.go | 2 +- scope.go | 12 ---- 4 files changed, 248 insertions(+), 66 deletions(-) diff --git a/association_test.go b/association_test.go index 34822dbc..60d0cf48 100644 --- a/association_test.go +++ b/association_test.go @@ -885,7 +885,7 @@ func TestHasManyChildrenWithOneStruct(t *testing.T) { DB.Save(&category) } -func TestSkipSaveAssociation(t *testing.T) { +func TestAutoSaveBelongsToAssociation(t *testing.T) { type Company struct { gorm.Model Name string @@ -895,27 +895,156 @@ func TestSkipSaveAssociation(t *testing.T) { gorm.Model Name string CompanyID uint - Company Company `gorm:"save_associations:false"` + Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"` } + + DB.Where("name = ?", "auto_save_association").Delete(&Company{}) DB.AutoMigrate(&Company{}, &User{}) - DB.Save(&User{Name: "jinzhu", Company: Company{Name: "skip_save_association"}}) + DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_association"}}) - if !DB.Where("name = ?", "skip_save_association").First(&Company{}).RecordNotFound() { - t.Errorf("Company skip_save_association should not have been saved") + if !DB.Where("name = ?", "auto_save_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_association should not have been saved when autosave is false") } // if foreign key is set, this should be saved even if association isn't - company := Company{Name: "skip_save_association"} + company := Company{Name: "auto_save_association"} DB.Save(&company) - company.Name = "skip_save_association_modified" + + company.Name = "auto_save_association_new_name" user := User{Name: "jinzhu", Company: company} + DB.Save(&user) - if !DB.Where("name = ?", "skip_save_association_modified").First(&Company{}).RecordNotFound() { - t.Errorf("Company skip_save_association should not have been updated") + if !DB.Where("name = ?", "auto_save_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") } + if DB.Where("id = ? AND company_id = ?", user.ID, company.ID).First(&User{}).RecordNotFound() { t.Errorf("User's foreign key should have been saved") } + + user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_association_2"}} + DB.Set("gorm:association_autocreate", true).Save(&user2) + if DB.Where("name = ?", "auto_save_association_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_association_2 should been created when autocreate is true") + } + + user2.Company.Name = "auto_save_association_2_newname" + DB.Set("gorm:association_autoupdate", true).Save(&user2) + + if DB.Where("name = ?", "auto_save_association_2_newname").First(&Company{}).RecordNotFound() { + t.Errorf("Company should been updated") + } +} + +func TestAutoSaveHasOneAssociation(t *testing.T) { + type Company struct { + gorm.Model + UserID uint + Name string + } + + type User struct { + gorm.Model + Name string + Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"` + } + + DB.Where("name = ?", "auto_save_has_one_association").Delete(&Company{}) + DB.AutoMigrate(&Company{}, &User{}) + + DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_has_one_association"}}) + + if !DB.Where("name = ?", "auto_save_has_one_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_has_one_association should not have been saved when autosave is false") + } + + company := Company{Name: "auto_save_has_one_association"} + DB.Save(&company) + + company.Name = "auto_save_has_one_association_new_name" + user := User{Name: "jinzhu", Company: company} + + DB.Save(&user) + + if !DB.Where("name = ?", "auto_save_has_one_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if !DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association", user.ID).First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if user.Company.UserID == 0 { + t.Errorf("UserID should be assigned") + } + + company.Name = "auto_save_has_one_association_2_new_name" + DB.Set("gorm:association_autoupdate", true).Save(&user) + + if DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association_new_name", user.ID).First(&Company{}).RecordNotFound() { + t.Errorf("Company should been updated") + } + + user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_has_one_association_2"}} + DB.Set("gorm:association_autocreate", true).Save(&user2) + if DB.Where("name = ?", "auto_save_has_one_association_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_has_one_association_2 should been created when autocreate is true") + } +} + +func TestAutoSaveMany2ManyAssociation(t *testing.T) { + type Company struct { + gorm.Model + Name string + } + + type User struct { + gorm.Model + Name string + Companies []Company `gorm:"many2many:user_companies;association_autoupdate:false;association_autocreate:false;"` + } + + DB.AutoMigrate(&Company{}, &User{}) + + DB.Save(&User{Name: "jinzhu", Companies: []Company{{Name: "auto_save_m2m_association"}}}) + + if !DB.Where("name = ?", "auto_save_m2m_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_m2m_association should not have been saved when autosave is false") + } + + company := Company{Name: "auto_save_m2m_association"} + DB.Save(&company) + + company.Name = "auto_save_m2m_association_new_name" + user := User{Name: "jinzhu", Companies: []Company{company, {Name: "auto_save_m2m_association_new_name_2"}}} + + DB.Save(&user) + + if !DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if !DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not been created") + } + + if DB.Model(&user).Association("Companies").Count() != 1 { + t.Errorf("Relationship should been saved") + } + + DB.Set("gorm:association_autoupdate", true).Set("gorm:association_autocreate", true).Save(&user) + + if DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should been updated") + } + + if DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company should been created") + } + + if DB.Model(&user).Association("Companies").Count() != 2 { + t.Errorf("Relationship should been updated") + } } diff --git a/callback_save.go b/callback_save.go index 544354d0..243c986e 100644 --- a/callback_save.go +++ b/callback_save.go @@ -1,6 +1,9 @@ package gorm -import "reflect" +import ( + "reflect" + "strings" +) func beginTransactionCallback(scope *Scope) { scope.Begin() @@ -10,33 +13,79 @@ func commitOrRollbackTransactionCallback(scope *Scope) { scope.CommitOrRollback() } -func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) { - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if field.Relationship != nil { - if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; (!ok || (value != "false" && value != "skip")) && scope.allowSaveAssociations() { - return true, field.Relationship +func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) { + checkTruth := func(value interface{}) bool { + if v, ok := value.(bool); ok && !v { + return false + } + + if v, ok := value.(string); ok { + v = strings.ToLower(v) + if v == "false" || v != "skip" { + return false + } + } + + return true + } + + if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { + if r = field.Relationship; r != nil { + autoUpdate, autoCreate, saveReference = true, true, true + + if value, ok := scope.Get("gorm:save_associations"); ok { + autoUpdate = checkTruth(value) + autoCreate = autoUpdate + } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok { + autoUpdate = checkTruth(value) + autoCreate = autoUpdate + } + + if value, ok := scope.Get("gorm:association_autoupdate"); ok { + autoUpdate = checkTruth(value) + } else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok { + autoUpdate = checkTruth(value) + } + + if value, ok := scope.Get("gorm:association_autocreate"); ok { + autoCreate = checkTruth(value) + } else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok { + autoCreate = checkTruth(value) + } + + if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok { + saveReference = checkTruth(value) } - return false, field.Relationship } } - return false, nil + + return } func saveBeforeAssociationsCallback(scope *Scope) { for _, field := range scope.Fields() { - if allowSaveAssociation, relationship := saveFieldAsAssociation(scope, field); relationship != nil && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() + autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) - if allowSaveAssociation { + if relationship != nil && relationship.Kind == "belongs_to" { + fieldValue := field.Field.Addr().Interface() + newScope := scope.New(fieldValue) + + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(scope.NewDB().Save(fieldValue).Error) + } + } else if autoUpdate { scope.Err(scope.NewDB().Save(fieldValue).Error) } - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + if saveReference { + if len(relationship.ForeignFieldNames) != 0 { + // set value's foreign key + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + } } } } @@ -46,8 +95,9 @@ func saveBeforeAssociationsCallback(scope *Scope) { func saveAfterAssociationsCallback(scope *Scope) { for _, field := range scope.Fields() { - if allowSaveAssociation, relationship := saveFieldAsAssociation(scope, field); relationship != nil && - (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { + autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) + + if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { value := field.Field switch value.Kind() { @@ -57,7 +107,41 @@ func saveAfterAssociationsCallback(scope *Scope) { elem := value.Index(i).Addr().Interface() newScope := newDB.NewScope(elem) - if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { + if saveReference { + if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.FieldByName(associationForeignName); ok { + scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) + } + } + } + + if relationship.PolymorphicType != "" { + scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) + } + } + + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(newDB.Save(elem).Error) + } + } else if autoUpdate { + scope.Err(newDB.Save(elem).Error) + } + + if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference { + if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { + scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) + } + } + } + default: + elem := value.Addr().Interface() + newScope := scope.New(elem) + + if saveReference { + if len(relationship.ForeignFieldNames) != 0 { for idx, fieldName := range relationship.ForeignFieldNames { associationForeignName := relationship.AssociationForeignDBNames[idx] if f, ok := scope.FieldByName(associationForeignName); ok { @@ -69,32 +153,13 @@ func saveAfterAssociationsCallback(scope *Scope) { if relationship.PolymorphicType != "" { scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) } - - if allowSaveAssociation { - scope.Err(newDB.Save(elem).Error) - } - - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil && !newScope.PrimaryKeyZero() { - scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) - } - } - default: - elem := value.Addr().Interface() - newScope := scope.New(elem) - if len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } } - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - - if allowSaveAssociation { + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(scope.NewDB().Save(elem).Error) + } + } else if autoUpdate { scope.Err(scope.NewDB().Save(elem).Error) } } diff --git a/query_test.go b/query_test.go index def84e04..98721800 100644 --- a/query_test.go +++ b/query_test.go @@ -389,7 +389,7 @@ func TestOffset(t *testing.T) { DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) } var users1, users2, users3, users4 []User - DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) + DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { t.Errorf("Offset should work") diff --git a/scope.go b/scope.go index 9ae33913..125e02b0 100644 --- a/scope.go +++ b/scope.go @@ -993,18 +993,6 @@ func (scope *Scope) changeableField(field *Field) bool { return true } -func (scope *Scope) allowSaveAssociations() bool { - if saveAssociations, ok := scope.Get("gorm:save_associations"); ok { - if v, ok := saveAssociations.(bool); ok && !v { - return false - } - if v, ok := saveAssociations.(string); ok && (v != "skip") { - return false - } - } - return true && !scope.HasError() -} - func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { toScope := scope.db.NewScope(value) tx := scope.db.Set("gorm:association:source", scope.Value) From 2940c553eb9763e966effbdca702e2d5b2b255da Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 18:01:41 +0800 Subject: [PATCH 42/47] Add DB setting gorm:association_save_reference --- callback_save.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/callback_save.go b/callback_save.go index 243c986e..ef267141 100644 --- a/callback_save.go +++ b/callback_save.go @@ -53,7 +53,9 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea autoCreate = checkTruth(value) } - if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok { + if value, ok := scope.Get("gorm:association_save_reference"); ok { + saveReference = checkTruth(value) + } else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok { saveReference = checkTruth(value) } } From c6ce739b2a4d3b26af9326a31723883b4f136a74 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 19:25:58 +0800 Subject: [PATCH 43/47] Convert auto_increment's value to lower case when checking its value --- dialect_common.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dialect_common.go b/dialect_common.go index 1e5e3b61..fbbaef33 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -40,7 +40,7 @@ func (commonDialect) Quote(key string) string { func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - return value != "FALSE" + return strings.ToLower(value) != "false" } return field.IsPrimaryKey } From c0359226dc500354fd8c18366ad2fb6616f8c322 Mon Sep 17 00:00:00 2001 From: Emil Davtyan Date: Sat, 10 Feb 2018 12:31:55 +0100 Subject: [PATCH 44/47] Removed unnecessary cloning. (#1462) `NewScope` clones `DB` no need to chain a call to clone with `NewScope`. --- main.go | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/main.go b/main.go index b23ae2f2..fc4859ac 100644 --- a/main.go +++ b/main.go @@ -274,7 +274,7 @@ func (s *DB) Assign(attrs ...interface{}) *DB { // First find first record that match given conditions, order by primary key func (s *DB) First(out interface{}, where ...interface{}) *DB { - newScope := s.clone().NewScope(out) + newScope := s.NewScope(out) newScope.Search.Limit(1) return newScope.Set("gorm:order_by_primary_key", "ASC"). inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db @@ -282,7 +282,7 @@ func (s *DB) First(out interface{}, where ...interface{}) *DB { // Last find last record that match given conditions, order by primary key func (s *DB) Last(out interface{}, where ...interface{}) *DB { - newScope := s.clone().NewScope(out) + newScope := s.NewScope(out) newScope.Search.Limit(1) return newScope.Set("gorm:order_by_primary_key", "DESC"). inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db @@ -290,12 +290,12 @@ func (s *DB) Last(out interface{}, where ...interface{}) *DB { // Find find records that match given conditions func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db + return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } // Scan scan value to a struct func (s *DB) Scan(dest interface{}) *DB { - return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db + return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db } // Row return `*sql.Row` with given conditions @@ -311,8 +311,8 @@ func (s *DB) Rows() (*sql.Rows, error) { // ScanRows scan `*sql.Rows` to give struct func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { var ( - clone = s.clone() - scope = clone.NewScope(result) + scope = s.NewScope(result) + clone = scope.db columns, err = rows.Columns() ) @@ -337,7 +337,7 @@ func (s *DB) Count(value interface{}) *DB { // Related get related associations func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.clone().NewScope(s.Value).related(value, foreignKeys...).db + return s.NewScope(s.Value).related(value, foreignKeys...).db } // FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) @@ -377,7 +377,7 @@ func (s *DB) Update(attrs ...interface{}) *DB { // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { - return s.clone().NewScope(s.Value). + return s.NewScope(s.Value). Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). InstanceSet("gorm:update_interface", values). callCallbacks(s.parent.callbacks.updates).db @@ -390,7 +390,7 @@ func (s *DB) UpdateColumn(attrs ...interface{}) *DB { // UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (s *DB) UpdateColumns(values interface{}) *DB { - return s.clone().NewScope(s.Value). + return s.NewScope(s.Value). Set("gorm:update_column", true). Set("gorm:save_associations", false). InstanceSet("gorm:update_interface", values). @@ -399,7 +399,7 @@ func (s *DB) UpdateColumns(values interface{}) *DB { // Save update value in database, if the value doesn't have primary key, will insert it func (s *DB) Save(value interface{}) *DB { - scope := s.clone().NewScope(value) + scope := s.NewScope(value) if !scope.PrimaryKeyZero() { newDB := scope.callCallbacks(s.parent.callbacks.updates).db if newDB.Error == nil && newDB.RowsAffected == 0 { @@ -412,13 +412,13 @@ func (s *DB) Save(value interface{}) *DB { // Create insert the value into database func (s *DB) Create(value interface{}) *DB { - scope := s.clone().NewScope(value) + scope := s.NewScope(value) return scope.callCallbacks(s.parent.callbacks.creates).db } // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition func (s *DB) Delete(value interface{}, where ...interface{}) *DB { - return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db + return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db } // Raw use raw sql as conditions, won't run it unless invoked by other methods @@ -429,7 +429,7 @@ func (s *DB) Raw(sql string, values ...interface{}) *DB { // Exec execute raw sql func (s *DB) Exec(sql string, values ...interface{}) *DB { - scope := s.clone().NewScope(nil) + scope := s.NewScope(nil) generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") scope.Raw(generatedSQL) @@ -495,7 +495,7 @@ func (s *DB) Rollback() *DB { // NewRecord check if value's primary key is blank func (s *DB) NewRecord(value interface{}) bool { - return s.clone().NewScope(value).PrimaryKeyZero() + return s.NewScope(value).PrimaryKeyZero() } // RecordNotFound check if returning ErrRecordNotFound error @@ -544,7 +544,7 @@ func (s *DB) DropTableIfExists(values ...interface{}) *DB { // HasTable check has table or not func (s *DB) HasTable(value interface{}) bool { var ( - scope = s.clone().NewScope(value) + scope = s.NewScope(value) tableName string ) @@ -570,14 +570,14 @@ func (s *DB) AutoMigrate(values ...interface{}) *DB { // ModifyColumn modify column to type func (s *DB) ModifyColumn(column string, typ string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.modifyColumn(column, typ) return scope.db } // DropColumn drop a column func (s *DB) DropColumn(column string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.dropColumn(column) return scope.db } @@ -598,7 +598,7 @@ func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { // RemoveIndex remove index with name func (s *DB) RemoveIndex(indexName string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.removeIndex(indexName) return scope.db } @@ -606,7 +606,7 @@ func (s *DB) RemoveIndex(indexName string) *DB { // AddForeignKey Add foreign key to the given scope, e.g: // db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.addForeignKey(field, dest, onDelete, onUpdate) return scope.db } From 8e7d807ebf902bf239ac9ccd509b42659ee378ba Mon Sep 17 00:00:00 2001 From: Nathan Osman Date: Fri, 22 Dec 2017 17:59:15 -0800 Subject: [PATCH 45/47] Allow name of column to be customized to support self-referencing many2many fields. --- customize_column_test.go | 22 ++++++++++++++++++++++ join_table_handler.go | 19 ++++++++++++++++++- model_struct.go | 15 ++++++++++++++- 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/customize_column_test.go b/customize_column_test.go index ddb536b8..c96b2d40 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -279,3 +279,25 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { t.Errorf("should preload discount from coupon") } } + +type SelfReferencingUser struct { + gorm.Model + Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;AssociationForeignKey:ID=friend_id"` +} + +func TestSelfReferencingMany2ManyColumn(t *testing.T) { + DB.DropTable(&SelfReferencingUser{}, "UserFriends") + DB.AutoMigrate(&SelfReferencingUser{}) + + friend := SelfReferencingUser{} + if err := DB.Create(&friend).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + user := SelfReferencingUser{ + Friends: []*SelfReferencingUser{&friend}, + } + if err := DB.Create(&user).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } +} diff --git a/join_table_handler.go b/join_table_handler.go index 2d1a5055..b4be6cf9 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -109,7 +109,24 @@ func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[strin // Add create relationship in join table for source and destination func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { scope := db.NewScope("") - searchMap := s.getSearchMap(db, source, destination) + searchMap := map[string]interface{}{} + + // getSearchMap() cannot be used here since the source and destination + // model types may be identical + + sourceScope := db.NewScope(source) + for _, foreignKey := range s.Source.ForeignKeys { + if field, ok := sourceScope.FieldByName(foreignKey.AssociationDBName); ok { + searchMap[foreignKey.DBName] = field.Field.Interface() + } + } + + destinationScope := db.NewScope(destination) + for _, foreignKey := range s.Destination.ForeignKeys { + if field, ok := destinationScope.FieldByName(foreignKey.AssociationDBName); ok { + searchMap[foreignKey.DBName] = field.Field.Interface() + } + } var assignColumns, binVars, conditions []string var values []interface{} diff --git a/model_struct.go b/model_struct.go index 315028c4..463ec517 100644 --- a/model_struct.go +++ b/model_struct.go @@ -289,11 +289,24 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } for _, name := range associationForeignKeys { + + // In order to allow self-referencing many2many tables, the name + // may be followed by "=" to allow renaming the column + parts := strings.Split(name, "=") + name = parts[0] + if field, ok := toScope.FieldByName(name); ok { // association foreign keys (db names) relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) + + // If a new name was provided for the field, use it + name = field.DBName + if len(parts) > 1 { + name = parts[1] + } + // join table foreign keys for association - joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + joinTableDBName := ToDBName(elemType.Name()) + "_" + name relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) } } From 44b9911f5157e6b7d03c08fcf730ded96b2eda66 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 20:57:39 +0800 Subject: [PATCH 46/47] Refactor self referencing m2m support --- association.go | 4 +- customize_column_test.go | 30 +++++++++-- join_table_handler.go | 60 +++++++++------------ model_struct.go | 111 +++++++++++++++++++++++---------------- 4 files changed, 117 insertions(+), 88 deletions(-) diff --git a/association.go b/association.go index 3d522ccc..8c6d9864 100644 --- a/association.go +++ b/association.go @@ -107,7 +107,7 @@ func (association *Association) Replace(values ...interface{}) *Association { if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) } } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) @@ -173,7 +173,7 @@ func (association *Association) Delete(values ...interface{}) *Association { sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) } else { var foreignKeyMap = map[string]interface{}{} for _, foreignKey := range relationship.ForeignDBNames { diff --git a/customize_column_test.go b/customize_column_test.go index c96b2d40..629d85f9 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -282,22 +282,44 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { type SelfReferencingUser struct { gorm.Model - Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;AssociationForeignKey:ID=friend_id"` + Name string + Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"` } func TestSelfReferencingMany2ManyColumn(t *testing.T) { DB.DropTable(&SelfReferencingUser{}, "UserFriends") DB.AutoMigrate(&SelfReferencingUser{}) - friend := SelfReferencingUser{} - if err := DB.Create(&friend).Error; err != nil { + friend1 := SelfReferencingUser{Name: "friend1_m2m"} + if err := DB.Create(&friend1).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + friend2 := SelfReferencingUser{Name: "friend2_m2m"} + if err := DB.Create(&friend2).Error; err != nil { t.Errorf("no error should happen, but got %v", err) } user := SelfReferencingUser{ - Friends: []*SelfReferencingUser{&friend}, + Name: "self_m2m", + Friends: []*SelfReferencingUser{&friend1, &friend2}, } + if err := DB.Create(&user).Error; err != nil { t.Errorf("no error should happen, but got %v", err) } + + if DB.Model(&user).Association("Friends").Count() != 2 { + t.Errorf("Should find created friends correctly") + } + + var newUser = SelfReferencingUser{} + + if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if len(newUser.Friends) != 2 { + t.Errorf("Should preload created frineds for self reference m2m") + } } diff --git a/join_table_handler.go b/join_table_handler.go index b4be6cf9..f07541ba 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -82,55 +82,40 @@ func (s JoinTableHandler) Table(db *DB) string { return s.TableName } -func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} { - values := map[string]interface{}{} - +func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) { for _, source := range sources { scope := db.NewScope(source) modelType := scope.GetModelStruct().ModelType - if s.Source.ModelType == modelType { - for _, foreignKey := range s.Source.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - values[foreignKey.DBName] = field.Field.Interface() - } - } - } else if s.Destination.ModelType == modelType { - for _, foreignKey := range s.Destination.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - values[foreignKey.DBName] = field.Field.Interface() + for _, joinTableSource := range joinTableSources { + if joinTableSource.ModelType == modelType { + for _, foreignKey := range joinTableSource.ForeignKeys { + if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { + conditionMap[foreignKey.DBName] = field.Field.Interface() + } } + break } } } - return values } // Add create relationship in join table for source and destination func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { - scope := db.NewScope("") - searchMap := map[string]interface{}{} + var ( + scope = db.NewScope("") + conditionMap = map[string]interface{}{} + ) - // getSearchMap() cannot be used here since the source and destination - // model types may be identical + // Update condition map for source + s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source) - sourceScope := db.NewScope(source) - for _, foreignKey := range s.Source.ForeignKeys { - if field, ok := sourceScope.FieldByName(foreignKey.AssociationDBName); ok { - searchMap[foreignKey.DBName] = field.Field.Interface() - } - } - - destinationScope := db.NewScope(destination) - for _, foreignKey := range s.Destination.ForeignKeys { - if field, ok := destinationScope.FieldByName(foreignKey.AssociationDBName); ok { - searchMap[foreignKey.DBName] = field.Field.Interface() - } - } + // Update condition map for destination + s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination) var assignColumns, binVars, conditions []string var values []interface{} - for key, value := range searchMap { + for key, value := range conditionMap { assignColumns = append(assignColumns, scope.Quote(key)) binVars = append(binVars, `?`) conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) @@ -158,12 +143,15 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source // Delete delete relationship in join table for sources func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { var ( - scope = db.NewScope(nil) - conditions []string - values []interface{} + scope = db.NewScope(nil) + conditions []string + values []interface{} + conditionMap = map[string]interface{}{} ) - for key, value := range s.getSearchMap(db, sources...) { + s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...) + + for key, value := range conditionMap { conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) values = append(values, value) } diff --git a/model_struct.go b/model_struct.go index 463ec517..f571e2e8 100644 --- a/model_struct.go +++ b/model_struct.go @@ -249,11 +249,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { ) if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") + foreignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") + if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { + associationForeignKeys = strings.Split(foreignKey, ",") + } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + associationForeignKeys = strings.Split(foreignKey, ",") } for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { @@ -264,50 +266,65 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { relationship.Kind = "many_to_many" - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, field.DBName) + { // Foreign Keys for Source + joinTableDBNames := []string{} + + if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" { + joinTableDBNames = strings.Split(foreignKey, ",") } - } - for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - // source foreign keys (db names) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) - // join table foreign keys for source - joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName - relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) - } - } - - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range toScope.PrimaryFields() { - associationForeignKeys = append(associationForeignKeys, field.DBName) - } - } - - for _, name := range associationForeignKeys { - - // In order to allow self-referencing many2many tables, the name - // may be followed by "=" to allow renaming the column - parts := strings.Split(name, "=") - name = parts[0] - - if field, ok := toScope.FieldByName(name); ok { - // association foreign keys (db names) - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) - - // If a new name was provided for the field, use it - name = field.DBName - if len(parts) > 1 { - name = parts[1] + // if no foreign keys defined with tag + if len(foreignKeys) == 0 { + for _, field := range modelStruct.PrimaryFields { + foreignKeys = append(foreignKeys, field.DBName) } + } - // join table foreign keys for association - joinTableDBName := ToDBName(elemType.Name()) + "_" + name - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + for idx, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { + // source foreign keys (db names) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) + + // setup join table foreign keys for source + if len(joinTableDBNames) > idx { + // if defined join table's foreign key + relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) + } else { + defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName + relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) + } + } + } + } + + { // Foreign Keys for Association (Destination) + associationJoinTableDBNames := []string{} + + if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" { + associationJoinTableDBNames = strings.Split(foreignKey, ",") + } + + // if no association foreign keys defined with tag + if len(associationForeignKeys) == 0 { + for _, field := range toScope.PrimaryFields() { + associationForeignKeys = append(associationForeignKeys, field.DBName) + } + } + + for idx, name := range associationForeignKeys { + if field, ok := toScope.FieldByName(name); ok { + // association foreign keys (db names) + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) + + // setup join table foreign keys for association + if len(associationJoinTableDBNames) > idx { + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) + } else { + // join table foreign keys for association + joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + } + } } } @@ -412,11 +429,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { ) if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") + tagForeignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") + if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { + tagAssociationForeignKeys = strings.Split(foreignKey, ",") + } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + tagAssociationForeignKeys = strings.Split(foreignKey, ",") } if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { From cb7c41e0b6e3863e7934a50c0aed76b8cfb61bfd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 22:14:18 +0800 Subject: [PATCH 47/47] Add more tests for self-referencing many2many relationship --- customize_column_test.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/customize_column_test.go b/customize_column_test.go index 629d85f9..5e19d6f4 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -322,4 +322,25 @@ func TestSelfReferencingMany2ManyColumn(t *testing.T) { if len(newUser.Friends) != 2 { t.Errorf("Should preload created frineds for self reference m2m") } + + DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"}) + if DB.Model(&user).Association("Friends").Count() != 3 { + t.Errorf("Should find created friends correctly") + } + + DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"}) + if DB.Model(&user).Association("Friends").Count() != 1 { + t.Errorf("Should find created friends correctly") + } + + friend := SelfReferencingUser{} + DB.Model(&newUser).Association("Friends").Find(&friend) + if friend.Name != "friend4_m2m" { + t.Errorf("Should find created friends correctly") + } + + DB.Model(&newUser).Association("Friends").Delete(friend) + if DB.Model(&user).Association("Friends").Count() != 0 { + t.Errorf("All friends should be deleted") + } }