From 6f58f8a52cc3ad21950402d1adaa09682e07ec2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adem=20=C3=96zay?= Date: Mon, 10 Sep 2018 00:52:20 +0300 Subject: [PATCH 01/17] added naming strategy option for db, table and column names (#2040) --- model_struct.go | 12 ++--- naming.go | 124 ++++++++++++++++++++++++++++++++++++++++++++++++ naming_test.go | 69 +++++++++++++++++++++++++++ scope.go | 4 +- utils.go | 61 ------------------------ utils_test.go | 35 -------------- 6 files changed, 201 insertions(+), 104 deletions(-) create mode 100644 naming.go create mode 100644 naming_test.go delete mode 100644 utils_test.go diff --git a/model_struct.go b/model_struct.go index 8506fe87..5b5be618 100644 --- a/model_struct.go +++ b/model_struct.go @@ -34,7 +34,7 @@ func (s *ModelStruct) TableName(db *DB) string { if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { s.defaultTableName = tabler.TableName() } else { - tableName := ToDBName(s.ModelType.Name()) + tableName := ToTableName(s.ModelType.Name()) if db == nil || !db.parent.singularTable { tableName = inflection.Plural(tableName) } @@ -105,7 +105,7 @@ type Relationship struct { func getForeignField(column string, fields []*StructField) *StructField { for _, field := range fields { - if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) { + if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) { return field } } @@ -269,7 +269,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { // if defined join table's foreign key relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) } else { - defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName + defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) } } @@ -300,7 +300,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) } else { // join table foreign keys for association - joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) } } @@ -308,7 +308,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, many2many, reflectType, elemType) + joinTableHandler.Setup(relationship, ToTableName(many2many), reflectType, elemType) relationship.JoinTableHandler = &joinTableHandler field.Relationship = relationship } else { @@ -566,7 +566,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if value, ok := field.TagSettings["COLUMN"]; ok { field.DBName = value } else { - field.DBName = ToDBName(fieldStruct.Name) + field.DBName = ToColumnName(fieldStruct.Name) } modelStruct.StructFields = append(modelStruct.StructFields, field) diff --git a/naming.go b/naming.go new file mode 100644 index 00000000..6b0a4fdd --- /dev/null +++ b/naming.go @@ -0,0 +1,124 @@ +package gorm + +import ( + "bytes" + "strings" +) + +// Namer is a function type which is given a string and return a string +type Namer func(string) string + +// NamingStrategy represents naming strategies +type NamingStrategy struct { + DB Namer + Table Namer + Column Namer +} + +// TheNamingStrategy is being initialized with defaultNamingStrategy +var TheNamingStrategy = &NamingStrategy{ + DB: defaultNamer, + Table: defaultNamer, + Column: defaultNamer, +} + +// AddNamingStrategy sets the naming strategy +func AddNamingStrategy(ns *NamingStrategy) { + if ns.DB == nil { + ns.DB = defaultNamer + } + if ns.Table == nil { + ns.Table = defaultNamer + } + if ns.Column == nil { + ns.Column = defaultNamer + } + TheNamingStrategy = ns +} + +// DBName alters the given name by DB +func (ns *NamingStrategy) DBName(name string) string { + return ns.DB(name) +} + +// TableName alters the given name by Table +func (ns *NamingStrategy) TableName(name string) string { + return ns.Table(name) +} + +// ColumnName alters the given name by Column +func (ns *NamingStrategy) ColumnName(name string) string { + return ns.Column(name) +} + +// ToDBName convert string to db name +func ToDBName(name string) string { + return TheNamingStrategy.DBName(name) +} + +// ToTableName convert string to table name +func ToTableName(name string) string { + return TheNamingStrategy.TableName(name) +} + +// ToColumnName convert string to db name +func ToColumnName(name string) string { + return TheNamingStrategy.ColumnName(name) +} + +var smap = newSafeMap() + +func defaultNamer(name string) string { + const ( + lower = false + upper = true + ) + + if v := smap.Get(name); v != "" { + return v + } + + if name == "" { + return "" + } + + var ( + value = commonInitialismsReplacer.Replace(name) + buf = bytes.NewBufferString("") + lastCase, currCase, nextCase, nextNumber bool + ) + + for i, v := range value[:len(value)-1] { + nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z') + nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9') + + if i > 0 { + if currCase == upper { + if lastCase == upper && (nextCase == upper || nextNumber == upper) { + buf.WriteRune(v) + } else { + if value[i-1] != '_' && value[i+1] != '_' { + buf.WriteRune('_') + } + buf.WriteRune(v) + } + } else { + buf.WriteRune(v) + if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { + buf.WriteRune('_') + } + } + } else { + currCase = upper + buf.WriteRune(v) + } + lastCase = currCase + currCase = nextCase + } + + buf.WriteByte(value[len(value)-1]) + + s := strings.ToLower(buf.String()) + smap.Set(name, s) + return s +} diff --git a/naming_test.go b/naming_test.go new file mode 100644 index 00000000..0c6f7713 --- /dev/null +++ b/naming_test.go @@ -0,0 +1,69 @@ +package gorm_test + +import ( + "testing" + + "github.com/jinzhu/gorm" +) + +func TestTheNamingStrategy(t *testing.T) { + + cases := []struct { + name string + namer gorm.Namer + expected string + }{ + {name: "auth", expected: "auth", namer: gorm.TheNamingStrategy.DB}, + {name: "userRestrictions", expected: "user_restrictions", namer: gorm.TheNamingStrategy.Table}, + {name: "clientID", expected: "client_id", namer: gorm.TheNamingStrategy.Column}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := c.namer(c.name) + if result != c.expected { + t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) + } + }) + } + +} + +func TestNamingStrategy(t *testing.T) { + + dbNameNS := func(name string) string { + return "db_" + name + } + tableNameNS := func(name string) string { + return "tbl_" + name + } + columnNameNS := func(name string) string { + return "col_" + name + } + + ns := &gorm.NamingStrategy{ + DB: dbNameNS, + Table: tableNameNS, + Column: columnNameNS, + } + + cases := []struct { + name string + namer gorm.Namer + expected string + }{ + {name: "auth", expected: "db_auth", namer: ns.DB}, + {name: "user", expected: "tbl_user", namer: ns.Table}, + {name: "password", expected: "col_password", namer: ns.Column}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := c.namer(c.name) + if result != c.expected { + t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) + } + }) + } + +} diff --git a/scope.go b/scope.go index ca861d8a..fbf7634e 100644 --- a/scope.go +++ b/scope.go @@ -134,7 +134,7 @@ func (scope *Scope) Fields() []*Field { // FieldByName find `gorm.Field` with field name or db name func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { var ( - dbName = ToDBName(name) + dbName = ToColumnName(name) mostMatchedField *Field ) @@ -880,7 +880,7 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string switch reflectValue.Kind() { case reflect.Map: for _, key := range reflectValue.MapKeys() { - attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() + attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() } default: for _, field := range (&Scope{Value: values}).Fields() { diff --git a/utils.go b/utils.go index 99b532c5..ad700b98 100644 --- a/utils.go +++ b/utils.go @@ -1,7 +1,6 @@ package gorm import ( - "bytes" "database/sql/driver" "fmt" "reflect" @@ -58,66 +57,6 @@ func newSafeMap() *safeMap { return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} } -var smap = newSafeMap() - -type strCase bool - -const ( - lower strCase = false - upper strCase = true -) - -// ToDBName convert string to db name -func ToDBName(name string) string { - if v := smap.Get(name); v != "" { - return v - } - - if name == "" { - return "" - } - - var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase, nextNumber strCase - ) - - for i, v := range value[:len(value)-1] { - nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z') - nextNumber = strCase(value[i+1] >= '0' && value[i+1] <= '9') - - if i > 0 { - if currCase == upper { - if lastCase == upper && (nextCase == upper || nextNumber == upper) { - buf.WriteRune(v) - } else { - if value[i-1] != '_' && value[i+1] != '_' { - buf.WriteRune('_') - } - buf.WriteRune(v) - } - } else { - buf.WriteRune(v) - if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { - buf.WriteRune('_') - } - } - } else { - currCase = upper - buf.WriteRune(v) - } - lastCase = currCase - currCase = nextCase - } - - buf.WriteByte(value[len(value)-1]) - - s := strings.ToLower(buf.String()) - smap.Set(name, s) - return s -} - // SQL expression type expr struct { expr string diff --git a/utils_test.go b/utils_test.go deleted file mode 100644 index 086c4450..00000000 --- a/utils_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package gorm_test - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -func TestToDBNameGenerateFriendlyName(t *testing.T) { - var maps = map[string]string{ - "": "", - "X": "x", - "ThisIsATest": "this_is_a_test", - "PFAndESI": "pf_and_esi", - "AbcAndJkl": "abc_and_jkl", - "EmployeeID": "employee_id", - "SKU_ID": "sku_id", - "UTF8": "utf8", - "Level1": "level1", - "SHA256Hash": "sha256_hash", - "FieldX": "field_x", - "HTTPAndSMTP": "http_and_smtp", - "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", - "UUID": "uuid", - "HTTPURL": "http_url", - "HTTP_URL": "http_url", - "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", - } - - for key, value := range maps { - if gorm.ToDBName(key) != value { - t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key)) - } - } -} From dc3b2476c4eb61c37424a1ca2f46859e4e6fcd81 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 10 Sep 2018 06:03:41 +0800 Subject: [PATCH 02/17] Don't save ignored fields into database --- callback_create.go | 2 +- scope.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/callback_create.go b/callback_create.go index e7fe6f86..2ab05d3b 100644 --- a/callback_create.go +++ b/callback_create.go @@ -59,7 +59,7 @@ func createCallback(scope *Scope) { for _, field := range scope.Fields() { if scope.changeableField(field) { - if field.IsNormal { + if field.IsNormal && !field.IsIgnored { if field.IsBlank && field.HasDefaultValue { blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) diff --git a/scope.go b/scope.go index fbf7634e..7d6ba1c0 100644 --- a/scope.go +++ b/scope.go @@ -907,7 +907,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin results[field.DBName] = value } else { err := field.Set(value) - if field.IsNormal { + if field.IsNormal && !field.IsIgnored { hasUpdate = true if err == ErrUnaddressable { results[field.DBName] = value From 71b7f19aad77eaf99a90324c7d2ac5634eaefca8 Mon Sep 17 00:00:00 2001 From: Xy Ziemba Date: Sun, 9 Sep 2018 15:12:58 -0700 Subject: [PATCH 03/17] Fix scanning identical column names occurring >2 times (#2080) Fix the indexing logic used in selectedColumnsMap to skip fields that have already been seen. The values of selectedColumns map must be indexed relative to fields, not relative to selectFields. --- main_test.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++ scope.go | 6 ++++-- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/main_test.go b/main_test.go index 265e0be7..11c4bb87 100644 --- a/main_test.go +++ b/main_test.go @@ -581,6 +581,60 @@ func TestJoins(t *testing.T) { } } +type JoinedIds struct { + UserID int64 `gorm:"column:id"` + BillingAddressID int64 `gorm:"column:id"` + EmailID int64 `gorm:"column:id"` +} + +func TestScanIdenticalColumnNames(t *testing.T) { + var user = User{ + Name: "joinsIds", + Email: "joinIds@example.com", + BillingAddress: Address{ + Address1: "One Park Place", + }, + Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, + } + DB.Save(&user) + + var users []JoinedIds + DB.Select("users.id, addresses.id, emails.id").Table("users"). + Joins("left join addresses on users.billing_address_id = addresses.id"). + Joins("left join emails on emails.user_id = users.id"). + Where("name = ?", "joinsIds").Scan(&users) + + if len(users) != 2 { + t.Fatal("should find two rows using left join") + } + + if user.Id != users[0].UserID { + t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[0].UserID) + } + if user.Id != users[1].UserID { + t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[1].UserID) + } + + if user.BillingAddressID.Int64 != users[0].BillingAddressID { + t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID) + } + if user.BillingAddressID.Int64 != users[1].BillingAddressID { + t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID) + } + + if users[0].EmailID == users[1].EmailID { + t.Errorf("Email ids should be unique. Got %d and %d", users[0].EmailID, users[1].EmailID) + } + + if int64(user.Emails[0].Id) != users[0].EmailID && int64(user.Emails[1].Id) != users[0].EmailID { + t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[0].EmailID) + } + + if int64(user.Emails[0].Id) != users[1].EmailID && int64(user.Emails[1].Id) != users[1].EmailID { + t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[1].EmailID) + } +} + func TestJoinsWithSelect(t *testing.T) { type result struct { Name string diff --git a/scope.go b/scope.go index 7d6ba1c0..ce80ab86 100644 --- a/scope.go +++ b/scope.go @@ -486,8 +486,10 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { values[index] = &ignored selectFields = fields + offset := 0 if idx, ok := selectedColumnsMap[column]; ok { - selectFields = selectFields[idx+1:] + offset = idx + 1 + selectFields = selectFields[offset:] } for fieldIndex, field := range selectFields { @@ -501,7 +503,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { resetFields[index] = field } - selectedColumnsMap[column] = fieldIndex + selectedColumnsMap[column] = offset + fieldIndex if field.IsNormal { break From 12607e8bdf4a724492d53d8c788edc77ad4439e7 Mon Sep 17 00:00:00 2001 From: kuangzhiqiang Date: Mon, 10 Sep 2018 06:14:05 +0800 Subject: [PATCH 04/17] for go1.11 go mod (#2072) when used go1.11 gomodules the code dir will be `$GOPATH/pkg/mod/github.com/jinzhu/gorm@*/` fileWithLineNum check failed --- utils.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils.go b/utils.go index ad700b98..8489538c 100644 --- a/utils.go +++ b/utils.go @@ -25,8 +25,8 @@ var NowFunc = func() time.Time { var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialismsReplacer *strings.Replacer -var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) -var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`) +var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`) +var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`) func init() { var commonInitialismsForReplacer []string From d3e666a1e086a020905e3f6cf293941806520d97 Mon Sep 17 00:00:00 2001 From: Ikhtiyor <33823221+iahmedov@users.noreply.github.com> Date: Mon, 10 Sep 2018 03:25:26 +0500 Subject: [PATCH 05/17] save_associations:true should store related item (#2067) * save_associations:true should store related item, save_associations priority on related objects * code quality --- callback_save.go | 6 ++-- main_test.go | 88 +++++++++++++++++++++++++++++++++++++++++++++++ migration_test.go | 10 +++++- 3 files changed, 100 insertions(+), 4 deletions(-) diff --git a/callback_save.go b/callback_save.go index ef267141..ebfd0b34 100644 --- a/callback_save.go +++ b/callback_save.go @@ -21,9 +21,7 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea if v, ok := value.(string); ok { v = strings.ToLower(v) - if v == "false" || v != "skip" { - return false - } + return v == "true" } return true @@ -36,9 +34,11 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea if value, ok := scope.Get("gorm:save_associations"); ok { autoUpdate = checkTruth(value) autoCreate = autoUpdate + saveReference = autoUpdate } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok { autoUpdate = checkTruth(value) autoCreate = autoUpdate + saveReference = autoUpdate } if value, ok := scope.Get("gorm:association_autoupdate"); ok { diff --git a/main_test.go b/main_test.go index 11c4bb87..94d2fa39 100644 --- a/main_test.go +++ b/main_test.go @@ -933,6 +933,94 @@ func TestOpenWithOneParameter(t *testing.T) { } } +func TestSaveAssociations(t *testing.T) { + db := DB.New() + deltaAddressCount := 0 + if err := db.Model(&Address{}).Count(&deltaAddressCount).Error; err != nil { + t.Errorf("failed to fetch address count") + t.FailNow() + } + + placeAddress := &Address{ + Address1: "somewhere on earth", + } + ownerAddress1 := &Address{ + Address1: "near place address", + } + ownerAddress2 := &Address{ + Address1: "address2", + } + db.Create(placeAddress) + + addressCountShouldBe := func(t *testing.T, expectedCount int) { + countFromDB := 0 + t.Helper() + err := db.Model(&Address{}).Count(&countFromDB).Error + if err != nil { + t.Error("failed to fetch address count") + } + if countFromDB != expectedCount { + t.Errorf("address count mismatch: %d", countFromDB) + } + } + addressCountShouldBe(t, deltaAddressCount+1) + + // owner address should be created, place address should be reused + place1 := &Place{ + PlaceAddressID: placeAddress.ID, + PlaceAddress: placeAddress, + OwnerAddress: ownerAddress1, + } + err := db.Create(place1).Error + if err != nil { + t.Errorf("failed to store place: %s", err.Error()) + } + addressCountShouldBe(t, deltaAddressCount+2) + + // owner address should be created again, place address should be reused + place2 := &Place{ + PlaceAddressID: placeAddress.ID, + PlaceAddress: &Address{ + ID: 777, + Address1: "address1", + }, + OwnerAddress: ownerAddress2, + OwnerAddressID: 778, + } + err = db.Create(place2).Error + if err != nil { + t.Errorf("failed to store place: %s", err.Error()) + } + addressCountShouldBe(t, deltaAddressCount+3) + + count := 0 + db.Model(&Place{}).Where(&Place{ + PlaceAddressID: placeAddress.ID, + OwnerAddressID: ownerAddress1.ID, + }).Count(&count) + if count != 1 { + t.Errorf("only one instance of (%d, %d) should be available, found: %d", + placeAddress.ID, ownerAddress1.ID, count) + } + + db.Model(&Place{}).Where(&Place{ + PlaceAddressID: placeAddress.ID, + OwnerAddressID: ownerAddress2.ID, + }).Count(&count) + if count != 1 { + t.Errorf("only one instance of (%d, %d) should be available, found: %d", + placeAddress.ID, ownerAddress2.ID, count) + } + + db.Model(&Place{}).Where(&Place{ + PlaceAddressID: placeAddress.ID, + }).Count(&count) + if count != 2 { + t.Errorf("two instances of (%d) should be available, found: %d", + placeAddress.ID, count) + } +} + func TestBlockGlobalUpdate(t *testing.T) { db := DB.New() db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"}) diff --git a/migration_test.go b/migration_test.go index 78555dcc..3fb06648 100644 --- a/migration_test.go +++ b/migration_test.go @@ -118,6 +118,14 @@ type Company struct { Owner *User `sql:"-"` } +type Place struct { + Id int64 + PlaceAddressID int + PlaceAddress *Address `gorm:"save_associations:false"` + OwnerAddressID int + OwnerAddress *Address `gorm:"save_associations:true"` +} + type EncryptedData []byte func (data *EncryptedData) Scan(value interface{}) error { @@ -284,7 +292,7 @@ func runMigration() { DB.Exec(fmt.Sprintf("drop table %v;", table)) } - values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}} + values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}} for _, value := range values { DB.DropTable(value) } From 73e7561e20e8e554ec54463ccbed38e426aad17f Mon Sep 17 00:00:00 2001 From: Aaron Leung Date: Sun, 9 Sep 2018 15:26:29 -0700 Subject: [PATCH 06/17] Use sync.Map for DB.values (#2064) * Replace the regular map with a sync.Map to avoid fatal concurrent map reads/writes * fix the formatting --- main.go | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/main.go b/main.go index 993e19b1..364d8e8e 100644 --- a/main.go +++ b/main.go @@ -22,7 +22,7 @@ type DB struct { logMode int logger logger search *search - values map[string]interface{} + values sync.Map // global db parent *DB @@ -72,7 +72,6 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { db = &DB{ db: dbSQL, logger: defaultLogger, - values: map[string]interface{}{}, callbacks: DefaultCallback, dialect: newDialect(dialect, dbSQL), } @@ -680,13 +679,13 @@ func (s *DB) Set(name string, value interface{}) *DB { // InstantSet instant set setting, will affect current db func (s *DB) InstantSet(name string, value interface{}) *DB { - s.values[name] = value + s.values.Store(name, value) return s } // Get get setting by name func (s *DB) Get(name string) (value interface{}, ok bool) { - value, ok = s.values[name] + value, ok = s.values.Load(name) return } @@ -750,16 +749,16 @@ func (s *DB) clone() *DB { parent: s.parent, logger: s.logger, logMode: s.logMode, - values: map[string]interface{}{}, Value: s.Value, Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate, dialect: newDialect(s.dialect.GetName(), s.db), } - for key, value := range s.values { - db.values[key] = value - } + s.values.Range(func(k, v interface{}) bool { + db.values.Store(k, v) + return true + }) if s.search == nil { db.search = &search{limit: -1, offset: -1} From 012d1479740ec593b0c07f0372e0111c01c3b34a Mon Sep 17 00:00:00 2001 From: maddie Date: Mon, 10 Sep 2018 06:45:55 +0800 Subject: [PATCH 07/17] Improve preload speed (#2058) All credits to @vanjapt who came up with this patch. Closes #1672 --- callback_query_preload.go | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/callback_query_preload.go b/callback_query_preload.go index 481bfbe3..46405c38 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -161,14 +161,17 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) ) if indirectScopeValue.Kind() == reflect.Slice { + foreignValuesToResults := make(map[string]reflect.Value) + for i := 0; i < resultsValue.Len(); i++ { + result := resultsValue.Index(i) + foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames)) + foreignValuesToResults[foreignValues] = result + } for j := 0; j < indirectScopeValue.Len(); j++ { - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { - indirectValue.FieldByName(field.Name).Set(result) - break - } + indirectValue := indirect(indirectScopeValue.Index(j)) + valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames)) + if result, found := foreignValuesToResults[valueString]; found { + indirectValue.FieldByName(field.Name).Set(result) } } } else { @@ -255,13 +258,21 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ indirectScopeValue = scope.IndirectValue() ) + foreignFieldToObjects := make(map[string][]*reflect.Value) + if indirectScopeValue.Kind() == reflect.Slice { + for j := 0; j < indirectScopeValue.Len(); j++ { + object := indirect(indirectScopeValue.Index(j)) + valueString := toString(getValueFromFields(object, relation.ForeignFieldNames)) + foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object) + } + } + for i := 0; i < resultsValue.Len(); i++ { result := resultsValue.Index(i) if indirectScopeValue.Kind() == reflect.Slice { - value := getValueFromFields(result, relation.AssociationForeignFieldNames) - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { + valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames)) + if objects, found := foreignFieldToObjects[valueString]; found { + for _, object := range objects { object.FieldByName(field.Name).Set(result) } } From 26fde9110f932df8cb5cc24396e7a54a6d3a94c2 Mon Sep 17 00:00:00 2001 From: Gustavo Brunoro Date: Sun, 9 Sep 2018 19:47:18 -0300 Subject: [PATCH 08/17] getValueFromFields doesn't panic on nil pointers (#2021) * `IsValid()` won't return `false` for nil pointers unless Value is wrapped in a `reflect.Indirect`. --- utils.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils.go b/utils.go index 8489538c..e58e57a5 100644 --- a/utils.go +++ b/utils.go @@ -206,7 +206,7 @@ func getValueFromFields(value reflect.Value, fieldNames []string) (results []int // as FieldByName could panic if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { for _, fieldName := range fieldNames { - if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { + if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() { result := fieldValue.Interface() if r, ok := result.(driver.Valuer); ok { result, _ = r.Value() From 588b598f9fbf9a0c84b6ec18f617940b045c54d4 Mon Sep 17 00:00:00 2001 From: Phillip Shipley Date: Sun, 9 Sep 2018 18:50:22 -0400 Subject: [PATCH 09/17] Fix issue updating models with foreign key constraints (#1988) * fix update callback to not try to write zero values when field has default value * fix to update callback for gorm tests --- callback_update.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/callback_update.go b/callback_update.go index 373bd726..f6ba0ffd 100644 --- a/callback_update.go +++ b/callback_update.go @@ -76,7 +76,9 @@ func updateCallback(scope *Scope) { for _, field := range scope.Fields() { if scope.changeableField(field) { if !field.IsPrimaryKey && field.IsNormal { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { for _, foreignKey := range relationship.ForeignDBNames { if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { From 282f11af1900a36646b607797273056d76350223 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 9 Sep 2018 19:52:32 -0300 Subject: [PATCH 10/17] Support only preloading (#1926) * add support for only preloading relations on an already populated model * Update callback_query.go comments --- callback_query.go | 5 +++++ main.go | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/callback_query.go b/callback_query.go index ba10cc7d..593e5d30 100644 --- a/callback_query.go +++ b/callback_query.go @@ -18,6 +18,11 @@ func queryCallback(scope *Scope) { if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { return } + + //we are only preloading relations, dont touch base model + if _, skip := scope.InstanceGet("gorm:only_preload"); skip { + return + } defer scope.trace(NowFunc()) diff --git a/main.go b/main.go index 364d8e8e..4dbda61e 100644 --- a/main.go +++ b/main.go @@ -314,6 +314,11 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB { return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } +//Preloads preloads relations, don`t touch out +func (s *DB) Preloads(out interface{}) *DB { + return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db +} + // Scan scan value to a struct func (s *DB) Scan(dest interface{}) *DB { return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db From 123d4f50ef8a8209ee8434daa41c6045a9111864 Mon Sep 17 00:00:00 2001 From: Eyal Posener Date: Mon, 10 Sep 2018 02:11:00 +0300 Subject: [PATCH 11/17] lock TagSettings structure when modified (#1796) The map is modified in different places in the code which results in race conditions on execution. This commit locks the map with read-write lock when it is modified --- callback_query_preload.go | 2 +- callback_save.go | 8 ++--- dialect.go | 10 +++--- dialect_common.go | 2 +- dialect_mysql.go | 22 ++++++------ dialect_postgres.go | 6 ++-- dialect_sqlite3.go | 4 +-- dialects/mssql/mssql.go | 8 ++--- field_test.go | 2 +- main.go | 2 +- model_struct.go | 73 +++++++++++++++++++++++++++------------ scope.go | 12 +++---- 12 files changed, 90 insertions(+), 61 deletions(-) diff --git a/callback_query_preload.go b/callback_query_preload.go index 46405c38..d7c8a133 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -100,7 +100,7 @@ func autoPreload(scope *Scope) { continue } - if val, ok := field.TagSettings["PRELOAD"]; ok { + if val, ok := field.TagSettingsGet("PRELOAD"); ok { if preload, err := strconv.ParseBool(val); err != nil { scope.Err(errors.New("invalid preload option")) return diff --git a/callback_save.go b/callback_save.go index ebfd0b34..3b4e0589 100644 --- a/callback_save.go +++ b/callback_save.go @@ -35,7 +35,7 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea autoUpdate = checkTruth(value) autoCreate = autoUpdate saveReference = autoUpdate - } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok { + } else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok { autoUpdate = checkTruth(value) autoCreate = autoUpdate saveReference = autoUpdate @@ -43,19 +43,19 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea if value, ok := scope.Get("gorm:association_autoupdate"); ok { autoUpdate = checkTruth(value) - } else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok { + } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok { autoUpdate = checkTruth(value) } if value, ok := scope.Get("gorm:association_autocreate"); ok { autoCreate = checkTruth(value) - } else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok { + } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok { autoCreate = checkTruth(value) } if value, ok := scope.Get("gorm:association_save_reference"); ok { saveReference = checkTruth(value) - } else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok { + } else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok { saveReference = checkTruth(value) } } diff --git a/dialect.go b/dialect.go index 506a6e86..27b308af 100644 --- a/dialect.go +++ b/dialect.go @@ -83,7 +83,7 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel // Get redirected field type var ( reflectType = field.Struct.Type - dataType = field.TagSettings["TYPE"] + dataType, _ = field.TagSettingsGet("TYPE") ) for reflectType.Kind() == reflect.Ptr { @@ -112,15 +112,17 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel } // Default Size - if num, ok := field.TagSettings["SIZE"]; ok { + if num, ok := field.TagSettingsGet("SIZE"); ok { size, _ = strconv.Atoi(num) } else { size = 255 } // Default type from tag setting - additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] - if value, ok := field.TagSettings["DEFAULT"]; ok { + notNull, _ := field.TagSettingsGet("NOT NULL") + unique, _ := field.TagSettingsGet("UNIQUE") + additionalType = notNull + " " + unique + if value, ok := field.TagSettingsGet("DEFAULT"); ok { additionalType = additionalType + " DEFAULT " + value } diff --git a/dialect_common.go b/dialect_common.go index b9f0c7da..a479be79 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -39,7 +39,7 @@ func (commonDialect) Quote(key string) string { } func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { - if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { return strings.ToLower(value) != "false" } return field.IsPrimaryKey diff --git a/dialect_mysql.go b/dialect_mysql.go index b162bade..5d63e5cd 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -33,9 +33,9 @@ func (s *mysql) DataTypeOf(field *StructField) string { // MySQL allows only one auto increment column per table, and it must // be a KEY column. - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey { - delete(field.TagSettings, "AUTO_INCREMENT") + if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { + if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey { + field.TagSettingsDelete("AUTO_INCREMENT") } } @@ -45,42 +45,42 @@ func (s *mysql) DataTypeOf(field *StructField) string { sqlType = "boolean" case reflect.Int8: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "tinyint AUTO_INCREMENT" } else { sqlType = "tinyint" } case reflect.Int, reflect.Int16, reflect.Int32: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "int AUTO_INCREMENT" } else { sqlType = "int" } case reflect.Uint8: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "tinyint unsigned AUTO_INCREMENT" } else { sqlType = "tinyint unsigned" } case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "int unsigned AUTO_INCREMENT" } else { sqlType = "int unsigned" } case reflect.Int64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigint AUTO_INCREMENT" } else { sqlType = "bigint" } case reflect.Uint64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigint unsigned AUTO_INCREMENT" } else { sqlType = "bigint unsigned" @@ -96,11 +96,11 @@ func (s *mysql) DataTypeOf(field *StructField) string { case reflect.Struct: if _, ok := dataValue.Interface().(time.Time); ok { precision := "" - if p, ok := field.TagSettings["PRECISION"]; ok { + if p, ok := field.TagSettingsGet("PRECISION"); ok { precision = fmt.Sprintf("(%s)", p) } - if _, ok := field.TagSettings["NOT NULL"]; ok { + if _, ok := field.TagSettingsGet("NOT NULL"); ok { sqlType = fmt.Sprintf("timestamp%v", precision) } else { sqlType = fmt.Sprintf("timestamp%v NULL", precision) diff --git a/dialect_postgres.go b/dialect_postgres.go index c44c6a5b..53d31388 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -34,14 +34,14 @@ func (s *postgres) DataTypeOf(field *StructField) string { sqlType = "boolean" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "serial" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint32, reflect.Uint64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigserial" } else { sqlType = "bigint" @@ -49,7 +49,7 @@ func (s *postgres) DataTypeOf(field *StructField) string { case reflect.Float32, reflect.Float64: sqlType = "numeric" case reflect.String: - if _, ok := field.TagSettings["SIZE"]; !ok { + if _, ok := field.TagSettingsGet("SIZE"); !ok { size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different } diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index f26f6be3..5f96c363 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -29,14 +29,14 @@ func (s *sqlite3) DataTypeOf(field *StructField) string { sqlType = "bool" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "integer primary key autoincrement" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "integer primary key autoincrement" } else { sqlType = "bigint" diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 731721cb..6c424bc1 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -18,7 +18,7 @@ import ( func setIdentityInsert(scope *gorm.Scope) { if scope.Dialect().GetName() == "mssql" { for _, field := range scope.PrimaryFields() { - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank { + if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank { scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) scope.InstanceSet("mssql:identity_insert_on", true) } @@ -70,14 +70,14 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { sqlType = "bit" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "int IDENTITY(1,1)" } else { sqlType = "int" } case reflect.Int64, reflect.Uint64: if s.fieldCanAutoIncrement(field) { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigint IDENTITY(1,1)" } else { sqlType = "bigint" @@ -116,7 +116,7 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { } func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool { - if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { return value != "FALSE" } return field.IsPrimaryKey diff --git a/field_test.go b/field_test.go index 30e9a778..c3afdff5 100644 --- a/field_test.go +++ b/field_test.go @@ -43,7 +43,7 @@ func TestCalculateField(t *testing.T) { if field, ok := scope.FieldByName("embedded_name"); !ok { t.Errorf("should find embedded field") - } else if _, ok := field.TagSettings["NOT NULL"]; !ok { + } else if _, ok := field.TagSettingsGet("NOT NULL"); !ok { t.Errorf("should find embedded field's tag settings") } } diff --git a/main.go b/main.go index 4dbda61e..17c75ed3 100644 --- a/main.go +++ b/main.go @@ -699,7 +699,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join scope := s.NewScope(source) for _, field := range scope.GetModelStruct().StructFields { if field.Name == column || field.DBName == column { - if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { + if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { source := (&Scope{Value: source}).GetModelStruct().ModelType destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType handler.Setup(field.Relationship, many2many, source, destination) diff --git a/model_struct.go b/model_struct.go index 5b5be618..12860e67 100644 --- a/model_struct.go +++ b/model_struct.go @@ -60,6 +60,30 @@ type StructField struct { Struct reflect.StructField IsForeignKey bool Relationship *Relationship + + tagSettingsLock sync.RWMutex +} + +// TagSettingsSet Sets a tag in the tag settings map +func (s *StructField) TagSettingsSet(key, val string) { + s.tagSettingsLock.Lock() + defer s.tagSettingsLock.Unlock() + s.TagSettings[key] = val +} + +// TagSettingsGet returns a tag from the tag settings +func (s *StructField) TagSettingsGet(key string) (string, bool) { + s.tagSettingsLock.RLock() + defer s.tagSettingsLock.RUnlock() + val, ok := s.TagSettings[key] + return val, ok +} + +// TagSettingsDelete deletes a tag +func (s *StructField) TagSettingsDelete(key string) { + s.tagSettingsLock.Lock() + defer s.tagSettingsLock.Unlock() + delete(s.TagSettings, key) } func (structField *StructField) clone() *StructField { @@ -83,6 +107,9 @@ func (structField *StructField) clone() *StructField { clone.Relationship = &relationship } + // copy the struct field tagSettings, they should be read-locked while they are copied + structField.tagSettingsLock.Lock() + defer structField.tagSettingsLock.Unlock() for key, value := range structField.TagSettings { clone.TagSettings[key] = value } @@ -149,19 +176,19 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // is ignored field - if _, ok := field.TagSettings["-"]; ok { + if _, ok := field.TagSettingsGet("-"); ok { field.IsIgnored = true } else { - if _, ok := field.TagSettings["PRIMARY_KEY"]; ok { + if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok { field.IsPrimaryKey = true modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) } - if _, ok := field.TagSettings["DEFAULT"]; ok { + if _, ok := field.TagSettingsGet("DEFAULT"); ok { field.HasDefaultValue = true } - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsPrimaryKey { + if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey { field.HasDefaultValue = true } @@ -177,8 +204,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if indirectType.Kind() == reflect.Struct { for i := 0; i < indirectType.NumField(); i++ { for key, value := range parseTagSetting(indirectType.Field(i).Tag) { - if _, ok := field.TagSettings[key]; !ok { - field.TagSettings[key] = value + if _, ok := field.TagSettingsGet(key); !ok { + field.TagSettingsSet(key, value) } } } @@ -186,17 +213,17 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } else if _, isTime := fieldValue.(*time.Time); isTime { // is time field.IsNormal = true - } else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { + } else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous { // is embedded struct for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields { subField = subField.clone() subField.Names = append([]string{fieldStruct.Name}, subField.Names...) - if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok { + if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok { subField.DBName = prefix + subField.DBName } if subField.IsPrimaryKey { - if _, ok := subField.TagSettings["PRIMARY_KEY"]; ok { + if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok { modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) } else { subField.IsPrimaryKey = false @@ -227,13 +254,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { elemType = field.Struct.Type ) - if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { foreignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { associationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { associationForeignKeys = strings.Split(foreignKey, ",") } @@ -242,13 +269,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if elemType.Kind() == reflect.Struct { - if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { + if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { relationship.Kind = "many_to_many" { // Foreign Keys for Source joinTableDBNames := []string{} - if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" { joinTableDBNames = strings.Split(foreignKey, ",") } @@ -279,7 +306,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { { // Foreign Keys for Association (Destination) associationJoinTableDBNames := []string{} - if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" { associationJoinTableDBNames = strings.Split(foreignKey, ",") } @@ -317,7 +344,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { var toFields = toScope.GetStructFields() relationship.Kind = "has_many" - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { // Dog has many toys, tag polymorphic is Owner, then associationType is Owner // Toy use OwnerID, OwnerType ('dogs') as foreign key if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { @@ -325,7 +352,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName // if Dog has multiple set of toys set name of the set (instead of default 'dogs') - if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { + if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { relationship.PolymorphicValue = value } else { relationship.PolymorphicValue = scope.TableName() @@ -407,17 +434,17 @@ func (scope *Scope) GetModelStruct() *ModelStruct { tagAssociationForeignKeys []string ) - if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { tagForeignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { + if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { tagAssociationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { tagAssociationForeignKeys = strings.Split(foreignKey, ",") } - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { // Cat has one toy, tag polymorphic is Owner, then associationType is Owner // Toy use OwnerID, OwnerType ('cats') as foreign key if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { @@ -425,7 +452,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName // if Cat has several different types of toys set name for each (instead of default 'cats') - if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { + if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { relationship.PolymorphicValue = value } else { relationship.PolymorphicValue = scope.TableName() @@ -563,7 +590,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Even it is ignored, also possible to decode db value into the field - if value, ok := field.TagSettings["COLUMN"]; ok { + if value, ok := field.TagSettingsGet("COLUMN"); ok { field.DBName = value } else { field.DBName = ToColumnName(fieldStruct.Name) diff --git a/scope.go b/scope.go index ce80ab86..fa521ca2 100644 --- a/scope.go +++ b/scope.go @@ -1115,8 +1115,8 @@ func (scope *Scope) createJoinTable(field *StructField) { if field, ok := scope.FieldByName(fieldName); ok { foreignKeyStruct := field.clone() foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" - delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT") + foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") + foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) } @@ -1126,8 +1126,8 @@ func (scope *Scope) createJoinTable(field *StructField) { if field, ok := toScope.FieldByName(fieldName); ok { foreignKeyStruct := field.clone() foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" - delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT") + foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") + foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) } @@ -1262,7 +1262,7 @@ func (scope *Scope) autoIndex() *Scope { var uniqueIndexes = map[string][]string{} for _, field := range scope.GetStructFields() { - if name, ok := field.TagSettings["INDEX"]; ok { + if name, ok := field.TagSettingsGet("INDEX"); ok { names := strings.Split(name, ",") for _, name := range names { @@ -1273,7 +1273,7 @@ func (scope *Scope) autoIndex() *Scope { } } - if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok { + if name, ok := field.TagSettingsGet("UNIQUE_INDEX"); ok { names := strings.Split(name, ",") for _, name := range names { From 5be9bd34135805e0332b993378864b159784d8a8 Mon Sep 17 00:00:00 2001 From: ch3rub1m Date: Fri, 14 Sep 2018 15:53:49 +0800 Subject: [PATCH 12/17] Rollback transaction when a panic happens in callback (#2093) --- scope.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/scope.go b/scope.go index fa521ca2..378025bd 100644 --- a/scope.go +++ b/scope.go @@ -855,6 +855,14 @@ func (scope *Scope) inlineCondition(values ...interface{}) *Scope { } func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { + defer func() { + if err := recover(); err != nil { + if db, ok := scope.db.db.(sqlTx); ok { + db.Rollback() + } + panic(err) + } + }() for _, f := range funcs { (*f)(scope) if scope.skipLeft { From f6260a00852946a10a57e8bb9f505f19bc9389b7 Mon Sep 17 00:00:00 2001 From: Artemij Shepelev Date: Sat, 22 Sep 2018 14:59:11 +0300 Subject: [PATCH 13/17] Second part of the defaultTableName field race fix (#2060) * fix (https://github.com/jinzhu/gorm/issues/1407) * changed map with mutex to sync.Map (https://github.com/jinzhu/gorm/issues/1407) * removed newModelStructsMap func * commit to rerun pipeline, comment changed * fix race with defaultTableName field (again) --- model_struct.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/model_struct.go b/model_struct.go index 12860e67..8c27e209 100644 --- a/model_struct.go +++ b/model_struct.go @@ -24,11 +24,16 @@ type ModelStruct struct { PrimaryFields []*StructField StructFields []*StructField ModelType reflect.Type + defaultTableName string + l sync.Mutex } // TableName returns model's table name func (s *ModelStruct) TableName(db *DB) string { + s.l.Lock() + defer s.l.Unlock() + if s.defaultTableName == "" && db != nil && s.ModelType != nil { // Set default table name if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { From 742154be9a26e849f02d296073c077e0a7c23828 Mon Sep 17 00:00:00 2001 From: "Iskander (Alex) Sharipov" Date: Sun, 7 Oct 2018 03:49:37 +0300 Subject: [PATCH 14/17] rewrite if-else chain as switch statement (#2121) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit From effective Go: https://golang.org/doc/effective_go.html#switch > It's therefore possible—and idiomatic—to write an if-else-if-else chain as a switch. --- association.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/association.go b/association.go index 8c6d9864..1b7744b5 100644 --- a/association.go +++ b/association.go @@ -267,15 +267,16 @@ func (association *Association) Count() int { query = scope.DB() ) - if relationship.Kind == "many_to_many" { + switch relationship.Kind { + case "many_to_many": query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { + case "has_many", "has_one": primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) query = query.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)..., ) - } else if relationship.Kind == "belongs_to" { + case "belongs_to": primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) query = query.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), From 50c61291de2f96a25627c55adcfda719ff5adae8 Mon Sep 17 00:00:00 2001 From: RikiyaFujii Date: Sat, 3 Nov 2018 22:55:52 +0900 Subject: [PATCH 15/17] add comment (#2163) * add comment * typo --- association.go | 1 + 1 file changed, 1 insertion(+) diff --git a/association.go b/association.go index 1b7744b5..a73344fe 100644 --- a/association.go +++ b/association.go @@ -368,6 +368,7 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa return association } +// setErr set error when the error is not nil. And return Association. func (association *Association) setErr(err error) *Association { if err != nil { association.Error = err From 68f5d25d640b04d1b302993b609b2b1c693432ad Mon Sep 17 00:00:00 2001 From: teresy <43420401+teresy@users.noreply.github.com> Date: Sat, 3 Nov 2018 09:56:27 -0400 Subject: [PATCH 16/17] simplify cases of strings.Index with strings.Contains (#2162) --- scope.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scope.go b/scope.go index 378025bd..806ccb7d 100644 --- a/scope.go +++ b/scope.go @@ -68,7 +68,7 @@ func (scope *Scope) Dialect() Dialect { // Quote used to quote string to escape them for database func (scope *Scope) Quote(str string) string { - if strings.Index(str, ".") != -1 { + if strings.Contains(str, ".") { newStrs := []string{} for _, str := range strings.Split(str, ".") { newStrs = append(newStrs, scope.Dialect().Quote(str)) @@ -330,7 +330,7 @@ func (scope *Scope) TableName() string { // QuotedTableName return quoted table name func (scope *Scope) QuotedTableName() (name string) { if scope.Search != nil && len(scope.Search.tableName) > 0 { - if strings.Index(scope.Search.tableName, " ") != -1 { + if strings.Contains(scope.Search.tableName, " ") { return scope.Search.tableName } return scope.Quote(scope.Search.tableName) From 472c70caa40267cb89fd8facb07fe6454b578626 Mon Sep 17 00:00:00 2001 From: Jun Jie Nan Date: Sat, 3 Nov 2018 22:14:39 +0800 Subject: [PATCH 17/17] Check valuer interface before scan value (#2155) Scan interface only accept int64, float64, bool, []byte, string, time.Time or nil. When do scan, it's better to check whether the type support valuer interface and do convert. --- field.go | 10 +++++++++- field_test.go | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/field.go b/field.go index 11c410b0..acd06e20 100644 --- a/field.go +++ b/field.go @@ -2,6 +2,7 @@ package gorm import ( "database/sql" + "database/sql/driver" "errors" "fmt" "reflect" @@ -44,7 +45,14 @@ func (field *Field) Set(value interface{}) (err error) { if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { fieldValue.Set(reflectValue.Convert(fieldValue.Type())) } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - err = scanner.Scan(reflectValue.Interface()) + v := reflectValue.Interface() + if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = scanner.Scan(v) + } + } else { + err = scanner.Scan(v) + } } else { err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) } diff --git a/field_test.go b/field_test.go index c3afdff5..03a3b3b7 100644 --- a/field_test.go +++ b/field_test.go @@ -3,6 +3,7 @@ package gorm_test import ( "testing" + "github.com/gofrs/uuid" "github.com/jinzhu/gorm" ) @@ -47,3 +48,20 @@ func TestCalculateField(t *testing.T) { t.Errorf("should find embedded field's tag settings") } } + +func TestFieldSet(t *testing.T) { + type TestFieldSetNullUUID struct { + NullUUID uuid.NullUUID + } + scope := DB.NewScope(&TestFieldSetNullUUID{}) + field := scope.Fields()[0] + err := field.Set(uuid.FromStringOrNil("3034d44a-da03-11e8-b366-4a00070b9f00")) + if err != nil { + t.Fatal(err) + } + if id, ok := field.Field.Addr().Interface().(*uuid.NullUUID); !ok { + t.Fatal() + } else if !id.Valid || id.UUID.String() != "3034d44a-da03-11e8-b366-4a00070b9f00" { + t.Fatal(id) + } +}