diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index 8b4f03b7..a0b64bfa 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -1,10 +1,4 @@ -Before posting a bug report about a problem, please try to verify that it is a bug and that it has not been reported already, please apply corresponding GitHub labels to the issue, for feature requests, please apply `type:feature`. - -DON'T post usage related questions, ask in https://gitter.im/jinzhu/gorm or http://stackoverflow.com/questions/tagged/go-gorm, - -Please answer these questions before submitting your issue. Thanks! - - +Your issue may already be reported! Please search on the [issue track](https://github.com/jinzhu/gorm/issues) before creating one. ### What version of Go are you using (`go version`)? @@ -12,9 +6,9 @@ Please answer these questions before submitting your issue. Thanks! ### Which database and its version are you using? -### What did you do? +### Please provide a complete runnable program to reproduce your issue. **IMPORTANT** -Please provide a complete runnable program to reproduce your issue. +Need to runnable with [GORM's docker compose config](https://github.com/jinzhu/gorm/blob/master/docker-compose.yml) or please provides your config. ```go package main @@ -32,10 +26,9 @@ var db *gorm.DB func init() { var err error db, err = gorm.Open("sqlite3", "test.db") - // Please use below username, password as your database's account for the script. - // db, err = gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") - // db, err = gorm.Open("mysql", "gorm:gorm@/dbname?charset=utf8&parseTime=True") - // db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm") + // db, err = gorm.Open("postgres", "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable") + // db, err = gorm.Open("mysql", "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True") + // db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm") if err != nil { panic(err) } @@ -43,8 +36,6 @@ func init() { } func main() { - // your code here - if /* failure condition */ { fmt.Println("failed") } else { diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 0ee0d73b..b467b6ce 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -3,12 +3,7 @@ Make sure these boxes checked before submitting your pull request. - [] Do only one thing - [] No API-breaking changes - [] New code/logic commented & tested -- [] Write good commit message, try to squash your commits into a single one -- [] Run `./build.sh` in `gh-pages` branch for document changes - -For significant changes like big bug fixes, new features, please open an issue to make a agreement on an implementation design/plan first before starting it. - -Thank you. +For significant changes like big bug fixes, new features, please open an issue to make an agreement on an implementation design/plan first before starting it. ### What did this pull request do? diff --git a/README.md b/README.md index 44eb4a69..8c6e2302 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,8 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -[![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921) +[![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm) +[![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b) [![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) ## Overview @@ -31,7 +32,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Supporting the project -[![http://patreon.com/jinzhu](http://patreon_public_assets.s3.amazonaws.com/sized/becomeAPatronBanner.png)](http://patreon.com/jinzhu) +[![http://patreon.com/jinzhu](https://c5.patreon.com/external/logo/become_a_patron_button.png)](http://patreon.com/jinzhu) ## Author diff --git a/association.go b/association.go index 3d522ccc..8c6d9864 100644 --- a/association.go +++ b/association.go @@ -107,7 +107,7 @@ func (association *Association) Replace(values ...interface{}) *Association { if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) } } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) @@ -173,7 +173,7 @@ func (association *Association) Delete(values ...interface{}) *Association { sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) } else { var foreignKeyMap = map[string]interface{}{} for _, foreignKey := range relationship.ForeignDBNames { diff --git a/association_test.go b/association_test.go index c84f84ed..60d0cf48 100644 --- a/association_test.go +++ b/association_test.go @@ -885,7 +885,7 @@ func TestHasManyChildrenWithOneStruct(t *testing.T) { DB.Save(&category) } -func TestSkipSaveAssociation(t *testing.T) { +func TestAutoSaveBelongsToAssociation(t *testing.T) { type Company struct { gorm.Model Name string @@ -895,13 +895,156 @@ func TestSkipSaveAssociation(t *testing.T) { gorm.Model Name string CompanyID uint - Company Company `gorm:"save_associations:false"` + Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"` } + + DB.Where("name = ?", "auto_save_association").Delete(&Company{}) DB.AutoMigrate(&Company{}, &User{}) - DB.Save(&User{Name: "jinzhu", Company: Company{Name: "skip_save_association"}}) + DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_association"}}) - if !DB.Where("name = ?", "skip_save_association").First(&Company{}).RecordNotFound() { - t.Errorf("Company skip_save_association should not been saved") + if !DB.Where("name = ?", "auto_save_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_association should not have been saved when autosave is false") + } + + // if foreign key is set, this should be saved even if association isn't + company := Company{Name: "auto_save_association"} + DB.Save(&company) + + company.Name = "auto_save_association_new_name" + user := User{Name: "jinzhu", Company: company} + + DB.Save(&user) + + if !DB.Where("name = ?", "auto_save_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if DB.Where("id = ? AND company_id = ?", user.ID, company.ID).First(&User{}).RecordNotFound() { + t.Errorf("User's foreign key should have been saved") + } + + user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_association_2"}} + DB.Set("gorm:association_autocreate", true).Save(&user2) + if DB.Where("name = ?", "auto_save_association_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_association_2 should been created when autocreate is true") + } + + user2.Company.Name = "auto_save_association_2_newname" + DB.Set("gorm:association_autoupdate", true).Save(&user2) + + if DB.Where("name = ?", "auto_save_association_2_newname").First(&Company{}).RecordNotFound() { + t.Errorf("Company should been updated") + } +} + +func TestAutoSaveHasOneAssociation(t *testing.T) { + type Company struct { + gorm.Model + UserID uint + Name string + } + + type User struct { + gorm.Model + Name string + Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"` + } + + DB.Where("name = ?", "auto_save_has_one_association").Delete(&Company{}) + DB.AutoMigrate(&Company{}, &User{}) + + DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_has_one_association"}}) + + if !DB.Where("name = ?", "auto_save_has_one_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_has_one_association should not have been saved when autosave is false") + } + + company := Company{Name: "auto_save_has_one_association"} + DB.Save(&company) + + company.Name = "auto_save_has_one_association_new_name" + user := User{Name: "jinzhu", Company: company} + + DB.Save(&user) + + if !DB.Where("name = ?", "auto_save_has_one_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if !DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association", user.ID).First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if user.Company.UserID == 0 { + t.Errorf("UserID should be assigned") + } + + company.Name = "auto_save_has_one_association_2_new_name" + DB.Set("gorm:association_autoupdate", true).Save(&user) + + if DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association_new_name", user.ID).First(&Company{}).RecordNotFound() { + t.Errorf("Company should been updated") + } + + user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_has_one_association_2"}} + DB.Set("gorm:association_autocreate", true).Save(&user2) + if DB.Where("name = ?", "auto_save_has_one_association_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_has_one_association_2 should been created when autocreate is true") + } +} + +func TestAutoSaveMany2ManyAssociation(t *testing.T) { + type Company struct { + gorm.Model + Name string + } + + type User struct { + gorm.Model + Name string + Companies []Company `gorm:"many2many:user_companies;association_autoupdate:false;association_autocreate:false;"` + } + + DB.AutoMigrate(&Company{}, &User{}) + + DB.Save(&User{Name: "jinzhu", Companies: []Company{{Name: "auto_save_m2m_association"}}}) + + if !DB.Where("name = ?", "auto_save_m2m_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_m2m_association should not have been saved when autosave is false") + } + + company := Company{Name: "auto_save_m2m_association"} + DB.Save(&company) + + company.Name = "auto_save_m2m_association_new_name" + user := User{Name: "jinzhu", Companies: []Company{company, {Name: "auto_save_m2m_association_new_name_2"}}} + + DB.Save(&user) + + if !DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if !DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not been created") + } + + if DB.Model(&user).Association("Companies").Count() != 1 { + t.Errorf("Relationship should been saved") + } + + DB.Set("gorm:association_autoupdate", true).Set("gorm:association_autocreate", true).Save(&user) + + if DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should been updated") + } + + if DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company should been created") + } + + if DB.Model(&user).Association("Companies").Count() != 2 { + t.Errorf("Relationship should been updated") } } diff --git a/callback.go b/callback.go index 17f75451..a4382147 100644 --- a/callback.go +++ b/callback.go @@ -1,8 +1,6 @@ package gorm -import ( - "fmt" -) +import "log" // DefaultCallback default callbacks defined by gorm var DefaultCallback = &Callback{} @@ -95,7 +93,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { if cp.kind == "row_query" { if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { - fmt.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName) + log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName) cp.before = "gorm:row_query" } } @@ -109,7 +107,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope * // Remove a registered callback // db.Callback().Create().Remove("gorm:update_time_stamp_when_create") func (cp *CallbackProcessor) Remove(callbackName string) { - fmt.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()) + log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()) cp.name = callbackName cp.remove = true cp.parent.processors = append(cp.parent.processors, cp) @@ -122,7 +120,7 @@ func (cp *CallbackProcessor) Remove(callbackName string) { // scope.SetColumn("Updated", now) // }) func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { - fmt.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()) + log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()) cp.name = callbackName cp.processor = &callback cp.replace = true @@ -161,7 +159,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { for _, cp := range cps { // show warning message the callback name already exists if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { - fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) + log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) } allNames = append(allNames, cp.name) } diff --git a/callback_query.go b/callback_query.go index 20e88161..ba10cc7d 100644 --- a/callback_query.go +++ b/callback_query.go @@ -15,6 +15,10 @@ func init() { // queryCallback used to query data from database func queryCallback(scope *Scope) { + if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { + return + } + defer scope.trace(NowFunc()) var ( diff --git a/callback_query_preload.go b/callback_query_preload.go index 21ab22ce..30f6b585 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -10,6 +10,9 @@ import ( // preloadCallback used to preload associations func preloadCallback(scope *Scope) { + if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { + return + } if _, ok := scope.Get("gorm:auto_preload"); ok { autoPreload(scope) @@ -324,6 +327,10 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface scope.scan(rows, columns, append(fields, joinTableFields...)) + scope.New(elem.Addr().Interface()). + InstanceSet("gorm:skip_query_callback", true). + callCallbacks(scope.db.parent.callbacks.queries) + var foreignKeys = make([]interface{}, len(sourceKeys)) // generate hashed forkey keys in join table for idx, joinTableField := range joinTableFields { diff --git a/callback_save.go b/callback_save.go index f4bc918e..ef267141 100644 --- a/callback_save.go +++ b/callback_save.go @@ -1,6 +1,9 @@ package gorm -import "reflect" +import ( + "reflect" + "strings" +) func beginTransactionCallback(scope *Scope) { scope.Begin() @@ -10,31 +13,81 @@ func commitOrRollbackTransactionCallback(scope *Scope) { scope.CommitOrRollback() } -func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) { +func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) { + checkTruth := func(value interface{}) bool { + if v, ok := value.(bool); ok && !v { + return false + } + + if v, ok := value.(string); ok { + v = strings.ToLower(v) + if v == "false" || v != "skip" { + return false + } + } + + return true + } + if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { - if relationship := field.Relationship; relationship != nil { - return true, relationship + if r = field.Relationship; r != nil { + autoUpdate, autoCreate, saveReference = true, true, true + + if value, ok := scope.Get("gorm:save_associations"); ok { + autoUpdate = checkTruth(value) + autoCreate = autoUpdate + } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok { + autoUpdate = checkTruth(value) + autoCreate = autoUpdate + } + + if value, ok := scope.Get("gorm:association_autoupdate"); ok { + autoUpdate = checkTruth(value) + } else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok { + autoUpdate = checkTruth(value) + } + + if value, ok := scope.Get("gorm:association_autocreate"); ok { + autoCreate = checkTruth(value) + } else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok { + autoCreate = checkTruth(value) + } + + if value, ok := scope.Get("gorm:association_save_reference"); ok { + saveReference = checkTruth(value) + } else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok { + saveReference = checkTruth(value) } } } - return false, nil + + return } func saveBeforeAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } for _, field := range scope.Fields() { - if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship.Kind == "belongs_to" { + autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) + + if relationship != nil && relationship.Kind == "belongs_to" { fieldValue := field.Field.Addr().Interface() - scope.Err(scope.NewDB().Save(fieldValue).Error) - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + newScope := scope.New(fieldValue) + + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(scope.NewDB().Save(fieldValue).Error) + } + } else if autoUpdate { + scope.Err(scope.NewDB().Save(fieldValue).Error) + } + + if saveReference { + if len(relationship.ForeignFieldNames) != 0 { + // set value's foreign key + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + } } } } @@ -43,12 +96,10 @@ func saveBeforeAssociationsCallback(scope *Scope) { } func saveAfterAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } for _, field := range scope.Fields() { - if ok, relationship := saveFieldAsAssociation(scope, field); ok && - (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { + autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) + + if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { value := field.Field switch value.Kind() { @@ -58,7 +109,41 @@ func saveAfterAssociationsCallback(scope *Scope) { elem := value.Index(i).Addr().Interface() newScope := newDB.NewScope(elem) - if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { + if saveReference { + if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.FieldByName(associationForeignName); ok { + scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) + } + } + } + + if relationship.PolymorphicType != "" { + scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) + } + } + + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(newDB.Save(elem).Error) + } + } else if autoUpdate { + scope.Err(newDB.Save(elem).Error) + } + + if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference { + if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { + scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) + } + } + } + default: + elem := value.Addr().Interface() + newScope := scope.New(elem) + + if saveReference { + if len(relationship.ForeignFieldNames) != 0 { for idx, fieldName := range relationship.ForeignFieldNames { associationForeignName := relationship.AssociationForeignDBNames[idx] if f, ok := scope.FieldByName(associationForeignName); ok { @@ -70,29 +155,15 @@ func saveAfterAssociationsCallback(scope *Scope) { if relationship.PolymorphicType != "" { scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) } + } - scope.Err(newDB.Save(elem).Error) - - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { - scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(scope.NewDB().Save(elem).Error) } + } else if autoUpdate { + scope.Err(scope.NewDB().Save(elem).Error) } - default: - elem := value.Addr().Interface() - newScope := scope.New(elem) - if len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - scope.Err(scope.NewDB().Save(elem).Error) } } } diff --git a/callback_update.go b/callback_update.go index 6948439f..373bd726 100644 --- a/callback_update.go +++ b/callback_update.go @@ -3,6 +3,7 @@ package gorm import ( "errors" "fmt" + "sort" "strings" ) @@ -59,7 +60,16 @@ func updateCallback(scope *Scope) { var sqls []string if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - for column, value := range updateAttrs.(map[string]interface{}) { + // Sort the column names so that the generated SQL is the same every time. + updateMap := updateAttrs.(map[string]interface{}) + var columns []string + for c := range updateMap { + columns = append(columns, c) + } + sort.Strings(columns) + + for _, column := range columns { + value := updateMap[column] sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) } } else { diff --git a/customize_column_test.go b/customize_column_test.go index ddb536b8..5e19d6f4 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -279,3 +279,68 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { t.Errorf("should preload discount from coupon") } } + +type SelfReferencingUser struct { + gorm.Model + Name string + Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"` +} + +func TestSelfReferencingMany2ManyColumn(t *testing.T) { + DB.DropTable(&SelfReferencingUser{}, "UserFriends") + DB.AutoMigrate(&SelfReferencingUser{}) + + friend1 := SelfReferencingUser{Name: "friend1_m2m"} + if err := DB.Create(&friend1).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + friend2 := SelfReferencingUser{Name: "friend2_m2m"} + if err := DB.Create(&friend2).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + user := SelfReferencingUser{ + Name: "self_m2m", + Friends: []*SelfReferencingUser{&friend1, &friend2}, + } + + if err := DB.Create(&user).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if DB.Model(&user).Association("Friends").Count() != 2 { + t.Errorf("Should find created friends correctly") + } + + var newUser = SelfReferencingUser{} + + if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if len(newUser.Friends) != 2 { + t.Errorf("Should preload created frineds for self reference m2m") + } + + DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"}) + if DB.Model(&user).Association("Friends").Count() != 3 { + t.Errorf("Should find created friends correctly") + } + + DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"}) + if DB.Model(&user).Association("Friends").Count() != 1 { + t.Errorf("Should find created friends correctly") + } + + friend := SelfReferencingUser{} + DB.Model(&newUser).Association("Friends").Find(&friend) + if friend.Name != "friend4_m2m" { + t.Errorf("Should find created friends correctly") + } + + DB.Model(&newUser).Association("Friends").Delete(friend) + if DB.Model(&user).Association("Friends").Count() != 0 { + t.Errorf("All friends should be deleted") + } +} diff --git a/dialect.go b/dialect.go index e879588b..fe8e2f62 100644 --- a/dialect.go +++ b/dialect.go @@ -33,6 +33,8 @@ type Dialect interface { HasTable(tableName string) bool // HasColumn check has column or not HasColumn(tableName string, columnName string) bool + // ModifyColumn modify column's type + ModifyColumn(tableName string, columnName string, typ string) error // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case LimitAndOffsetSQL(limit, offset interface{}) string @@ -41,8 +43,8 @@ type Dialect interface { // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` LastInsertIDReturningSuffix(tableName, columnName string) string - // BuildForeignKeyName returns a foreign key name for the given table, field and reference - BuildForeignKeyName(tableName, field, dest string) string + // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference + BuildKeyName(kind, tableName string, fields ...string) string // CurrentDatabase return current database name CurrentDatabase() string @@ -114,3 +116,11 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel return fieldValue, dataType, size, strings.TrimSpace(additionalType) } + +func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) { + if strings.Contains(tableName, ".") { + splitStrings := strings.SplitN(tableName, ".", 2) + return splitStrings[0], splitStrings[1] + } + return dialect.CurrentDatabase(), tableName +} diff --git a/dialect_common.go b/dialect_common.go index a99627f2..fbbaef33 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -38,6 +38,13 @@ func (commonDialect) Quote(key string) string { return fmt.Sprintf(`"%s"`, key) } +func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { + if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + return strings.ToLower(value) != "false" + } + return field.IsPrimaryKey +} + func (s *commonDialect) DataTypeOf(field *StructField) string { var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) @@ -46,13 +53,13 @@ func (s *commonDialect) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "BOOLEAN" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if s.fieldCanAutoIncrement(field) { sqlType = "INTEGER AUTO_INCREMENT" } else { sqlType = "INTEGER" } case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if s.fieldCanAutoIncrement(field) { sqlType = "BIGINT AUTO_INCREMENT" } else { sqlType = "BIGINT" @@ -92,7 +99,8 @@ func (s *commonDialect) DataTypeOf(field *StructField) string { func (s commonDialect) HasIndex(tableName string, indexName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) return count > 0 } @@ -107,16 +115,23 @@ func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bo func (s commonDialect) HasTable(tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) return count > 0 } func (s commonDialect) HasColumn(tableName string, columnName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) return count > 0 } +func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ)) + return err +} + func (s commonDialect) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DATABASE()").Scan(&name) return @@ -144,9 +159,10 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s return "" } -func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string { - keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest) - keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") +// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference +func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { + keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) + keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_") return keyName } diff --git a/dialect_mysql.go b/dialect_mysql.go index 6fcd0079..1feed1f6 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -44,42 +44,42 @@ func (s *mysql) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "boolean" case reflect.Int8: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "tinyint AUTO_INCREMENT" } else { sqlType = "tinyint" } case reflect.Int, reflect.Int16, reflect.Int32: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "int AUTO_INCREMENT" } else { sqlType = "int" } case reflect.Uint8: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "tinyint unsigned AUTO_INCREMENT" } else { sqlType = "tinyint unsigned" } case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "int unsigned AUTO_INCREMENT" } else { sqlType = "int unsigned" } case reflect.Int64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigint AUTO_INCREMENT" } else { sqlType = "bigint" } case reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigint unsigned AUTO_INCREMENT" } else { @@ -95,10 +95,15 @@ func (s *mysql) DataTypeOf(field *StructField) string { } case reflect.Struct: if _, ok := dataValue.Interface().(time.Time); ok { + precision := "" + if p, ok := field.TagSettings["PRECISION"]; ok { + precision = fmt.Sprintf("(%s)", p) + } + if _, ok := field.TagSettings["NOT NULL"]; ok { - sqlType = "timestamp" + sqlType = fmt.Sprintf("timestamp%v", precision) } else { - sqlType = "timestamp NULL" + sqlType = fmt.Sprintf("timestamp%v NULL", precision) } } default: @@ -127,6 +132,11 @@ func (s mysql) RemoveIndex(tableName string, indexName string) error { return err } +func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ)) + return err +} + func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { if limit != nil { if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { @@ -144,7 +154,8 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) return count > 0 } @@ -157,8 +168,8 @@ func (mysql) SelectFromDummyTable() string { return "FROM DUAL" } -func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { - keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest) +func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string { + keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...) if utf8.RuneCountInString(keyName) <= 64 { return keyName } @@ -166,8 +177,8 @@ func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { h.Write([]byte(keyName)) bs := h.Sum(nil) - // sha1 is 40 digits, keep first 24 characters of destination - destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_")) + // sha1 is 40 characters, keep first 24 characters of destination + destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_")) if len(destRunes) > 24 { destRunes = destRunes[:24] } diff --git a/dialect_postgres.go b/dialect_postgres.go index ed5248e0..c44c6a5b 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -1,6 +1,7 @@ package gorm import ( + "encoding/json" "fmt" "reflect" "strings" @@ -13,6 +14,7 @@ type postgres struct { func init() { RegisterDialect("postgres", &postgres{}) + RegisterDialect("cloudsqlpostgres", &postgres{}) } func (postgres) GetName() string { @@ -31,14 +33,14 @@ func (s *postgres) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "boolean" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "serial" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint32, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigserial" } else { @@ -67,8 +69,14 @@ func (s *postgres) DataTypeOf(field *StructField) string { default: if IsByteArrayOrSlice(dataValue) { sqlType = "bytea" - } else if isUUID(dataValue) { - sqlType = "uuid" + + if isUUID(dataValue) { + sqlType = "uuid" + } + + if isJSON(dataValue) { + sqlType = "jsonb" + } } } } @@ -85,7 +93,7 @@ func (s *postgres) DataTypeOf(field *StructField) string { func (s postgres) HasIndex(tableName string, indexName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count) + s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count) return count > 0 } @@ -97,13 +105,13 @@ func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { func (s postgres) HasTable(tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count) return count > 0 } func (s postgres) HasColumn(tableName string, columnName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count) return count > 0 } @@ -128,3 +136,8 @@ func isUUID(value reflect.Value) bool { lower := strings.ToLower(typename) return "uuid" == lower || "guid" == lower } + +func isJSON(value reflect.Value) bool { + _, ok := value.Interface().(json.RawMessage) + return ok +} diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index de9c05cb..f26f6be3 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -28,14 +28,14 @@ func (s *sqlite3) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "bool" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "integer primary key autoincrement" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint64: - if field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "integer primary key autoincrement" } else { diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index de2ae7ca..1dd5fb69 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -54,7 +54,7 @@ func (mssql) BindVar(i int) string { } func (mssql) Quote(key string) string { - return fmt.Sprintf(`"%s"`, key) + return fmt.Sprintf(`[%s]`, key) } func (s *mssql) DataTypeOf(field *gorm.StructField) string { @@ -65,14 +65,14 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { case reflect.Bool: sqlType = "bit" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "int IDENTITY(1,1)" } else { sqlType = "int" } case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigint IDENTITY(1,1)" } else { @@ -111,6 +111,13 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } +func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool { + if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + return value != "FALSE" + } + return field.IsPrimaryKey +} + func (s mssql) HasIndex(tableName string, indexName string) bool { var count int s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) @@ -128,16 +135,23 @@ func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { func (s mssql) HasTable(tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count) return count > 0 } func (s mssql) HasColumn(tableName string, columnName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) return count > 0 } +func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ)) + return err +} + func (s mssql) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) return @@ -168,3 +182,11 @@ func (mssql) SelectFromDummyTable() string { func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } + +func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { + if strings.Contains(tableName, ".") { + splitStrings := strings.SplitN(tableName, ".", 2) + return splitStrings[0], splitStrings[1] + } + return dialect.CurrentDatabase(), tableName +} diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index adeeec7b..1d0dcb60 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -6,6 +6,9 @@ import ( _ "github.com/lib/pq" "github.com/lib/pq/hstore" + "encoding/json" + "errors" + "fmt" ) type Hstore map[string]*string @@ -52,3 +55,26 @@ func (h *Hstore) Scan(value interface{}) error { return nil } + +// Jsonb Postgresql's JSONB data type +type Jsonb struct { + json.RawMessage +} + +// Value get value of Jsonb +func (j Jsonb) Value() (driver.Value, error) { + if len(j.RawMessage) == 0 { + return nil, nil + } + return j.MarshalJSON() +} + +// Scan scan value into Jsonb +func (j *Jsonb) Scan(value interface{}) error { + bytes, ok := value.([]byte) + if !ok { + return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value)) + } + + return json.Unmarshal(bytes, j) +} diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..79bf5fc3 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,30 @@ +version: '3' + +services: + mysql: + image: 'mysql:latest' + ports: + - 9910:3306 + environment: + - MYSQL_DATABASE=gorm + - MYSQL_USER=gorm + - MYSQL_PASSWORD=gorm + - MYSQL_RANDOM_ROOT_PASSWORD="yes" + postgres: + image: 'postgres:latest' + ports: + - 9920:5432 + environment: + - POSTGRES_USER=gorm + - POSTGRES_DB=gorm + - POSTGRES_PASSWORD=gorm + mssql: + image: 'mcmoe/mssqldocker:latest' + ports: + - 9930:1433 + environment: + - ACCEPT_EULA=Y + - SA_PASSWORD=LoremIpsum86 + - MSSQL_DB=gorm + - MSSQL_USER=gorm + - MSSQL_PASSWORD=LoremIpsum86 diff --git a/errors.go b/errors.go index 832fa9b0..6845188e 100644 --- a/errors.go +++ b/errors.go @@ -29,6 +29,10 @@ func (errs Errors) GetErrors() []error { // Add adds an error func (errs Errors) Add(newErrors ...error) Errors { for _, err := range newErrors { + if err == nil { + continue + } + if errors, ok := err.(Errors); ok { errs = errs.Add(errors...) } else { diff --git a/join_table_handler.go b/join_table_handler.go index 2d1a5055..f07541ba 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -82,38 +82,40 @@ func (s JoinTableHandler) Table(db *DB) string { return s.TableName } -func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} { - values := map[string]interface{}{} - +func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) { for _, source := range sources { scope := db.NewScope(source) modelType := scope.GetModelStruct().ModelType - if s.Source.ModelType == modelType { - for _, foreignKey := range s.Source.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - values[foreignKey.DBName] = field.Field.Interface() - } - } - } else if s.Destination.ModelType == modelType { - for _, foreignKey := range s.Destination.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - values[foreignKey.DBName] = field.Field.Interface() + for _, joinTableSource := range joinTableSources { + if joinTableSource.ModelType == modelType { + for _, foreignKey := range joinTableSource.ForeignKeys { + if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { + conditionMap[foreignKey.DBName] = field.Field.Interface() + } } + break } } } - return values } // Add create relationship in join table for source and destination func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { - scope := db.NewScope("") - searchMap := s.getSearchMap(db, source, destination) + var ( + scope = db.NewScope("") + conditionMap = map[string]interface{}{} + ) + + // Update condition map for source + s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source) + + // Update condition map for destination + s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination) var assignColumns, binVars, conditions []string var values []interface{} - for key, value := range searchMap { + for key, value := range conditionMap { assignColumns = append(assignColumns, scope.Quote(key)) binVars = append(binVars, `?`) conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) @@ -141,12 +143,15 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source // Delete delete relationship in join table for sources func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { var ( - scope = db.NewScope(nil) - conditions []string - values []interface{} + scope = db.NewScope(nil) + conditions []string + values []interface{} + conditionMap = map[string]interface{}{} ) - for key, value := range s.getSearchMap(db, sources...) { + s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...) + + for key, value := range conditionMap { conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) values = append(values, value) } diff --git a/main.go b/main.go index 16fa0b79..fc4859ac 100644 --- a/main.go +++ b/main.go @@ -274,7 +274,7 @@ func (s *DB) Assign(attrs ...interface{}) *DB { // First find first record that match given conditions, order by primary key func (s *DB) First(out interface{}, where ...interface{}) *DB { - newScope := s.clone().NewScope(out) + newScope := s.NewScope(out) newScope.Search.Limit(1) return newScope.Set("gorm:order_by_primary_key", "ASC"). inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db @@ -282,7 +282,7 @@ func (s *DB) First(out interface{}, where ...interface{}) *DB { // Last find last record that match given conditions, order by primary key func (s *DB) Last(out interface{}, where ...interface{}) *DB { - newScope := s.clone().NewScope(out) + newScope := s.NewScope(out) newScope.Search.Limit(1) return newScope.Set("gorm:order_by_primary_key", "DESC"). inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db @@ -290,12 +290,12 @@ func (s *DB) Last(out interface{}, where ...interface{}) *DB { // Find find records that match given conditions func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db + return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } // Scan scan value to a struct func (s *DB) Scan(dest interface{}) *DB { - return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db + return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db } // Row return `*sql.Row` with given conditions @@ -311,8 +311,8 @@ func (s *DB) Rows() (*sql.Rows, error) { // ScanRows scan `*sql.Rows` to give struct func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { var ( - clone = s.clone() - scope = clone.NewScope(result) + scope = s.NewScope(result) + clone = scope.db columns, err = rows.Columns() ) @@ -337,7 +337,7 @@ func (s *DB) Count(value interface{}) *DB { // Related get related associations func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.clone().NewScope(s.Value).related(value, foreignKeys...).db + return s.NewScope(s.Value).related(value, foreignKeys...).db } // FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) @@ -377,7 +377,7 @@ func (s *DB) Update(attrs ...interface{}) *DB { // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { - return s.clone().NewScope(s.Value). + return s.NewScope(s.Value). Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). InstanceSet("gorm:update_interface", values). callCallbacks(s.parent.callbacks.updates).db @@ -390,7 +390,7 @@ func (s *DB) UpdateColumn(attrs ...interface{}) *DB { // UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (s *DB) UpdateColumns(values interface{}) *DB { - return s.clone().NewScope(s.Value). + return s.NewScope(s.Value). Set("gorm:update_column", true). Set("gorm:save_associations", false). InstanceSet("gorm:update_interface", values). @@ -399,7 +399,7 @@ func (s *DB) UpdateColumns(values interface{}) *DB { // Save update value in database, if the value doesn't have primary key, will insert it func (s *DB) Save(value interface{}) *DB { - scope := s.clone().NewScope(value) + scope := s.NewScope(value) if !scope.PrimaryKeyZero() { newDB := scope.callCallbacks(s.parent.callbacks.updates).db if newDB.Error == nil && newDB.RowsAffected == 0 { @@ -412,13 +412,13 @@ func (s *DB) Save(value interface{}) *DB { // Create insert the value into database func (s *DB) Create(value interface{}) *DB { - scope := s.clone().NewScope(value) + scope := s.NewScope(value) return scope.callCallbacks(s.parent.callbacks.creates).db } // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition func (s *DB) Delete(value interface{}, where ...interface{}) *DB { - return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db + return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db } // Raw use raw sql as conditions, won't run it unless invoked by other methods @@ -429,7 +429,7 @@ func (s *DB) Raw(sql string, values ...interface{}) *DB { // Exec execute raw sql func (s *DB) Exec(sql string, values ...interface{}) *DB { - scope := s.clone().NewScope(nil) + scope := s.NewScope(nil) generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") scope.Raw(generatedSQL) @@ -495,7 +495,7 @@ func (s *DB) Rollback() *DB { // NewRecord check if value's primary key is blank func (s *DB) NewRecord(value interface{}) bool { - return s.clone().NewScope(value).PrimaryKeyZero() + return s.NewScope(value).PrimaryKeyZero() } // RecordNotFound check if returning ErrRecordNotFound error @@ -544,7 +544,7 @@ func (s *DB) DropTableIfExists(values ...interface{}) *DB { // HasTable check has table or not func (s *DB) HasTable(value interface{}) bool { var ( - scope = s.clone().NewScope(value) + scope = s.NewScope(value) tableName string ) @@ -570,14 +570,14 @@ func (s *DB) AutoMigrate(values ...interface{}) *DB { // ModifyColumn modify column to type func (s *DB) ModifyColumn(column string, typ string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.modifyColumn(column, typ) return scope.db } // DropColumn drop a column func (s *DB) DropColumn(column string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.dropColumn(column) return scope.db } @@ -598,7 +598,7 @@ func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { // RemoveIndex remove index with name func (s *DB) RemoveIndex(indexName string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.removeIndex(indexName) return scope.db } @@ -606,11 +606,19 @@ func (s *DB) RemoveIndex(indexName string) *DB { // AddForeignKey Add foreign key to the given scope, e.g: // db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.addForeignKey(field, dest, onDelete, onUpdate) return scope.db } +// RemoveForeignKey Remove foreign key from the given scope, e.g: +// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") +func (s *DB) RemoveForeignKey(field string, dest string) *DB { + scope := s.clone().NewScope(s.Value) + scope.removeForeignKey(field, dest) + return scope.db +} + // Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode func (s *DB) Association(column string) *Association { var err error diff --git a/main_test.go b/main_test.go index 34f96a86..83e6f7aa 100644 --- a/main_test.go +++ b/main_test.go @@ -36,27 +36,20 @@ func init() { } func OpenTestConnection() (db *gorm.DB, err error) { + dbDSN := os.Getenv("GORM_DSN") switch os.Getenv("GORM_DIALECT") { case "mysql": - // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; - // CREATE DATABASE gorm; - // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; fmt.Println("testing mysql...") - dbhost := os.Getenv("GORM_DBADDRESS") - if dbhost != "" { - dbhost = fmt.Sprintf("tcp(%v)", dbhost) + if dbDSN == "" { + dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" } - db, err = gorm.Open("mysql", fmt.Sprintf("gorm:gorm@%v/gorm?charset=utf8&parseTime=True", dbhost)) + db, err = gorm.Open("mysql", dbDSN) case "postgres": fmt.Println("testing postgres...") - dbhost := os.Getenv("GORM_DBHOST") - if dbhost != "" { - dbhost = fmt.Sprintf("host=%v ", dbhost) + if dbDSN == "" { + dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" } - db, err = gorm.Open("postgres", fmt.Sprintf("%vuser=gorm password=gorm DB.name=gorm sslmode=disable", dbhost)) - case "foundation": - fmt.Println("testing foundation...") - db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable") + db, err = gorm.Open("postgres", dbDSN) case "mssql": // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; // CREATE DATABASE gorm; @@ -64,7 +57,10 @@ func OpenTestConnection() (db *gorm.DB, err error) { // CREATE USER gorm FROM LOGIN gorm; // sp_changedbowner 'gorm'; fmt.Println("testing mssql...") - db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm") + if dbDSN == "" { + dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + } + db, err = gorm.Open("mssql", dbDSN) default: fmt.Println("testing sqlite3...") db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) @@ -72,8 +68,10 @@ func OpenTestConnection() (db *gorm.DB, err error) { // db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) // db.SetLogger(log.New(os.Stdout, "\r\n", 0)) - if os.Getenv("DEBUG") == "true" { + if debug := os.Getenv("DEBUG"); debug == "true" { db.LogMode(true) + } else if debug == "false" { + db.LogMode(false) } db.DB().SetMaxIdleConns(10) diff --git a/migration_test.go b/migration_test.go index 9fc14fa0..6b4470a6 100644 --- a/migration_test.go +++ b/migration_test.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "errors" "fmt" + "os" "reflect" "testing" "time" @@ -432,3 +433,24 @@ func TestMultipleIndexes(t *testing.T) { t.Error("MultipleIndexes unique index failed") } } + +func TestModifyColumnType(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" { + t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type") + } + + type ModifyColumnType struct { + gorm.Model + Name1 string `gorm:"length:100"` + Name2 string `gorm:"length:200"` + } + DB.DropTable(&ModifyColumnType{}) + DB.CreateTable(&ModifyColumnType{}) + + name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2") + name2Type := DB.Dialect().DataTypeOf(name2Field.StructField) + + if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil { + t.Errorf("No error should happen when ModifyColumn, but got %v", err) + } +} diff --git a/model_struct.go b/model_struct.go index 315028c4..f571e2e8 100644 --- a/model_struct.go +++ b/model_struct.go @@ -249,11 +249,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { ) if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") + foreignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") + if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { + associationForeignKeys = strings.Split(foreignKey, ",") + } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + associationForeignKeys = strings.Split(foreignKey, ",") } for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { @@ -264,37 +266,65 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { relationship.Kind = "many_to_many" - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, field.DBName) + { // Foreign Keys for Source + joinTableDBNames := []string{} + + if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" { + joinTableDBNames = strings.Split(foreignKey, ",") + } + + // if no foreign keys defined with tag + if len(foreignKeys) == 0 { + for _, field := range modelStruct.PrimaryFields { + foreignKeys = append(foreignKeys, field.DBName) + } + } + + for idx, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { + // source foreign keys (db names) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) + + // setup join table foreign keys for source + if len(joinTableDBNames) > idx { + // if defined join table's foreign key + relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) + } else { + defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName + relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) + } + } } } - for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - // source foreign keys (db names) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) - // join table foreign keys for source - joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName - relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) - } - } + { // Foreign Keys for Association (Destination) + associationJoinTableDBNames := []string{} - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range toScope.PrimaryFields() { - associationForeignKeys = append(associationForeignKeys, field.DBName) + if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" { + associationJoinTableDBNames = strings.Split(foreignKey, ",") } - } - for _, name := range associationForeignKeys { - if field, ok := toScope.FieldByName(name); ok { - // association foreign keys (db names) - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) - // join table foreign keys for association - joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + // if no association foreign keys defined with tag + if len(associationForeignKeys) == 0 { + for _, field := range toScope.PrimaryFields() { + associationForeignKeys = append(associationForeignKeys, field.DBName) + } + } + + for idx, name := range associationForeignKeys { + if field, ok := toScope.FieldByName(name); ok { + // association foreign keys (db names) + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) + + // setup join table foreign keys for association + if len(associationJoinTableDBNames) > idx { + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) + } else { + // join table foreign keys for association + joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + } + } } } @@ -399,11 +429,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { ) if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") + tagForeignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") + if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { + tagAssociationForeignKeys = strings.Split(foreignKey, ",") + } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + tagAssociationForeignKeys = strings.Split(foreignKey, ",") } if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { diff --git a/preload_test.go b/preload_test.go index 1b89e77b..311ad0be 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1627,6 +1627,48 @@ func TestPrefixedPreloadDuplication(t *testing.T) { } } +func TestPreloadManyToManyCallbacks(t *testing.T) { + type ( + Level2 struct { + ID uint + Name string + } + Level1 struct { + ID uint + Name string + Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"` + } + ) + + DB.DropTableIfExists("level1_level2s") + DB.DropTableIfExists(new(Level1)) + DB.DropTableIfExists(new(Level2)) + + if err := DB.AutoMigrate(new(Level1), new(Level2)).Error; err != nil { + t.Error(err) + } + + lvl := Level1{ + Name: "l1", + Level2s: []Level2{ + Level2{Name: "l2-1"}, Level2{Name: "l2-2"}, + }, + } + DB.Save(&lvl) + + called := 0 + + DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) { + called = called + 1 + }) + + DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID) + + if called != 3 { + t.Errorf("Wanted callback to be called 3 times but got %d", called) + } +} + func toJSONString(v interface{}) []byte { r, _ := json.MarshalIndent(v, "", " ") return r diff --git a/query_test.go b/query_test.go index def84e04..98721800 100644 --- a/query_test.go +++ b/query_test.go @@ -389,7 +389,7 @@ func TestOffset(t *testing.T) { DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) } var users1, users2, users3, users4 []User - DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) + DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { t.Errorf("Offset should work") diff --git a/scope.go b/scope.go index 6b8ce53f..63bf618f 100644 --- a/scope.go +++ b/scope.go @@ -998,18 +998,6 @@ func (scope *Scope) changeableField(field *Field) bool { return true } -func (scope *Scope) shouldSaveAssociations() bool { - if saveAssociations, ok := scope.Get("gorm:save_associations"); ok { - if v, ok := saveAssociations.(bool); ok && !v { - return false - } - if v, ok := saveAssociations.(string); ok && (v != "skip") { - return false - } - } - return true && !scope.HasError() -} - func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { toScope := scope.db.NewScope(value) tx := scope.db.Set("gorm:association:source", scope.Value) @@ -1144,7 +1132,7 @@ func (scope *Scope) dropTable() *Scope { } func (scope *Scope) modifyColumn(column string, typ string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() + scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ)) } func (scope *Scope) dropColumn(column string) { @@ -1170,7 +1158,8 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { } func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { - keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest) + // Compatible with old generated key + keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return @@ -1179,6 +1168,16 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() } +func (scope *Scope) removeForeignKey(field string, dest string) { + keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest) + + if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { + return + } + var query = `ALTER TABLE %s DROP CONSTRAINT %s;` + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() +} + func (scope *Scope) removeIndex(indexName string) { scope.Dialect().RemoveIndex(scope.TableName(), indexName) } @@ -1214,7 +1213,7 @@ func (scope *Scope) autoIndex() *Scope { for _, name := range names { if name == "INDEX" || name == "" { - name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) + name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName) } indexes[name] = append(indexes[name], field.DBName) } @@ -1225,7 +1224,7 @@ func (scope *Scope) autoIndex() *Scope { for _, name := range names { if name == "UNIQUE_INDEX" || name == "" { - name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) + name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName) } uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) } @@ -1233,11 +1232,15 @@ func (scope *Scope) autoIndex() *Scope { } for name, columns := range indexes { - scope.NewDB().Model(scope.Value).AddIndex(name, columns...) + if db := scope.NewDB().Model(scope.Value).AddIndex(name, columns...); db.Error != nil { + scope.db.AddError(db.Error) + } } for name, columns := range uniqueIndexes { - scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) + if db := scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil { + scope.db.AddError(db.Error) + } } return scope diff --git a/search.go b/search.go index 2e273584..90138595 100644 --- a/search.go +++ b/search.go @@ -2,7 +2,6 @@ package gorm import ( "fmt" - "regexp" ) type search struct { @@ -73,13 +72,7 @@ func (s *search) Order(value interface{}, reorder ...bool) *search { return s } -var distinctSQLRegexp = regexp.MustCompile(`(?i)distinct[^a-z]+[a-z]+`) - func (s *search) Select(query interface{}, args ...interface{}) *search { - if distinctSQLRegexp.MatchString(fmt.Sprint(query)) { - s.ignoreOrderQuery = true - } - s.selects = map[string]interface{}{"query": query, "args": args} return s } diff --git a/test_all.sh b/test_all.sh index 80b319bf..5cfb3321 100755 --- a/test_all.sh +++ b/test_all.sh @@ -1,5 +1,5 @@ dialects=("postgres" "mysql" "mssql" "sqlite") for dialect in "${dialects[@]}" ; do - GORM_DIALECT=${dialect} go test + DEBUG=false GORM_DIALECT=${dialect} go test done diff --git a/utils.go b/utils.go index 97a3d175..dfaae939 100644 --- a/utils.go +++ b/utils.go @@ -23,7 +23,7 @@ var NowFunc = func() time.Time { } // Copied from golint -var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} +var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialismsReplacer *strings.Replacer var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) diff --git a/wercker.yml b/wercker.yml index ff6fb17c..2f2370b3 100644 --- a/wercker.yml +++ b/wercker.yml @@ -2,17 +2,79 @@ box: golang services: - - id: mariadb:10.0 + - name: mariadb + id: mariadb:latest env: MYSQL_DATABASE: gorm MYSQL_USER: gorm MYSQL_PASSWORD: gorm MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - id: postgres + - name: mysql + id: mysql:8 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql57 + id: mysql:5.7 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql56 + id: mysql:5.6 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql55 + id: mysql:5.5 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: postgres + id: postgres:latest env: POSTGRES_USER: gorm POSTGRES_PASSWORD: gorm POSTGRES_DB: gorm + - name: postgres96 + id: postgres:9.6 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres95 + id: postgres:9.5 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres94 + id: postgres:9.4 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres93 + id: postgres:9.3 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: mssql + id: mcmoe/mssqldocker:latest + env: + ACCEPT_EULA: Y + SA_PASSWORD: LoremIpsum86 + MSSQL_DB: gorm + MSSQL_USER: gorm + MSSQL_PASSWORD: LoremIpsum86 # The steps that will be executed in the build pipeline build: @@ -42,12 +104,57 @@ build: code: | go test ./... + - script: + name: test mariadb + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test ./... + - script: name: test mysql code: | - GORM_DIALECT=mysql GORM_DBADDRESS=mariadb:3306 go test ./... + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test ./... + + - script: + name: test mysql5.7 + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test ./... + + - script: + name: test mysql5.6 + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test ./... + + - script: + name: test mysql5.5 + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test ./... - script: name: test postgres code: | - GORM_DIALECT=postgres GORM_DBHOST=postgres go test ./... + GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test postgres96 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test postgres95 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test postgres94 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test postgres93 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test mssql + code: | + GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test ./...